diff --git a/callbacks/associations.go b/callbacks/associations.go index d78bd968..9d5b7c21 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -39,7 +39,8 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: var ( - objs = make([]reflect.Value, 0, db.Statement.ReflectValue.Len()) + rValLen = db.Statement.ReflectValue.Len() + objs = make([]reflect.Value, 0, rValLen) fieldType = rel.Field.FieldType isPtr = fieldType.Kind() == reflect.Ptr ) @@ -49,22 +50,21 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) { } elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + for i := 0; i < rValLen; i++ { obj := db.Statement.ReflectValue.Index(i) - - if reflect.Indirect(obj).Kind() == reflect.Struct { - if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value - rv := rel.Field.ReflectValueOf(obj) // relation reflect value - objs = append(objs, obj) - if isPtr { - elems = reflect.Append(elems, rv) - } else { - elems = reflect.Append(elems, rv.Addr()) - } - } - } else { + if reflect.Indirect(obj).Kind() != reflect.Struct { break } + + if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value + rv := rel.Field.ReflectValueOf(obj) // relation reflect value + objs = append(objs, obj) + if isPtr { + elems = reflect.Append(elems, rv) + } else { + elems = reflect.Append(elems, rv.Addr()) + } + } } if elems.Len() > 0 { diff --git a/callbacks/create.go b/callbacks/create.go index 36e165a0..df774349 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -200,15 +200,16 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { switch stmt.ReflectValue.Kind() { case reflect.Slice, reflect.Array: - stmt.SQL.Grow(stmt.ReflectValue.Len() * 18) - values.Values = make([][]interface{}, stmt.ReflectValue.Len()) - defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{} - if stmt.ReflectValue.Len() == 0 { + rValLen := stmt.ReflectValue.Len() + stmt.SQL.Grow(rValLen * 18) + values.Values = make([][]interface{}, rValLen) + if rValLen == 0 { stmt.AddError(gorm.ErrEmptySlice) return } - for i := 0; i < stmt.ReflectValue.Len(); i++ { + defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{} + for i := 0; i < rValLen; i++ { rv := reflect.Indirect(stmt.ReflectValue.Index(i)) if !rv.IsValid() { stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData)) @@ -234,11 +235,11 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { for _, field := range stmt.Schema.FieldsWithDefaultDBValue { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { - if v, isZero := field.ValueOf(rv); !isZero { + if rvOfvalue, isZero := field.ValueOf(rv); !isZero { if len(defaultValueFieldsHavingValue[field]) == 0 { - defaultValueFieldsHavingValue[field] = make([]interface{}, stmt.ReflectValue.Len()) + defaultValueFieldsHavingValue[field] = make([]interface{}, rValLen) } - defaultValueFieldsHavingValue[field][i] = v + defaultValueFieldsHavingValue[field][i] = rvOfvalue } } } @@ -274,9 +275,9 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { for _, field := range stmt.Schema.FieldsWithDefaultDBValue { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { - if v, isZero := field.ValueOf(stmt.ReflectValue); !isZero { + if rvOfvalue, isZero := field.ValueOf(stmt.ReflectValue); !isZero { values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) - values.Values[0] = append(values.Values[0], v) + values.Values[0] = append(values.Values[0], rvOfvalue) } } } diff --git a/callbacks/delete.go b/callbacks/delete.go index 08737505..525c0145 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -156,16 +156,19 @@ func Delete(config *Config) func(db *gorm.DB) { } if !db.DryRun && db.Error == nil { - if ok, mode := hasReturning(db, supportReturning); ok { - if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { - gorm.Scan(rows, db, mode) - rows.Close() - } - } else { + ok, mode := hasReturning(db, supportReturning) + if !ok { result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if db.AddError(err) == nil { db.RowsAffected, _ = result.RowsAffected() } + + return + } + + if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { + gorm.Scan(rows, db, mode) + rows.Close() } } } diff --git a/callbacks/preload.go b/callbacks/preload.go index c887c6c0..41405a22 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -61,12 +61,13 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload fieldValues := make([]interface{}, len(joinForeignFields)) joinFieldValues := make([]interface{}, len(joinRelForeignFields)) for i := 0; i < joinResults.Len(); i++ { + joinIndexValue := joinResults.Index(i) for idx, field := range joinForeignFields { - fieldValues[idx], _ = field.ValueOf(joinResults.Index(i)) + fieldValues[idx], _ = field.ValueOf(joinIndexValue) } for idx, field := range joinRelForeignFields { - joinFieldValues[idx], _ = field.ValueOf(joinResults.Index(i)) + joinFieldValues[idx], _ = field.ValueOf(joinIndexValue) } if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok { diff --git a/callbacks/raw.go b/callbacks/raw.go index d594ab39..013e638c 100644 --- a/callbacks/raw.go +++ b/callbacks/raw.go @@ -9,8 +9,9 @@ func RawExec(db *gorm.DB) { result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if err != nil { db.AddError(err) - } else { - db.RowsAffected, _ = result.RowsAffected() + return } + + db.RowsAffected, _ = result.RowsAffected() } } diff --git a/callbacks/row.go b/callbacks/row.go index 407c32d7..56be742e 100644 --- a/callbacks/row.go +++ b/callbacks/row.go @@ -7,16 +7,17 @@ import ( func RowQuery(db *gorm.DB) { if db.Error == nil { BuildQuerySQL(db) - - if !db.DryRun { - if isRows, ok := db.Get("rows"); ok && isRows.(bool) { - db.Statement.Settings.Delete("rows") - db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - } else { - db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - } - - db.RowsAffected = -1 + if db.DryRun { + return } + + if isRows, ok := db.Get("rows"); ok && isRows.(bool) { + db.Statement.Settings.Delete("rows") + db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + } else { + db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + } + + db.RowsAffected = -1 } } diff --git a/callbacks/transaction.go b/callbacks/transaction.go index f116d19f..50887ccc 100644 --- a/callbacks/transaction.go +++ b/callbacks/transaction.go @@ -20,11 +20,12 @@ func BeginTransaction(db *gorm.DB) { func CommitOrRollbackTransaction(db *gorm.DB) { if !db.Config.SkipDefaultTransaction { if _, ok := db.InstanceGet("gorm:started_transaction"); ok { - if db.Error == nil { - db.Commit() - } else { + if db.Error != nil { db.Rollback() + } else { + db.Commit() } + db.Statement.ConnPool = db.ConnPool } } diff --git a/callbacks/update.go b/callbacks/update.go index 8efc3983..1f4960b5 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -157,7 +157,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { case reflect.Slice, reflect.Array: if size := stmt.ReflectValue.Len(); size > 0 { var primaryKeyExprs []clause.Expression - for i := 0; i < stmt.ReflectValue.Len(); i++ { + for i := 0; i < size; i++ { var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields)) var notZero bool for idx, field := range stmt.Schema.PrimaryFields { diff --git a/migrator/migrator.go b/migrator/migrator.go index 95a708de..af1385e2 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -156,7 +156,9 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { } func (m Migrator) GetTables() (tableList []string, err error) { - return tableList, m.DB.Raw("SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=?", m.CurrentDatabase()).Scan(&tableList).Error + err = m.DB.Raw("SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=?", m.CurrentDatabase()). + Scan(&tableList).Error + return } func (m Migrator) CreateTable(values ...interface{}) error { diff --git a/scan.go b/scan.go index 2d0c8fc6..b931aff4 100644 --- a/scan.go +++ b/scan.go @@ -102,9 +102,9 @@ func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue re type ScanMode uint8 const ( - ScanInitialized ScanMode = 1 << 0 - ScanUpdate = 1 << 1 - ScanOnConflictDoNothing = 1 << 2 + ScanInitialized ScanMode = 1 << 0 // 1 + ScanUpdate ScanMode = 1 << 1 // 2 + ScanOnConflictDoNothing ScanMode = 1 << 2 // 4 ) func Scan(rows *sql.Rows, db *DB, mode ScanMode) {