Refactor for performance

This commit is contained in:
Jinzhu 2020-06-08 13:45:41 +08:00
parent 13f96f7a15
commit aaf0725771
8 changed files with 123 additions and 130 deletions

View File

@ -192,19 +192,22 @@ func ConvertToCreateValues(stmt *gorm.Statement) clause.Values {
return ConvertSliceOfMapToValuesForCreate(stmt, value)
default:
var (
values = clause.Values{}
values = clause.Values{Columns: make([]clause.Column, len(stmt.Schema.DBNames))}
selectColumns, restricted = SelectAndOmitColumns(stmt, true, false)
curTime = stmt.DB.NowFunc()
isZero = false
)
var columns int
for _, db := range stmt.Schema.DBNames {
if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil {
if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) {
values.Columns = append(values.Columns, clause.Column{Name: db})
values.Columns[columns] = clause.Column{Name: db}
columns++
}
}
}
values.Columns = values.Columns[:columns]
switch stmt.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:

View File

@ -53,38 +53,28 @@ func BuildQuerySQL(db *gorm.DB) {
}
if len(db.Statement.Selects) > 0 {
for _, name := range db.Statement.Selects {
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Selects))
for idx, name := range db.Statement.Selects {
if db.Statement.Schema == nil {
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
Name: name,
Raw: true,
})
clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true}
} else if f := db.Statement.Schema.LookUpField(name); f != nil {
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
Name: f.DBName,
})
clauseSelect.Columns[idx] = clause.Column{Name: f.DBName}
} else {
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
Name: name,
Raw: true,
})
clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true}
}
}
}
// inline joins
if len(db.Statement.Joins) != 0 {
joins := []clause.Join{}
if len(db.Statement.Selects) == 0 {
for _, dbName := range db.Statement.Schema.DBNames {
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
Table: db.Statement.Table,
Name: dbName,
})
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames))
for idx, dbName := range db.Statement.Schema.DBNames {
clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
}
}
joins := []clause.Join{}
for name, conds := range db.Statement.Joins {
if db.Statement.Schema == nil {
joins = append(joins, clause.Join{
@ -101,24 +91,24 @@ func BuildQuerySQL(db *gorm.DB) {
})
}
var exprs []clause.Expression
for _, ref := range relation.References {
exprs := make([]clause.Expression, len(relation.References))
for idx, ref := range relation.References {
if ref.OwnPrimaryKey {
exprs = append(exprs, clause.Eq{
exprs[idx] = clause.Eq{
Column: clause.Column{Table: db.Statement.Schema.Table, Name: ref.PrimaryKey.DBName},
Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
})
}
} else {
if ref.PrimaryValue == "" {
exprs = append(exprs, clause.Eq{
exprs[idx] = clause.Eq{
Column: clause.Column{Table: db.Statement.Schema.Table, Name: ref.ForeignKey.DBName},
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
})
}
} else {
exprs = append(exprs, clause.Eq{
exprs[idx] = clause.Eq{
Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
Value: ref.PrimaryValue,
})
}
}
}
}
@ -146,42 +136,40 @@ func BuildQuerySQL(db *gorm.DB) {
}
func Preload(db *gorm.DB) {
if db.Error == nil {
if len(db.Statement.Preloads) > 0 {
preloadMap := map[string][]string{}
for name := range db.Statement.Preloads {
preloadFields := strings.Split(name, ".")
for idx := range preloadFields {
preloadMap[strings.Join(preloadFields[:idx+1], ".")] = preloadFields[:idx+1]
if db.Error == nil && len(db.Statement.Preloads) > 0 {
preloadMap := map[string][]string{}
for name := range db.Statement.Preloads {
preloadFields := strings.Split(name, ".")
for idx := range preloadFields {
preloadMap[strings.Join(preloadFields[:idx+1], ".")] = preloadFields[:idx+1]
}
}
preloadNames := make([]string, len(preloadMap))
idx := 0
for key := range preloadMap {
preloadNames[idx] = key
idx++
}
sort.Strings(preloadNames)
for _, name := range preloadNames {
var (
curSchema = db.Statement.Schema
preloadFields = preloadMap[name]
rels = make([]*schema.Relationship, len(preloadFields))
)
for idx, preloadField := range preloadFields {
if rel := curSchema.Relationships.Relations[preloadField]; rel != nil {
rels[idx] = rel
curSchema = rel.FieldSchema
} else {
db.AddError(fmt.Errorf("%v: %w", name, gorm.ErrUnsupportedRelation))
}
}
preloadNames := make([]string, len(preloadMap))
idx := 0
for key := range preloadMap {
preloadNames[idx] = key
idx++
}
sort.Strings(preloadNames)
for _, name := range preloadNames {
var (
curSchema = db.Statement.Schema
preloadFields = preloadMap[name]
rels = make([]*schema.Relationship, len(preloadFields))
)
for idx, preloadField := range preloadFields {
if rel := curSchema.Relationships.Relations[preloadField]; rel != nil {
rels[idx] = rel
curSchema = rel.FieldSchema
} else {
db.AddError(fmt.Errorf("%v: %w", name, gorm.ErrUnsupportedRelation))
}
}
preload(db, rels, db.Statement.Preloads[name])
}
preload(db, rels, db.Statement.Preloads[name])
}
}
}

View File

@ -140,7 +140,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
case map[string]interface{}:
set = make([]clause.Assignment, 0, len(value))
var keys []string
keys := make([]string, 0, len(value))
for k := range value {
keys = append(keys, k)
}

View File

@ -38,20 +38,15 @@ func (set Set) MergeClause(clause *Clause) {
}
func Assignments(values map[string]interface{}) Set {
var keys []string
var assignments []Assignment
keys := make([]string, 0, len(values))
for key := range values {
keys = append(keys, key)
}
sort.Strings(keys)
for _, key := range keys {
assignments = append(assignments, Assignment{
Column: Column{Name: key},
Value: values[key],
})
assignments := make([]Assignment, len(keys))
for idx, key := range keys {
assignments[idx] = Assignment{Column: Column{Name: key}, Value: values[key]}
}
return assignments
}

79
gorm.go
View File

@ -205,53 +205,11 @@ func (db *DB) InstanceGet(key string) (interface{}, bool) {
return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key)
}
func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error {
var (
tx = db.getInstance()
stmt = tx.Statement
modelSchema, joinSchema *schema.Schema
)
if err := stmt.Parse(model); err == nil {
modelSchema = stmt.Schema
} else {
return err
}
if err := stmt.Parse(joinTable); err == nil {
joinSchema = stmt.Schema
} else {
return err
}
if relation, ok := modelSchema.Relationships.Relations[field]; ok && relation.JoinTable != nil {
for _, ref := range relation.References {
if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil {
f.DataType = ref.ForeignKey.DataType
ref.ForeignKey = f
} else {
return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName)
}
}
relation.JoinTable = joinSchema
} else {
return fmt.Errorf("failed to found relation: %v", field)
}
return nil
}
// Callback returns callback manager
func (db *DB) Callback() *callbacks {
return db.callbacks
}
// AutoMigrate run auto migration for given models
func (db *DB) AutoMigrate(dst ...interface{}) error {
return db.Migrator().AutoMigrate(dst...)
}
// AddError add error to db
func (db *DB) AddError(err error) error {
if db.Error == nil {
@ -289,3 +247,40 @@ func (db *DB) getInstance() *DB {
func Expr(expr string, args ...interface{}) clause.Expr {
return clause.Expr{SQL: expr, Vars: args}
}
func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error {
var (
tx = db.getInstance()
stmt = tx.Statement
modelSchema, joinSchema *schema.Schema
)
if err := stmt.Parse(model); err == nil {
modelSchema = stmt.Schema
} else {
return err
}
if err := stmt.Parse(joinTable); err == nil {
joinSchema = stmt.Schema
} else {
return err
}
if relation, ok := modelSchema.Relationships.Relations[field]; ok && relation.JoinTable != nil {
for _, ref := range relation.References {
if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil {
f.DataType = ref.ForeignKey.DataType
ref.ForeignKey = f
} else {
return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName)
}
}
relation.JoinTable = joinSchema
} else {
return fmt.Errorf("failed to found relation: %v", field)
}
return nil
}

View File

@ -9,6 +9,11 @@ func (db *DB) Migrator() Migrator {
return db.Dialector.Migrator(db)
}
// AutoMigrate run auto migration for given models
func (db *DB) AutoMigrate(dst ...interface{}) error {
return db.Migrator().AutoMigrate(dst...)
}
// ViewOption view option
type ViewOption struct {
Replace bool

31
scan.go
View File

@ -71,20 +71,27 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
default:
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
reflectValueType := db.Statement.ReflectValue.Type().Elem()
isPtr := reflectValueType.Kind() == reflect.Ptr
var (
reflectValueType = db.Statement.ReflectValue.Type().Elem()
isPtr = reflectValueType.Kind() == reflect.Ptr
fields = make([]*schema.Field, len(columns))
joinFields [][2]*schema.Field
)
if isPtr {
reflectValueType = reflectValueType.Elem()
}
db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 0))
fields := make([]*schema.Field, len(columns))
joinFields := make([][2]*schema.Field, len(columns))
for idx, column := range columns {
if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable {
fields[idx] = field
} else if names := strings.Split(column, "__"); len(names) > 1 {
if len(joinFields) == 0 {
joinFields = make([][2]*schema.Field, len(columns))
}
if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
fields[idx] = field
@ -98,26 +105,26 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
}
}
// pluck values into slice of data
isPluck := len(fields) == 1 && reflectValueType.Kind() != reflect.Struct
for initialized || rows.Next() {
initialized = false
db.RowsAffected++
elem := reflect.New(reflectValueType).Elem()
if reflectValueType.Kind() != reflect.Struct && len(fields) == 1 {
// pluck
values[0] = elem.Addr().Interface()
db.AddError(rows.Scan(values...))
if isPluck {
db.AddError(rows.Scan(elem.Addr().Interface()))
} else {
for idx, field := range fields {
if field != nil {
values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface()
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
}
}
db.AddError(rows.Scan(values...))
for idx, field := range fields {
if joinFields[idx][0] != nil {
if len(joinFields) != 0 && joinFields[idx][0] != nil {
value := reflect.ValueOf(values[idx]).Elem()
relValue := joinFields[idx][0].ReflectValueOf(elem)
@ -145,11 +152,11 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
if initialized || rows.Next() {
for idx, column := range columns {
if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable {
values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface()
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface()
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
continue
}
}

View File

@ -63,7 +63,7 @@ func (stmt *Statement) WriteQuoted(value interface{}) error {
}
// QuoteTo write quoted value to writer
func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) {
func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
switch v := field.(type) {
case clause.Table:
if v.Name == clause.CurrentTable {
@ -109,7 +109,7 @@ func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) {
case []string:
writer.WriteByte('(')
for idx, d := range v {
if idx != 0 {
if idx > 0 {
writer.WriteString(",")
}
stmt.DB.Dialector.QuoteTo(writer, d)
@ -121,7 +121,7 @@ func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) {
}
// Quote returns quoted value
func (stmt Statement) Quote(field interface{}) string {
func (stmt *Statement) Quote(field interface{}) string {
var builder strings.Builder
stmt.QuoteTo(&builder, field)
return builder.String()
@ -219,7 +219,7 @@ func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) {
}
// BuildCondition build condition
func (stmt Statement) BuildCondition(query interface{}, args ...interface{}) (conds []clause.Expression) {
func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (conds []clause.Expression) {
if sql, ok := query.(string); ok {
// if it is a number, then treats it as primary key
if _, err := strconv.Atoi(sql); err != nil {