From aaf07257719d4b7e85574ffc6fd6546f364b492e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 8 Jun 2020 13:45:41 +0800 Subject: [PATCH] Refactor for performance --- callbacks/create.go | 7 ++- callbacks/query.go | 108 ++++++++++++++++++++------------------------ callbacks/update.go | 2 +- clause/set.go | 13 ++---- gorm.go | 79 +++++++++++++++----------------- migrator.go | 5 ++ scan.go | 31 ++++++++----- statement.go | 8 ++-- 8 files changed, 123 insertions(+), 130 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index ec4ee1d1..6dc3f10a 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -192,19 +192,22 @@ func ConvertToCreateValues(stmt *gorm.Statement) clause.Values { return ConvertSliceOfMapToValuesForCreate(stmt, value) default: var ( - values = clause.Values{} + values = clause.Values{Columns: make([]clause.Column, len(stmt.Schema.DBNames))} selectColumns, restricted = SelectAndOmitColumns(stmt, true, false) curTime = stmt.DB.NowFunc() isZero = false ) + var columns int for _, db := range stmt.Schema.DBNames { if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil { if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) { - values.Columns = append(values.Columns, clause.Column{Name: db}) + values.Columns[columns] = clause.Column{Name: db} + columns++ } } } + values.Columns = values.Columns[:columns] switch stmt.ReflectValue.Kind() { case reflect.Slice, reflect.Array: diff --git a/callbacks/query.go b/callbacks/query.go index 41f09375..571c7245 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -53,38 +53,28 @@ func BuildQuerySQL(db *gorm.DB) { } if len(db.Statement.Selects) > 0 { - for _, name := range db.Statement.Selects { + clauseSelect.Columns = make([]clause.Column, len(db.Statement.Selects)) + for idx, name := range db.Statement.Selects { if db.Statement.Schema == nil { - clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ - Name: name, - Raw: true, - }) + clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true} } else if f := db.Statement.Schema.LookUpField(name); f != nil { - clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ - Name: f.DBName, - }) + clauseSelect.Columns[idx] = clause.Column{Name: f.DBName} } else { - clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ - Name: name, - Raw: true, - }) + clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true} } } } // inline joins if len(db.Statement.Joins) != 0 { - joins := []clause.Join{} - if len(db.Statement.Selects) == 0 { - for _, dbName := range db.Statement.Schema.DBNames { - clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ - Table: db.Statement.Table, - Name: dbName, - }) + clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames)) + for idx, dbName := range db.Statement.Schema.DBNames { + clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName} } } + joins := []clause.Join{} for name, conds := range db.Statement.Joins { if db.Statement.Schema == nil { joins = append(joins, clause.Join{ @@ -101,24 +91,24 @@ func BuildQuerySQL(db *gorm.DB) { }) } - var exprs []clause.Expression - for _, ref := range relation.References { + exprs := make([]clause.Expression, len(relation.References)) + for idx, ref := range relation.References { if ref.OwnPrimaryKey { - exprs = append(exprs, clause.Eq{ + exprs[idx] = clause.Eq{ Column: clause.Column{Table: db.Statement.Schema.Table, Name: ref.PrimaryKey.DBName}, Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, - }) + } } else { if ref.PrimaryValue == "" { - exprs = append(exprs, clause.Eq{ + exprs[idx] = clause.Eq{ Column: clause.Column{Table: db.Statement.Schema.Table, Name: ref.ForeignKey.DBName}, Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName}, - }) + } } else { - exprs = append(exprs, clause.Eq{ + exprs[idx] = clause.Eq{ Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, Value: ref.PrimaryValue, - }) + } } } } @@ -146,42 +136,40 @@ func BuildQuerySQL(db *gorm.DB) { } func Preload(db *gorm.DB) { - if db.Error == nil { - if len(db.Statement.Preloads) > 0 { - preloadMap := map[string][]string{} - for name := range db.Statement.Preloads { - preloadFields := strings.Split(name, ".") - for idx := range preloadFields { - preloadMap[strings.Join(preloadFields[:idx+1], ".")] = preloadFields[:idx+1] + if db.Error == nil && len(db.Statement.Preloads) > 0 { + preloadMap := map[string][]string{} + for name := range db.Statement.Preloads { + preloadFields := strings.Split(name, ".") + for idx := range preloadFields { + preloadMap[strings.Join(preloadFields[:idx+1], ".")] = preloadFields[:idx+1] + } + } + + preloadNames := make([]string, len(preloadMap)) + idx := 0 + for key := range preloadMap { + preloadNames[idx] = key + idx++ + } + sort.Strings(preloadNames) + + for _, name := range preloadNames { + var ( + curSchema = db.Statement.Schema + preloadFields = preloadMap[name] + rels = make([]*schema.Relationship, len(preloadFields)) + ) + + for idx, preloadField := range preloadFields { + if rel := curSchema.Relationships.Relations[preloadField]; rel != nil { + rels[idx] = rel + curSchema = rel.FieldSchema + } else { + db.AddError(fmt.Errorf("%v: %w", name, gorm.ErrUnsupportedRelation)) } } - preloadNames := make([]string, len(preloadMap)) - idx := 0 - for key := range preloadMap { - preloadNames[idx] = key - idx++ - } - sort.Strings(preloadNames) - - for _, name := range preloadNames { - var ( - curSchema = db.Statement.Schema - preloadFields = preloadMap[name] - rels = make([]*schema.Relationship, len(preloadFields)) - ) - - for idx, preloadField := range preloadFields { - if rel := curSchema.Relationships.Relations[preloadField]; rel != nil { - rels[idx] = rel - curSchema = rel.FieldSchema - } else { - db.AddError(fmt.Errorf("%v: %w", name, gorm.ErrUnsupportedRelation)) - } - } - - preload(db, rels, db.Statement.Preloads[name]) - } + preload(db, rels, db.Statement.Preloads[name]) } } } diff --git a/callbacks/update.go b/callbacks/update.go index f5287dc6..4ef33598 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -140,7 +140,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { case map[string]interface{}: set = make([]clause.Assignment, 0, len(value)) - var keys []string + keys := make([]string, 0, len(value)) for k := range value { keys = append(keys, k) } diff --git a/clause/set.go b/clause/set.go index 2d3965d3..1c2a9ef2 100644 --- a/clause/set.go +++ b/clause/set.go @@ -38,20 +38,15 @@ func (set Set) MergeClause(clause *Clause) { } func Assignments(values map[string]interface{}) Set { - var keys []string - var assignments []Assignment - + keys := make([]string, 0, len(values)) for key := range values { keys = append(keys, key) } - sort.Strings(keys) - for _, key := range keys { - assignments = append(assignments, Assignment{ - Column: Column{Name: key}, - Value: values[key], - }) + assignments := make([]Assignment, len(keys)) + for idx, key := range keys { + assignments[idx] = Assignment{Column: Column{Name: key}, Value: values[key]} } return assignments } diff --git a/gorm.go b/gorm.go index cea744f7..0de6860b 100644 --- a/gorm.go +++ b/gorm.go @@ -205,53 +205,11 @@ func (db *DB) InstanceGet(key string) (interface{}, bool) { return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key) } -func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error { - var ( - tx = db.getInstance() - stmt = tx.Statement - modelSchema, joinSchema *schema.Schema - ) - - if err := stmt.Parse(model); err == nil { - modelSchema = stmt.Schema - } else { - return err - } - - if err := stmt.Parse(joinTable); err == nil { - joinSchema = stmt.Schema - } else { - return err - } - - if relation, ok := modelSchema.Relationships.Relations[field]; ok && relation.JoinTable != nil { - for _, ref := range relation.References { - if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil { - f.DataType = ref.ForeignKey.DataType - ref.ForeignKey = f - } else { - return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName) - } - } - - relation.JoinTable = joinSchema - } else { - return fmt.Errorf("failed to found relation: %v", field) - } - - return nil -} - // Callback returns callback manager func (db *DB) Callback() *callbacks { return db.callbacks } -// AutoMigrate run auto migration for given models -func (db *DB) AutoMigrate(dst ...interface{}) error { - return db.Migrator().AutoMigrate(dst...) -} - // AddError add error to db func (db *DB) AddError(err error) error { if db.Error == nil { @@ -289,3 +247,40 @@ func (db *DB) getInstance() *DB { func Expr(expr string, args ...interface{}) clause.Expr { return clause.Expr{SQL: expr, Vars: args} } + +func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error { + var ( + tx = db.getInstance() + stmt = tx.Statement + modelSchema, joinSchema *schema.Schema + ) + + if err := stmt.Parse(model); err == nil { + modelSchema = stmt.Schema + } else { + return err + } + + if err := stmt.Parse(joinTable); err == nil { + joinSchema = stmt.Schema + } else { + return err + } + + if relation, ok := modelSchema.Relationships.Relations[field]; ok && relation.JoinTable != nil { + for _, ref := range relation.References { + if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil { + f.DataType = ref.ForeignKey.DataType + ref.ForeignKey = f + } else { + return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName) + } + } + + relation.JoinTable = joinSchema + } else { + return fmt.Errorf("failed to found relation: %v", field) + } + + return nil +} diff --git a/migrator.go b/migrator.go index 865a08ef..d45e3ac2 100644 --- a/migrator.go +++ b/migrator.go @@ -9,6 +9,11 @@ func (db *DB) Migrator() Migrator { return db.Dialector.Migrator(db) } +// AutoMigrate run auto migration for given models +func (db *DB) AutoMigrate(dst ...interface{}) error { + return db.Migrator().AutoMigrate(dst...) +} + // ViewOption view option type ViewOption struct { Replace bool diff --git a/scan.go b/scan.go index acba4e9f..f1cdb2e5 100644 --- a/scan.go +++ b/scan.go @@ -71,20 +71,27 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { default: switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: - reflectValueType := db.Statement.ReflectValue.Type().Elem() - isPtr := reflectValueType.Kind() == reflect.Ptr + var ( + reflectValueType = db.Statement.ReflectValue.Type().Elem() + isPtr = reflectValueType.Kind() == reflect.Ptr + fields = make([]*schema.Field, len(columns)) + joinFields [][2]*schema.Field + ) + if isPtr { reflectValueType = reflectValueType.Elem() } db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 0)) - fields := make([]*schema.Field, len(columns)) - joinFields := make([][2]*schema.Field, len(columns)) for idx, column := range columns { if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { fields[idx] = field } else if names := strings.Split(column, "__"); len(names) > 1 { + if len(joinFields) == 0 { + joinFields = make([][2]*schema.Field, len(columns)) + } + if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok { if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { fields[idx] = field @@ -98,26 +105,26 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { } } + // pluck values into slice of data + isPluck := len(fields) == 1 && reflectValueType.Kind() != reflect.Struct for initialized || rows.Next() { initialized = false db.RowsAffected++ elem := reflect.New(reflectValueType).Elem() - if reflectValueType.Kind() != reflect.Struct && len(fields) == 1 { - // pluck - values[0] = elem.Addr().Interface() - db.AddError(rows.Scan(values...)) + if isPluck { + db.AddError(rows.Scan(elem.Addr().Interface())) } else { for idx, field := range fields { if field != nil { - values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface() + values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() } } db.AddError(rows.Scan(values...)) for idx, field := range fields { - if joinFields[idx][0] != nil { + if len(joinFields) != 0 && joinFields[idx][0] != nil { value := reflect.ValueOf(values[idx]).Elem() relValue := joinFields[idx][0].ReflectValueOf(elem) @@ -145,11 +152,11 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { if initialized || rows.Next() { for idx, column := range columns { if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { - values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface() + values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() } else if names := strings.Split(column, "__"); len(names) > 1 { if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok { if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { - values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface() + values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() continue } } diff --git a/statement.go b/statement.go index 614a3ad3..e0e86019 100644 --- a/statement.go +++ b/statement.go @@ -63,7 +63,7 @@ func (stmt *Statement) WriteQuoted(value interface{}) error { } // QuoteTo write quoted value to writer -func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) { +func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { switch v := field.(type) { case clause.Table: if v.Name == clause.CurrentTable { @@ -109,7 +109,7 @@ func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) { case []string: writer.WriteByte('(') for idx, d := range v { - if idx != 0 { + if idx > 0 { writer.WriteString(",") } stmt.DB.Dialector.QuoteTo(writer, d) @@ -121,7 +121,7 @@ func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) { } // Quote returns quoted value -func (stmt Statement) Quote(field interface{}) string { +func (stmt *Statement) Quote(field interface{}) string { var builder strings.Builder stmt.QuoteTo(&builder, field) return builder.String() @@ -219,7 +219,7 @@ func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) { } // BuildCondition build condition -func (stmt Statement) BuildCondition(query interface{}, args ...interface{}) (conds []clause.Expression) { +func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (conds []clause.Expression) { if sql, ok := query.(string); ok { // if it is a number, then treats it as primary key if _, err := strconv.Atoi(sql); err != nil {