From 4c8810a8484df2ed450e41913c886b54367a3969 Mon Sep 17 00:00:00 2001 From: heige Date: Thu, 4 Nov 2021 13:45:44 +0800 Subject: [PATCH] Refactor if logic (#4683) * adjust code for preload * adjust code for Create --- callbacks/create.go | 119 ++++++++++++++++++++---------------- callbacks/delete.go | 139 ++++++++++++++++++++++--------------------- callbacks/preload.go | 39 ++++++------ 3 files changed, 160 insertions(+), 137 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index 656273fb..36e165a0 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -65,66 +65,81 @@ func Create(config *Config) func(db *gorm.DB) { db.Statement.Build(db.Statement.BuildClauses...) } - if !db.DryRun && db.Error == nil { + isDryRun := !db.DryRun && db.Error == nil + if !isDryRun { + return + } - if ok, mode := hasReturning(db, supportReturning); ok { - if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok { - if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.DoNothing { - mode |= gorm.ScanOnConflictDoNothing - } + ok, mode := hasReturning(db, supportReturning) + if ok { + if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok { + if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.DoNothing { + mode |= gorm.ScanOnConflictDoNothing } - if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { - gorm.Scan(rows, db, mode) - rows.Close() - } - } else { - result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + } - if err != nil { - db.AddError(err) - return - } + rows, err := db.Statement.ConnPool.QueryContext( + db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars..., + ) + if db.AddError(err) == nil { + gorm.Scan(rows, db, mode) + rows.Close() + } - db.RowsAffected, _ = result.RowsAffected() - if db.RowsAffected != 0 && db.Statement.Schema != nil && - db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { - if insertID, err := result.LastInsertId(); err == nil && insertID > 0 { - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - if config.LastInsertIDReversed { - for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { - rv := db.Statement.ReflectValue.Index(i) - if reflect.Indirect(rv).Kind() != reflect.Struct { - break - } + return + } - _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv) - if isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) - insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement - } - } - } else { - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - rv := db.Statement.ReflectValue.Index(i) - if reflect.Indirect(rv).Kind() != reflect.Struct { - break - } + result, err := db.Statement.ConnPool.ExecContext( + db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars..., + ) + if err != nil { + db.AddError(err) + return + } - if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) - insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement - } - } - } - case reflect.Struct: - if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) - } + db.RowsAffected, _ = result.RowsAffected() + if db.RowsAffected != 0 && db.Statement.Schema != nil && + db.Statement.Schema.PrioritizedPrimaryField != nil && + db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { + insertID, err := result.LastInsertId() + insertOk := err == nil && insertID > 0 + if !insertOk { + db.AddError(err) + return + } + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + if config.LastInsertIDReversed { + for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { + rv := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(rv).Kind() != reflect.Struct { + break + } + + _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv) + if isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) + insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement } - } else { - db.AddError(err) } + } else { + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + rv := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(rv).Kind() != reflect.Struct { + break + } + + if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) + insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement + } + } + } + case reflect.Struct: + _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue) + if isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) } } } diff --git a/callbacks/delete.go b/callbacks/delete.go index a1fd0a57..08737505 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -26,82 +26,87 @@ func BeforeDelete(db *gorm.DB) { func DeleteBeforeAssociations(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil { selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false) + if !restricted { + return + } - if restricted { - for column, v := range selectColumns { - if v { - if rel, ok := db.Statement.Schema.Relationships.Relations[column]; ok { - switch rel.Type { - case schema.HasOne, schema.HasMany: - queryConds := rel.ToQueryConditions(db.Statement.ReflectValue) - modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() - tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue) - withoutConditions := false - if db.Statement.Unscoped { - tx = tx.Unscoped() - } + for column, v := range selectColumns { + if !v { + continue + } - if len(db.Statement.Selects) > 0 { - selects := make([]string, 0, len(db.Statement.Selects)) - for _, s := range db.Statement.Selects { - if s == clause.Associations { - selects = append(selects, s) - } else if strings.HasPrefix(s, column+".") { - selects = append(selects, strings.TrimPrefix(s, column+".")) - } - } + rel, ok := db.Statement.Schema.Relationships.Relations[column] + if !ok { + continue + } - if len(selects) > 0 { - tx = tx.Select(selects) - } - } + switch rel.Type { + case schema.HasOne, schema.HasMany: + queryConds := rel.ToQueryConditions(db.Statement.ReflectValue) + modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() + tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue) + withoutConditions := false + if db.Statement.Unscoped { + tx = tx.Unscoped() + } - for _, cond := range queryConds { - if c, ok := cond.(clause.IN); ok && len(c.Values) == 0 { - withoutConditions = true - break - } - } - - if !withoutConditions { - if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { - return - } - } - case schema.Many2Many: - var ( - queryConds = make([]clause.Expression, 0, len(rel.References)) - foreignFields = make([]*schema.Field, 0, len(rel.References)) - relForeignKeys = make([]string, 0, len(rel.References)) - modelValue = reflect.New(rel.JoinTable.ModelType).Interface() - table = rel.JoinTable.Table - tx = db.Session(&gorm.Session{NewDB: true}).Model(modelValue).Table(table) - ) - - for _, ref := range rel.References { - if ref.OwnPrimaryKey { - foreignFields = append(foreignFields, ref.PrimaryKey) - relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) - } else if ref.PrimaryValue != "" { - queryConds = append(queryConds, clause.Eq{ - Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, - Value: ref.PrimaryValue, - }) - } - } - - _, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, foreignFields) - column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues) - queryConds = append(queryConds, clause.IN{Column: column, Values: values}) - - if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { - return - } + if len(db.Statement.Selects) > 0 { + selects := make([]string, 0, len(db.Statement.Selects)) + for _, s := range db.Statement.Selects { + if s == clause.Associations { + selects = append(selects, s) + } else if columnPrefix := column + "."; strings.HasPrefix(s, columnPrefix) { + selects = append(selects, strings.TrimPrefix(s, columnPrefix)) } } + + if len(selects) > 0 { + tx = tx.Select(selects) + } + } + + for _, cond := range queryConds { + if c, ok := cond.(clause.IN); ok && len(c.Values) == 0 { + withoutConditions = true + break + } + } + + if !withoutConditions && db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { + return + } + case schema.Many2Many: + var ( + queryConds = make([]clause.Expression, 0, len(rel.References)) + foreignFields = make([]*schema.Field, 0, len(rel.References)) + relForeignKeys = make([]string, 0, len(rel.References)) + modelValue = reflect.New(rel.JoinTable.ModelType).Interface() + table = rel.JoinTable.Table + tx = db.Session(&gorm.Session{NewDB: true}).Model(modelValue).Table(table) + ) + + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + foreignFields = append(foreignFields, ref.PrimaryKey) + relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) + } else if ref.PrimaryValue != "" { + queryConds = append(queryConds, clause.Eq{ + Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, + }) + } + } + + _, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, foreignFields) + column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues) + queryConds = append(queryConds, clause.IN{Column: column, Values: values}) + + if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { + return } } } + } } diff --git a/callbacks/preload.go b/callbacks/preload.go index 9882590c..c887c6c0 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -145,27 +145,30 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload fieldValues[idx], _ = field.ValueOf(elem) } - if datas, ok := identityMap[utils.ToStringKey(fieldValues...)]; ok { - for _, data := range datas { - reflectFieldValue := rel.Field.ReflectValueOf(data) - if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() { - reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem())) - } + datas, ok := identityMap[utils.ToStringKey(fieldValues...)] + if !ok { + db.AddError(fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", + elem.Interface())) + continue + } - reflectFieldValue = reflect.Indirect(reflectFieldValue) - switch reflectFieldValue.Kind() { - case reflect.Struct: - rel.Field.Set(data, reflectResults.Index(i).Interface()) - case reflect.Slice, reflect.Array: - if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr { - rel.Field.Set(data, reflect.Append(reflectFieldValue, elem).Interface()) - } else { - rel.Field.Set(data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()) - } + for _, data := range datas { + reflectFieldValue := rel.Field.ReflectValueOf(data) + if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() { + reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem())) + } + + reflectFieldValue = reflect.Indirect(reflectFieldValue) + switch reflectFieldValue.Kind() { + case reflect.Struct: + rel.Field.Set(data, elem.Interface()) + case reflect.Slice, reflect.Array: + if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr { + rel.Field.Set(data, reflect.Append(reflectFieldValue, elem).Interface()) + } else { + rel.Field.Set(data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()) } } - } else { - db.AddError(fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", elem.Interface())) } } }