Refactor for performance

This commit is contained in:
Jinzhu 2020-06-08 13:45:41 +08:00
parent 13f96f7a15
commit aaf0725771
8 changed files with 123 additions and 130 deletions

View File

@ -192,19 +192,22 @@ func ConvertToCreateValues(stmt *gorm.Statement) clause.Values {
return ConvertSliceOfMapToValuesForCreate(stmt, value) return ConvertSliceOfMapToValuesForCreate(stmt, value)
default: default:
var ( var (
values = clause.Values{} values = clause.Values{Columns: make([]clause.Column, len(stmt.Schema.DBNames))}
selectColumns, restricted = SelectAndOmitColumns(stmt, true, false) selectColumns, restricted = SelectAndOmitColumns(stmt, true, false)
curTime = stmt.DB.NowFunc() curTime = stmt.DB.NowFunc()
isZero = false isZero = false
) )
var columns int
for _, db := range stmt.Schema.DBNames { for _, db := range stmt.Schema.DBNames {
if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil { if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil {
if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) { if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) {
values.Columns = append(values.Columns, clause.Column{Name: db}) values.Columns[columns] = clause.Column{Name: db}
columns++
} }
} }
} }
values.Columns = values.Columns[:columns]
switch stmt.ReflectValue.Kind() { switch stmt.ReflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:

View File

@ -53,38 +53,28 @@ func BuildQuerySQL(db *gorm.DB) {
} }
if len(db.Statement.Selects) > 0 { if len(db.Statement.Selects) > 0 {
for _, name := range db.Statement.Selects { clauseSelect.Columns = make([]clause.Column, len(db.Statement.Selects))
for idx, name := range db.Statement.Selects {
if db.Statement.Schema == nil { if db.Statement.Schema == nil {
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true}
Name: name,
Raw: true,
})
} else if f := db.Statement.Schema.LookUpField(name); f != nil { } else if f := db.Statement.Schema.LookUpField(name); f != nil {
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ clauseSelect.Columns[idx] = clause.Column{Name: f.DBName}
Name: f.DBName,
})
} else { } else {
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true}
Name: name,
Raw: true,
})
} }
} }
} }
// inline joins // inline joins
if len(db.Statement.Joins) != 0 { if len(db.Statement.Joins) != 0 {
joins := []clause.Join{}
if len(db.Statement.Selects) == 0 { if len(db.Statement.Selects) == 0 {
for _, dbName := range db.Statement.Schema.DBNames { clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames))
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ for idx, dbName := range db.Statement.Schema.DBNames {
Table: db.Statement.Table, clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
Name: dbName,
})
} }
} }
joins := []clause.Join{}
for name, conds := range db.Statement.Joins { for name, conds := range db.Statement.Joins {
if db.Statement.Schema == nil { if db.Statement.Schema == nil {
joins = append(joins, clause.Join{ joins = append(joins, clause.Join{
@ -101,24 +91,24 @@ func BuildQuerySQL(db *gorm.DB) {
}) })
} }
var exprs []clause.Expression exprs := make([]clause.Expression, len(relation.References))
for _, ref := range relation.References { for idx, ref := range relation.References {
if ref.OwnPrimaryKey { if ref.OwnPrimaryKey {
exprs = append(exprs, clause.Eq{ exprs[idx] = clause.Eq{
Column: clause.Column{Table: db.Statement.Schema.Table, Name: ref.PrimaryKey.DBName}, Column: clause.Column{Table: db.Statement.Schema.Table, Name: ref.PrimaryKey.DBName},
Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
}) }
} else { } else {
if ref.PrimaryValue == "" { if ref.PrimaryValue == "" {
exprs = append(exprs, clause.Eq{ exprs[idx] = clause.Eq{
Column: clause.Column{Table: db.Statement.Schema.Table, Name: ref.ForeignKey.DBName}, Column: clause.Column{Table: db.Statement.Schema.Table, Name: ref.ForeignKey.DBName},
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName}, Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
}) }
} else { } else {
exprs = append(exprs, clause.Eq{ exprs[idx] = clause.Eq{
Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
Value: ref.PrimaryValue, Value: ref.PrimaryValue,
}) }
} }
} }
} }
@ -146,42 +136,40 @@ func BuildQuerySQL(db *gorm.DB) {
} }
func Preload(db *gorm.DB) { func Preload(db *gorm.DB) {
if db.Error == nil { if db.Error == nil && len(db.Statement.Preloads) > 0 {
if len(db.Statement.Preloads) > 0 { preloadMap := map[string][]string{}
preloadMap := map[string][]string{} for name := range db.Statement.Preloads {
for name := range db.Statement.Preloads { preloadFields := strings.Split(name, ".")
preloadFields := strings.Split(name, ".") for idx := range preloadFields {
for idx := range preloadFields { preloadMap[strings.Join(preloadFields[:idx+1], ".")] = preloadFields[:idx+1]
preloadMap[strings.Join(preloadFields[:idx+1], ".")] = preloadFields[:idx+1] }
}
preloadNames := make([]string, len(preloadMap))
idx := 0
for key := range preloadMap {
preloadNames[idx] = key
idx++
}
sort.Strings(preloadNames)
for _, name := range preloadNames {
var (
curSchema = db.Statement.Schema
preloadFields = preloadMap[name]
rels = make([]*schema.Relationship, len(preloadFields))
)
for idx, preloadField := range preloadFields {
if rel := curSchema.Relationships.Relations[preloadField]; rel != nil {
rels[idx] = rel
curSchema = rel.FieldSchema
} else {
db.AddError(fmt.Errorf("%v: %w", name, gorm.ErrUnsupportedRelation))
} }
} }
preloadNames := make([]string, len(preloadMap)) preload(db, rels, db.Statement.Preloads[name])
idx := 0
for key := range preloadMap {
preloadNames[idx] = key
idx++
}
sort.Strings(preloadNames)
for _, name := range preloadNames {
var (
curSchema = db.Statement.Schema
preloadFields = preloadMap[name]
rels = make([]*schema.Relationship, len(preloadFields))
)
for idx, preloadField := range preloadFields {
if rel := curSchema.Relationships.Relations[preloadField]; rel != nil {
rels[idx] = rel
curSchema = rel.FieldSchema
} else {
db.AddError(fmt.Errorf("%v: %w", name, gorm.ErrUnsupportedRelation))
}
}
preload(db, rels, db.Statement.Preloads[name])
}
} }
} }
} }

View File

@ -140,7 +140,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
case map[string]interface{}: case map[string]interface{}:
set = make([]clause.Assignment, 0, len(value)) set = make([]clause.Assignment, 0, len(value))
var keys []string keys := make([]string, 0, len(value))
for k := range value { for k := range value {
keys = append(keys, k) keys = append(keys, k)
} }

View File

@ -38,20 +38,15 @@ func (set Set) MergeClause(clause *Clause) {
} }
func Assignments(values map[string]interface{}) Set { func Assignments(values map[string]interface{}) Set {
var keys []string keys := make([]string, 0, len(values))
var assignments []Assignment
for key := range values { for key := range values {
keys = append(keys, key) keys = append(keys, key)
} }
sort.Strings(keys) sort.Strings(keys)
for _, key := range keys { assignments := make([]Assignment, len(keys))
assignments = append(assignments, Assignment{ for idx, key := range keys {
Column: Column{Name: key}, assignments[idx] = Assignment{Column: Column{Name: key}, Value: values[key]}
Value: values[key],
})
} }
return assignments return assignments
} }

79
gorm.go
View File

@ -205,53 +205,11 @@ func (db *DB) InstanceGet(key string) (interface{}, bool) {
return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key) return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key)
} }
func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error {
var (
tx = db.getInstance()
stmt = tx.Statement
modelSchema, joinSchema *schema.Schema
)
if err := stmt.Parse(model); err == nil {
modelSchema = stmt.Schema
} else {
return err
}
if err := stmt.Parse(joinTable); err == nil {
joinSchema = stmt.Schema
} else {
return err
}
if relation, ok := modelSchema.Relationships.Relations[field]; ok && relation.JoinTable != nil {
for _, ref := range relation.References {
if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil {
f.DataType = ref.ForeignKey.DataType
ref.ForeignKey = f
} else {
return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName)
}
}
relation.JoinTable = joinSchema
} else {
return fmt.Errorf("failed to found relation: %v", field)
}
return nil
}
// Callback returns callback manager // Callback returns callback manager
func (db *DB) Callback() *callbacks { func (db *DB) Callback() *callbacks {
return db.callbacks return db.callbacks
} }
// AutoMigrate run auto migration for given models
func (db *DB) AutoMigrate(dst ...interface{}) error {
return db.Migrator().AutoMigrate(dst...)
}
// AddError add error to db // AddError add error to db
func (db *DB) AddError(err error) error { func (db *DB) AddError(err error) error {
if db.Error == nil { if db.Error == nil {
@ -289,3 +247,40 @@ func (db *DB) getInstance() *DB {
func Expr(expr string, args ...interface{}) clause.Expr { func Expr(expr string, args ...interface{}) clause.Expr {
return clause.Expr{SQL: expr, Vars: args} return clause.Expr{SQL: expr, Vars: args}
} }
func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error {
var (
tx = db.getInstance()
stmt = tx.Statement
modelSchema, joinSchema *schema.Schema
)
if err := stmt.Parse(model); err == nil {
modelSchema = stmt.Schema
} else {
return err
}
if err := stmt.Parse(joinTable); err == nil {
joinSchema = stmt.Schema
} else {
return err
}
if relation, ok := modelSchema.Relationships.Relations[field]; ok && relation.JoinTable != nil {
for _, ref := range relation.References {
if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil {
f.DataType = ref.ForeignKey.DataType
ref.ForeignKey = f
} else {
return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName)
}
}
relation.JoinTable = joinSchema
} else {
return fmt.Errorf("failed to found relation: %v", field)
}
return nil
}

View File

@ -9,6 +9,11 @@ func (db *DB) Migrator() Migrator {
return db.Dialector.Migrator(db) return db.Dialector.Migrator(db)
} }
// AutoMigrate run auto migration for given models
func (db *DB) AutoMigrate(dst ...interface{}) error {
return db.Migrator().AutoMigrate(dst...)
}
// ViewOption view option // ViewOption view option
type ViewOption struct { type ViewOption struct {
Replace bool Replace bool

31
scan.go
View File

@ -71,20 +71,27 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
default: default:
switch db.Statement.ReflectValue.Kind() { switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
reflectValueType := db.Statement.ReflectValue.Type().Elem() var (
isPtr := reflectValueType.Kind() == reflect.Ptr reflectValueType = db.Statement.ReflectValue.Type().Elem()
isPtr = reflectValueType.Kind() == reflect.Ptr
fields = make([]*schema.Field, len(columns))
joinFields [][2]*schema.Field
)
if isPtr { if isPtr {
reflectValueType = reflectValueType.Elem() reflectValueType = reflectValueType.Elem()
} }
db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 0)) db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 0))
fields := make([]*schema.Field, len(columns))
joinFields := make([][2]*schema.Field, len(columns))
for idx, column := range columns { for idx, column := range columns {
if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable {
fields[idx] = field fields[idx] = field
} else if names := strings.Split(column, "__"); len(names) > 1 { } else if names := strings.Split(column, "__"); len(names) > 1 {
if len(joinFields) == 0 {
joinFields = make([][2]*schema.Field, len(columns))
}
if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok { if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
fields[idx] = field fields[idx] = field
@ -98,26 +105,26 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
} }
} }
// pluck values into slice of data
isPluck := len(fields) == 1 && reflectValueType.Kind() != reflect.Struct
for initialized || rows.Next() { for initialized || rows.Next() {
initialized = false initialized = false
db.RowsAffected++ db.RowsAffected++
elem := reflect.New(reflectValueType).Elem() elem := reflect.New(reflectValueType).Elem()
if reflectValueType.Kind() != reflect.Struct && len(fields) == 1 { if isPluck {
// pluck db.AddError(rows.Scan(elem.Addr().Interface()))
values[0] = elem.Addr().Interface()
db.AddError(rows.Scan(values...))
} else { } else {
for idx, field := range fields { for idx, field := range fields {
if field != nil { if field != nil {
values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface() values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
} }
} }
db.AddError(rows.Scan(values...)) db.AddError(rows.Scan(values...))
for idx, field := range fields { for idx, field := range fields {
if joinFields[idx][0] != nil { if len(joinFields) != 0 && joinFields[idx][0] != nil {
value := reflect.ValueOf(values[idx]).Elem() value := reflect.ValueOf(values[idx]).Elem()
relValue := joinFields[idx][0].ReflectValueOf(elem) relValue := joinFields[idx][0].ReflectValueOf(elem)
@ -145,11 +152,11 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
if initialized || rows.Next() { if initialized || rows.Next() {
for idx, column := range columns { for idx, column := range columns {
if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable {
values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface() values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
} else if names := strings.Split(column, "__"); len(names) > 1 { } else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok { if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface() values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
continue continue
} }
} }

View File

@ -63,7 +63,7 @@ func (stmt *Statement) WriteQuoted(value interface{}) error {
} }
// QuoteTo write quoted value to writer // QuoteTo write quoted value to writer
func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) { func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
switch v := field.(type) { switch v := field.(type) {
case clause.Table: case clause.Table:
if v.Name == clause.CurrentTable { if v.Name == clause.CurrentTable {
@ -109,7 +109,7 @@ func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) {
case []string: case []string:
writer.WriteByte('(') writer.WriteByte('(')
for idx, d := range v { for idx, d := range v {
if idx != 0 { if idx > 0 {
writer.WriteString(",") writer.WriteString(",")
} }
stmt.DB.Dialector.QuoteTo(writer, d) stmt.DB.Dialector.QuoteTo(writer, d)
@ -121,7 +121,7 @@ func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) {
} }
// Quote returns quoted value // Quote returns quoted value
func (stmt Statement) Quote(field interface{}) string { func (stmt *Statement) Quote(field interface{}) string {
var builder strings.Builder var builder strings.Builder
stmt.QuoteTo(&builder, field) stmt.QuoteTo(&builder, field)
return builder.String() return builder.String()
@ -219,7 +219,7 @@ func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) {
} }
// BuildCondition build condition // BuildCondition build condition
func (stmt Statement) BuildCondition(query interface{}, args ...interface{}) (conds []clause.Expression) { func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (conds []clause.Expression) {
if sql, ok := query.(string); ok { if sql, ok := query.(string); ok {
// if it is a number, then treats it as primary key // if it is a number, then treats it as primary key
if _, err := strconv.Atoi(sql); err != nil { if _, err := strconv.Atoi(sql); err != nil {