feat: go code style adjust and optimize code for callbacks package (#4861)

* feat: go code style adjust and optimize code for callbacks package

* Update scan.go
This commit is contained in:
heige 2021-11-29 09:33:20 +08:00 committed by GitHub
parent b8f33a42a4
commit 9d5f315b6d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 62 additions and 52 deletions

View File

@ -39,7 +39,8 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
switch db.Statement.ReflectValue.Kind() { switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
var ( var (
objs = make([]reflect.Value, 0, db.Statement.ReflectValue.Len()) rValLen = db.Statement.ReflectValue.Len()
objs = make([]reflect.Value, 0, rValLen)
fieldType = rel.Field.FieldType fieldType = rel.Field.FieldType
isPtr = fieldType.Kind() == reflect.Ptr isPtr = fieldType.Kind() == reflect.Ptr
) )
@ -49,10 +50,12 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
} }
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
for i := 0; i < db.Statement.ReflectValue.Len(); i++ { for i := 0; i < rValLen; i++ {
obj := db.Statement.ReflectValue.Index(i) obj := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(obj).Kind() != reflect.Struct {
break
}
if reflect.Indirect(obj).Kind() == reflect.Struct {
if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value
rv := rel.Field.ReflectValueOf(obj) // relation reflect value rv := rel.Field.ReflectValueOf(obj) // relation reflect value
objs = append(objs, obj) objs = append(objs, obj)
@ -62,9 +65,6 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
elems = reflect.Append(elems, rv.Addr()) elems = reflect.Append(elems, rv.Addr())
} }
} }
} else {
break
}
} }
if elems.Len() > 0 { if elems.Len() > 0 {

View File

@ -200,15 +200,16 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
switch stmt.ReflectValue.Kind() { switch stmt.ReflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
stmt.SQL.Grow(stmt.ReflectValue.Len() * 18) rValLen := stmt.ReflectValue.Len()
values.Values = make([][]interface{}, stmt.ReflectValue.Len()) stmt.SQL.Grow(rValLen * 18)
defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{} values.Values = make([][]interface{}, rValLen)
if stmt.ReflectValue.Len() == 0 { if rValLen == 0 {
stmt.AddError(gorm.ErrEmptySlice) stmt.AddError(gorm.ErrEmptySlice)
return return
} }
for i := 0; i < stmt.ReflectValue.Len(); i++ { defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{}
for i := 0; i < rValLen; i++ {
rv := reflect.Indirect(stmt.ReflectValue.Index(i)) rv := reflect.Indirect(stmt.ReflectValue.Index(i))
if !rv.IsValid() { if !rv.IsValid() {
stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData)) stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData))
@ -234,11 +235,11 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
for _, field := range stmt.Schema.FieldsWithDefaultDBValue { for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if v, isZero := field.ValueOf(rv); !isZero { if rvOfvalue, isZero := field.ValueOf(rv); !isZero {
if len(defaultValueFieldsHavingValue[field]) == 0 { if len(defaultValueFieldsHavingValue[field]) == 0 {
defaultValueFieldsHavingValue[field] = make([]interface{}, stmt.ReflectValue.Len()) defaultValueFieldsHavingValue[field] = make([]interface{}, rValLen)
} }
defaultValueFieldsHavingValue[field][i] = v defaultValueFieldsHavingValue[field][i] = rvOfvalue
} }
} }
} }
@ -274,9 +275,9 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
for _, field := range stmt.Schema.FieldsWithDefaultDBValue { for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if v, isZero := field.ValueOf(stmt.ReflectValue); !isZero { if rvOfvalue, isZero := field.ValueOf(stmt.ReflectValue); !isZero {
values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
values.Values[0] = append(values.Values[0], v) values.Values[0] = append(values.Values[0], rvOfvalue)
} }
} }
} }

View File

@ -156,16 +156,19 @@ func Delete(config *Config) func(db *gorm.DB) {
} }
if !db.DryRun && db.Error == nil { if !db.DryRun && db.Error == nil {
if ok, mode := hasReturning(db, supportReturning); ok { ok, mode := hasReturning(db, supportReturning)
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { if !ok {
gorm.Scan(rows, db, mode)
rows.Close()
}
} else {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if db.AddError(err) == nil { if db.AddError(err) == nil {
db.RowsAffected, _ = result.RowsAffected() db.RowsAffected, _ = result.RowsAffected()
} }
return
}
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
gorm.Scan(rows, db, mode)
rows.Close()
} }
} }
} }

View File

@ -61,12 +61,13 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
fieldValues := make([]interface{}, len(joinForeignFields)) fieldValues := make([]interface{}, len(joinForeignFields))
joinFieldValues := make([]interface{}, len(joinRelForeignFields)) joinFieldValues := make([]interface{}, len(joinRelForeignFields))
for i := 0; i < joinResults.Len(); i++ { for i := 0; i < joinResults.Len(); i++ {
joinIndexValue := joinResults.Index(i)
for idx, field := range joinForeignFields { for idx, field := range joinForeignFields {
fieldValues[idx], _ = field.ValueOf(joinResults.Index(i)) fieldValues[idx], _ = field.ValueOf(joinIndexValue)
} }
for idx, field := range joinRelForeignFields { for idx, field := range joinRelForeignFields {
joinFieldValues[idx], _ = field.ValueOf(joinResults.Index(i)) joinFieldValues[idx], _ = field.ValueOf(joinIndexValue)
} }
if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok { if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok {

View File

@ -9,8 +9,9 @@ func RawExec(db *gorm.DB) {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err != nil { if err != nil {
db.AddError(err) db.AddError(err)
} else { return
}
db.RowsAffected, _ = result.RowsAffected() db.RowsAffected, _ = result.RowsAffected()
} }
}
} }

View File

@ -7,8 +7,10 @@ import (
func RowQuery(db *gorm.DB) { func RowQuery(db *gorm.DB) {
if db.Error == nil { if db.Error == nil {
BuildQuerySQL(db) BuildQuerySQL(db)
if db.DryRun {
return
}
if !db.DryRun {
if isRows, ok := db.Get("rows"); ok && isRows.(bool) { if isRows, ok := db.Get("rows"); ok && isRows.(bool) {
db.Statement.Settings.Delete("rows") db.Statement.Settings.Delete("rows")
db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
@ -18,5 +20,4 @@ func RowQuery(db *gorm.DB) {
db.RowsAffected = -1 db.RowsAffected = -1
} }
}
} }

View File

@ -20,11 +20,12 @@ func BeginTransaction(db *gorm.DB) {
func CommitOrRollbackTransaction(db *gorm.DB) { func CommitOrRollbackTransaction(db *gorm.DB) {
if !db.Config.SkipDefaultTransaction { if !db.Config.SkipDefaultTransaction {
if _, ok := db.InstanceGet("gorm:started_transaction"); ok { if _, ok := db.InstanceGet("gorm:started_transaction"); ok {
if db.Error == nil { if db.Error != nil {
db.Commit()
} else {
db.Rollback() db.Rollback()
} else {
db.Commit()
} }
db.Statement.ConnPool = db.ConnPool db.Statement.ConnPool = db.ConnPool
} }
} }

View File

@ -157,7 +157,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
if size := stmt.ReflectValue.Len(); size > 0 { if size := stmt.ReflectValue.Len(); size > 0 {
var primaryKeyExprs []clause.Expression var primaryKeyExprs []clause.Expression
for i := 0; i < stmt.ReflectValue.Len(); i++ { for i := 0; i < size; i++ {
var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields)) var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields))
var notZero bool var notZero bool
for idx, field := range stmt.Schema.PrimaryFields { for idx, field := range stmt.Schema.PrimaryFields {

View File

@ -156,7 +156,9 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
} }
func (m Migrator) GetTables() (tableList []string, err error) { func (m Migrator) GetTables() (tableList []string, err error) {
return tableList, m.DB.Raw("SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=?", m.CurrentDatabase()).Scan(&tableList).Error err = m.DB.Raw("SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=?", m.CurrentDatabase()).
Scan(&tableList).Error
return
} }
func (m Migrator) CreateTable(values ...interface{}) error { func (m Migrator) CreateTable(values ...interface{}) error {

View File

@ -102,9 +102,9 @@ func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue re
type ScanMode uint8 type ScanMode uint8
const ( const (
ScanInitialized ScanMode = 1 << 0 ScanInitialized ScanMode = 1 << 0 // 1
ScanUpdate = 1 << 1 ScanUpdate ScanMode = 1 << 1 // 2
ScanOnConflictDoNothing = 1 << 2 ScanOnConflictDoNothing ScanMode = 1 << 2 // 4
) )
func Scan(rows *sql.Rows, db *DB, mode ScanMode) { func Scan(rows *sql.Rows, db *DB, mode ScanMode) {