mirror of https://github.com/go-gorm/gorm.git
Refactor for performance
This commit is contained in:
parent
13f96f7a15
commit
aaf0725771
|
@ -192,19 +192,22 @@ func ConvertToCreateValues(stmt *gorm.Statement) clause.Values {
|
|||
return ConvertSliceOfMapToValuesForCreate(stmt, value)
|
||||
default:
|
||||
var (
|
||||
values = clause.Values{}
|
||||
values = clause.Values{Columns: make([]clause.Column, len(stmt.Schema.DBNames))}
|
||||
selectColumns, restricted = SelectAndOmitColumns(stmt, true, false)
|
||||
curTime = stmt.DB.NowFunc()
|
||||
isZero = false
|
||||
)
|
||||
|
||||
var columns int
|
||||
for _, db := range stmt.Schema.DBNames {
|
||||
if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil {
|
||||
if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) {
|
||||
values.Columns = append(values.Columns, clause.Column{Name: db})
|
||||
values.Columns[columns] = clause.Column{Name: db}
|
||||
columns++
|
||||
}
|
||||
}
|
||||
}
|
||||
values.Columns = values.Columns[:columns]
|
||||
|
||||
switch stmt.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
|
|
|
@ -53,38 +53,28 @@ func BuildQuerySQL(db *gorm.DB) {
|
|||
}
|
||||
|
||||
if len(db.Statement.Selects) > 0 {
|
||||
for _, name := range db.Statement.Selects {
|
||||
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Selects))
|
||||
for idx, name := range db.Statement.Selects {
|
||||
if db.Statement.Schema == nil {
|
||||
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
|
||||
Name: name,
|
||||
Raw: true,
|
||||
})
|
||||
clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true}
|
||||
} else if f := db.Statement.Schema.LookUpField(name); f != nil {
|
||||
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
|
||||
Name: f.DBName,
|
||||
})
|
||||
clauseSelect.Columns[idx] = clause.Column{Name: f.DBName}
|
||||
} else {
|
||||
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
|
||||
Name: name,
|
||||
Raw: true,
|
||||
})
|
||||
clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// inline joins
|
||||
if len(db.Statement.Joins) != 0 {
|
||||
joins := []clause.Join{}
|
||||
|
||||
if len(db.Statement.Selects) == 0 {
|
||||
for _, dbName := range db.Statement.Schema.DBNames {
|
||||
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
|
||||
Table: db.Statement.Table,
|
||||
Name: dbName,
|
||||
})
|
||||
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames))
|
||||
for idx, dbName := range db.Statement.Schema.DBNames {
|
||||
clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
|
||||
}
|
||||
}
|
||||
|
||||
joins := []clause.Join{}
|
||||
for name, conds := range db.Statement.Joins {
|
||||
if db.Statement.Schema == nil {
|
||||
joins = append(joins, clause.Join{
|
||||
|
@ -101,24 +91,24 @@ func BuildQuerySQL(db *gorm.DB) {
|
|||
})
|
||||
}
|
||||
|
||||
var exprs []clause.Expression
|
||||
for _, ref := range relation.References {
|
||||
exprs := make([]clause.Expression, len(relation.References))
|
||||
for idx, ref := range relation.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
exprs = append(exprs, clause.Eq{
|
||||
exprs[idx] = clause.Eq{
|
||||
Column: clause.Column{Table: db.Statement.Schema.Table, Name: ref.PrimaryKey.DBName},
|
||||
Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
|
||||
})
|
||||
}
|
||||
} else {
|
||||
if ref.PrimaryValue == "" {
|
||||
exprs = append(exprs, clause.Eq{
|
||||
exprs[idx] = clause.Eq{
|
||||
Column: clause.Column{Table: db.Statement.Schema.Table, Name: ref.ForeignKey.DBName},
|
||||
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
|
||||
})
|
||||
}
|
||||
} else {
|
||||
exprs = append(exprs, clause.Eq{
|
||||
exprs[idx] = clause.Eq{
|
||||
Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
|
||||
Value: ref.PrimaryValue,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -146,42 +136,40 @@ func BuildQuerySQL(db *gorm.DB) {
|
|||
}
|
||||
|
||||
func Preload(db *gorm.DB) {
|
||||
if db.Error == nil {
|
||||
if len(db.Statement.Preloads) > 0 {
|
||||
preloadMap := map[string][]string{}
|
||||
for name := range db.Statement.Preloads {
|
||||
preloadFields := strings.Split(name, ".")
|
||||
for idx := range preloadFields {
|
||||
preloadMap[strings.Join(preloadFields[:idx+1], ".")] = preloadFields[:idx+1]
|
||||
if db.Error == nil && len(db.Statement.Preloads) > 0 {
|
||||
preloadMap := map[string][]string{}
|
||||
for name := range db.Statement.Preloads {
|
||||
preloadFields := strings.Split(name, ".")
|
||||
for idx := range preloadFields {
|
||||
preloadMap[strings.Join(preloadFields[:idx+1], ".")] = preloadFields[:idx+1]
|
||||
}
|
||||
}
|
||||
|
||||
preloadNames := make([]string, len(preloadMap))
|
||||
idx := 0
|
||||
for key := range preloadMap {
|
||||
preloadNames[idx] = key
|
||||
idx++
|
||||
}
|
||||
sort.Strings(preloadNames)
|
||||
|
||||
for _, name := range preloadNames {
|
||||
var (
|
||||
curSchema = db.Statement.Schema
|
||||
preloadFields = preloadMap[name]
|
||||
rels = make([]*schema.Relationship, len(preloadFields))
|
||||
)
|
||||
|
||||
for idx, preloadField := range preloadFields {
|
||||
if rel := curSchema.Relationships.Relations[preloadField]; rel != nil {
|
||||
rels[idx] = rel
|
||||
curSchema = rel.FieldSchema
|
||||
} else {
|
||||
db.AddError(fmt.Errorf("%v: %w", name, gorm.ErrUnsupportedRelation))
|
||||
}
|
||||
}
|
||||
|
||||
preloadNames := make([]string, len(preloadMap))
|
||||
idx := 0
|
||||
for key := range preloadMap {
|
||||
preloadNames[idx] = key
|
||||
idx++
|
||||
}
|
||||
sort.Strings(preloadNames)
|
||||
|
||||
for _, name := range preloadNames {
|
||||
var (
|
||||
curSchema = db.Statement.Schema
|
||||
preloadFields = preloadMap[name]
|
||||
rels = make([]*schema.Relationship, len(preloadFields))
|
||||
)
|
||||
|
||||
for idx, preloadField := range preloadFields {
|
||||
if rel := curSchema.Relationships.Relations[preloadField]; rel != nil {
|
||||
rels[idx] = rel
|
||||
curSchema = rel.FieldSchema
|
||||
} else {
|
||||
db.AddError(fmt.Errorf("%v: %w", name, gorm.ErrUnsupportedRelation))
|
||||
}
|
||||
}
|
||||
|
||||
preload(db, rels, db.Statement.Preloads[name])
|
||||
}
|
||||
preload(db, rels, db.Statement.Preloads[name])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -140,7 +140,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
|||
case map[string]interface{}:
|
||||
set = make([]clause.Assignment, 0, len(value))
|
||||
|
||||
var keys []string
|
||||
keys := make([]string, 0, len(value))
|
||||
for k := range value {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
|
|
|
@ -38,20 +38,15 @@ func (set Set) MergeClause(clause *Clause) {
|
|||
}
|
||||
|
||||
func Assignments(values map[string]interface{}) Set {
|
||||
var keys []string
|
||||
var assignments []Assignment
|
||||
|
||||
keys := make([]string, 0, len(values))
|
||||
for key := range values {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
|
||||
sort.Strings(keys)
|
||||
|
||||
for _, key := range keys {
|
||||
assignments = append(assignments, Assignment{
|
||||
Column: Column{Name: key},
|
||||
Value: values[key],
|
||||
})
|
||||
assignments := make([]Assignment, len(keys))
|
||||
for idx, key := range keys {
|
||||
assignments[idx] = Assignment{Column: Column{Name: key}, Value: values[key]}
|
||||
}
|
||||
return assignments
|
||||
}
|
||||
|
|
79
gorm.go
79
gorm.go
|
@ -205,53 +205,11 @@ func (db *DB) InstanceGet(key string) (interface{}, bool) {
|
|||
return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key)
|
||||
}
|
||||
|
||||
func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error {
|
||||
var (
|
||||
tx = db.getInstance()
|
||||
stmt = tx.Statement
|
||||
modelSchema, joinSchema *schema.Schema
|
||||
)
|
||||
|
||||
if err := stmt.Parse(model); err == nil {
|
||||
modelSchema = stmt.Schema
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := stmt.Parse(joinTable); err == nil {
|
||||
joinSchema = stmt.Schema
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
|
||||
if relation, ok := modelSchema.Relationships.Relations[field]; ok && relation.JoinTable != nil {
|
||||
for _, ref := range relation.References {
|
||||
if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil {
|
||||
f.DataType = ref.ForeignKey.DataType
|
||||
ref.ForeignKey = f
|
||||
} else {
|
||||
return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName)
|
||||
}
|
||||
}
|
||||
|
||||
relation.JoinTable = joinSchema
|
||||
} else {
|
||||
return fmt.Errorf("failed to found relation: %v", field)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Callback returns callback manager
|
||||
func (db *DB) Callback() *callbacks {
|
||||
return db.callbacks
|
||||
}
|
||||
|
||||
// AutoMigrate run auto migration for given models
|
||||
func (db *DB) AutoMigrate(dst ...interface{}) error {
|
||||
return db.Migrator().AutoMigrate(dst...)
|
||||
}
|
||||
|
||||
// AddError add error to db
|
||||
func (db *DB) AddError(err error) error {
|
||||
if db.Error == nil {
|
||||
|
@ -289,3 +247,40 @@ func (db *DB) getInstance() *DB {
|
|||
func Expr(expr string, args ...interface{}) clause.Expr {
|
||||
return clause.Expr{SQL: expr, Vars: args}
|
||||
}
|
||||
|
||||
func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error {
|
||||
var (
|
||||
tx = db.getInstance()
|
||||
stmt = tx.Statement
|
||||
modelSchema, joinSchema *schema.Schema
|
||||
)
|
||||
|
||||
if err := stmt.Parse(model); err == nil {
|
||||
modelSchema = stmt.Schema
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := stmt.Parse(joinTable); err == nil {
|
||||
joinSchema = stmt.Schema
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
|
||||
if relation, ok := modelSchema.Relationships.Relations[field]; ok && relation.JoinTable != nil {
|
||||
for _, ref := range relation.References {
|
||||
if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil {
|
||||
f.DataType = ref.ForeignKey.DataType
|
||||
ref.ForeignKey = f
|
||||
} else {
|
||||
return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName)
|
||||
}
|
||||
}
|
||||
|
||||
relation.JoinTable = joinSchema
|
||||
} else {
|
||||
return fmt.Errorf("failed to found relation: %v", field)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -9,6 +9,11 @@ func (db *DB) Migrator() Migrator {
|
|||
return db.Dialector.Migrator(db)
|
||||
}
|
||||
|
||||
// AutoMigrate run auto migration for given models
|
||||
func (db *DB) AutoMigrate(dst ...interface{}) error {
|
||||
return db.Migrator().AutoMigrate(dst...)
|
||||
}
|
||||
|
||||
// ViewOption view option
|
||||
type ViewOption struct {
|
||||
Replace bool
|
||||
|
|
31
scan.go
31
scan.go
|
@ -71,20 +71,27 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
|||
default:
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
reflectValueType := db.Statement.ReflectValue.Type().Elem()
|
||||
isPtr := reflectValueType.Kind() == reflect.Ptr
|
||||
var (
|
||||
reflectValueType = db.Statement.ReflectValue.Type().Elem()
|
||||
isPtr = reflectValueType.Kind() == reflect.Ptr
|
||||
fields = make([]*schema.Field, len(columns))
|
||||
joinFields [][2]*schema.Field
|
||||
)
|
||||
|
||||
if isPtr {
|
||||
reflectValueType = reflectValueType.Elem()
|
||||
}
|
||||
|
||||
db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 0))
|
||||
fields := make([]*schema.Field, len(columns))
|
||||
joinFields := make([][2]*schema.Field, len(columns))
|
||||
|
||||
for idx, column := range columns {
|
||||
if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable {
|
||||
fields[idx] = field
|
||||
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
||||
if len(joinFields) == 0 {
|
||||
joinFields = make([][2]*schema.Field, len(columns))
|
||||
}
|
||||
|
||||
if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok {
|
||||
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
||||
fields[idx] = field
|
||||
|
@ -98,26 +105,26 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
|||
}
|
||||
}
|
||||
|
||||
// pluck values into slice of data
|
||||
isPluck := len(fields) == 1 && reflectValueType.Kind() != reflect.Struct
|
||||
for initialized || rows.Next() {
|
||||
initialized = false
|
||||
db.RowsAffected++
|
||||
|
||||
elem := reflect.New(reflectValueType).Elem()
|
||||
if reflectValueType.Kind() != reflect.Struct && len(fields) == 1 {
|
||||
// pluck
|
||||
values[0] = elem.Addr().Interface()
|
||||
db.AddError(rows.Scan(values...))
|
||||
if isPluck {
|
||||
db.AddError(rows.Scan(elem.Addr().Interface()))
|
||||
} else {
|
||||
for idx, field := range fields {
|
||||
if field != nil {
|
||||
values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface()
|
||||
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
|
||||
}
|
||||
}
|
||||
|
||||
db.AddError(rows.Scan(values...))
|
||||
|
||||
for idx, field := range fields {
|
||||
if joinFields[idx][0] != nil {
|
||||
if len(joinFields) != 0 && joinFields[idx][0] != nil {
|
||||
value := reflect.ValueOf(values[idx]).Elem()
|
||||
relValue := joinFields[idx][0].ReflectValueOf(elem)
|
||||
|
||||
|
@ -145,11 +152,11 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
|||
if initialized || rows.Next() {
|
||||
for idx, column := range columns {
|
||||
if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable {
|
||||
values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface()
|
||||
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
|
||||
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
||||
if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok {
|
||||
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
||||
values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface()
|
||||
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
|
|
@ -63,7 +63,7 @@ func (stmt *Statement) WriteQuoted(value interface{}) error {
|
|||
}
|
||||
|
||||
// QuoteTo write quoted value to writer
|
||||
func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) {
|
||||
func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
|
||||
switch v := field.(type) {
|
||||
case clause.Table:
|
||||
if v.Name == clause.CurrentTable {
|
||||
|
@ -109,7 +109,7 @@ func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) {
|
|||
case []string:
|
||||
writer.WriteByte('(')
|
||||
for idx, d := range v {
|
||||
if idx != 0 {
|
||||
if idx > 0 {
|
||||
writer.WriteString(",")
|
||||
}
|
||||
stmt.DB.Dialector.QuoteTo(writer, d)
|
||||
|
@ -121,7 +121,7 @@ func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) {
|
|||
}
|
||||
|
||||
// Quote returns quoted value
|
||||
func (stmt Statement) Quote(field interface{}) string {
|
||||
func (stmt *Statement) Quote(field interface{}) string {
|
||||
var builder strings.Builder
|
||||
stmt.QuoteTo(&builder, field)
|
||||
return builder.String()
|
||||
|
@ -219,7 +219,7 @@ func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) {
|
|||
}
|
||||
|
||||
// BuildCondition build condition
|
||||
func (stmt Statement) BuildCondition(query interface{}, args ...interface{}) (conds []clause.Expression) {
|
||||
func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (conds []clause.Expression) {
|
||||
if sql, ok := query.(string); ok {
|
||||
// if it is a number, then treats it as primary key
|
||||
if _, err := strconv.Atoi(sql); err != nil {
|
||||
|
|
Loading…
Reference in New Issue