diff --git a/callbacks/associations.go b/callbacks/associations.go index 64d79f24..1e6f62c5 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -46,7 +46,7 @@ func SaveBeforeAssociations(db *gorm.DB) { fieldType = reflect.PtrTo(fieldType) } - elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0) + elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) for i := 0; i < db.Statement.ReflectValue.Len(); i++ { obj := db.Statement.ReflectValue.Index(i) @@ -109,7 +109,7 @@ func SaveAfterAssociations(db *gorm.DB) { fieldType = reflect.PtrTo(fieldType) } - elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0) + elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) for i := 0; i < db.Statement.ReflectValue.Len(); i++ { obj := db.Statement.ReflectValue.Index(i) @@ -181,7 +181,7 @@ func SaveAfterAssociations(db *gorm.DB) { if !isPtr { fieldType = reflect.PtrTo(fieldType) } - elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0) + elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) appendToElems := func(v reflect.Value) { if _, zero := rel.Field.ValueOf(v); !zero { f := reflect.Indirect(rel.Field.ReflectValueOf(v)) @@ -241,8 +241,8 @@ func SaveAfterAssociations(db *gorm.DB) { if !isPtr { fieldType = reflect.PtrTo(fieldType) } - elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0) - joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 0) + elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10) objs := []reflect.Value{} appendToJoins := func(obj reflect.Value, elem reflect.Value) { diff --git a/callbacks/create.go b/callbacks/create.go index 8e2454e8..67f3ab14 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -57,7 +57,7 @@ func Create(config *Config) func(db *gorm.DB) { db.RowsAffected, _ = result.RowsAffected() if db.RowsAffected > 0 { if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { - if insertID, err := result.LastInsertId(); err == nil { + if insertID, err := result.LastInsertId(); err == nil && insertID > 0 { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: if config.LastInsertIDReversed { @@ -87,11 +87,8 @@ func Create(config *Config) func(db *gorm.DB) { } } case reflect.Struct: - if insertID > 0 { - if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero { - - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) - } + if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) } } } else { @@ -253,7 +250,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { switch stmt.ReflectValue.Kind() { case reflect.Slice, reflect.Array: - stmt.SQL.Grow(stmt.ReflectValue.Len() * 15) + stmt.SQL.Grow(stmt.ReflectValue.Len() * 18) values.Values = make([][]interface{}, stmt.ReflectValue.Len()) defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{} if stmt.ReflectValue.Len() == 0 { diff --git a/callbacks/helper.go b/callbacks/helper.go index 09ec4582..3ac63fa1 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -12,7 +12,7 @@ func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]inter values.Columns = make([]clause.Column, 0, len(mapValue)) selectColumns, restricted := stmt.SelectAndOmitColumns(true, false) - var keys []string + var keys = make([]string, 0, len(mapValue)) for k := range mapValue { keys = append(keys, k) } @@ -41,7 +41,7 @@ func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]inter // ConvertSliceOfMapToValuesForCreate convert slice of map to values func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[string]interface{}) (values clause.Values) { var ( - columns = []string{} + columns = make([]string, 0, len(mapValues)) result = map[string][]interface{}{} selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) ) diff --git a/callbacks/preload.go b/callbacks/preload.go index aec10ec5..d60079e4 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -112,7 +112,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { case reflect.Struct: switch rel.Type { case schema.HasMany, schema.Many2Many: - rel.Field.Set(reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 0).Interface()) + rel.Field.Set(reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) default: rel.Field.Set(reflectValue, reflect.New(rel.Field.FieldType).Interface()) } @@ -120,7 +120,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { for i := 0; i < reflectValue.Len(); i++ { switch rel.Type { case schema.HasMany, schema.Many2Many: - rel.Field.Set(reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 0).Interface()) + rel.Field.Set(reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) default: rel.Field.Set(reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()) } diff --git a/gorm.go b/gorm.go index affa8e69..2dfbb855 100644 --- a/gorm.go +++ b/gorm.go @@ -286,6 +286,7 @@ func (db *DB) getInstance() *DB { ConnPool: db.Statement.ConnPool, Context: db.Statement.Context, Clauses: map[string]clause.Clause{}, + Vars: make([]interface{}, 0, 8), } } else { // with clone statement diff --git a/scan.go b/scan.go index 8d737b17..c9c8f442 100644 --- a/scan.go +++ b/scan.go @@ -106,7 +106,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { reflectValueType = reflectValueType.Elem() } - db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 0)) + db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 20)) if Schema != nil { if reflectValueType != Schema.ModelType && reflectValueType.Kind() == reflect.Struct { @@ -117,13 +117,13 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { if field := Schema.LookUpField(column); field != nil && field.Readable { fields[idx] = field } else if names := strings.Split(column, "__"); len(names) > 1 { - if len(joinFields) == 0 { - joinFields = make([][2]*schema.Field, len(columns)) - } - if rel, ok := Schema.Relationships.Relations[names[0]]; ok { if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { fields[idx] = field + + if len(joinFields) == 0 { + joinFields = make([][2]*schema.Field, len(columns)) + } joinFields[idx] = [2]*schema.Field{rel.Field, field} continue } @@ -138,9 +138,9 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { // pluck values into slice of data isPluck := false if len(fields) == 1 { - if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); ok { - isPluck = true - } else if reflectValueType.Kind() != reflect.Struct || reflectValueType.ConvertibleTo(schema.TimeReflectType) { + if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); ok || // is scanner + reflectValueType.Kind() != reflect.Struct || // is not struct + Schema.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time isPluck = true } } @@ -149,9 +149,9 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { initialized = false db.RowsAffected++ - elem := reflect.New(reflectValueType).Elem() + elem := reflect.New(reflectValueType) if isPluck { - db.AddError(rows.Scan(elem.Addr().Interface())) + db.AddError(rows.Scan(elem.Interface())) } else { for idx, field := range fields { if field != nil { @@ -181,9 +181,9 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { } if isPtr { - db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Addr())) - } else { db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem)) + } else { + db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Elem())) } } case reflect.Struct: @@ -216,8 +216,8 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { field.Set(db.Statement.ReflectValue, values[idx]) } else if names := strings.Split(column, "__"); len(names) > 1 { if rel, ok := Schema.Relationships.Relations[names[0]]; ok { - relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue) if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { + relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue) value := reflect.ValueOf(values[idx]).Elem() if relValue.Kind() == reflect.Ptr && relValue.IsNil() { diff --git a/schema/schema.go b/schema/schema.go index cffc19a7..05db641f 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -50,7 +50,7 @@ func (schema Schema) String() string { } func (schema Schema) MakeSlice() reflect.Value { - slice := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(schema.ModelType)), 0, 0) + slice := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(schema.ModelType)), 0, 20) results := reflect.New(slice.Type()) results.Elem().Set(slice) return results diff --git a/schema/utils.go b/schema/utils.go index 55cbdeb4..6e5fd528 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -61,7 +61,7 @@ func removeSettingFromTag(tag reflect.StructTag, names ...string) reflect.Struct // GetRelationsValues get relations's values from a reflect value func GetRelationsValues(reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) { for _, rel := range rels { - reflectResults = reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.FieldSchema.ModelType)), 0, 0) + reflectResults = reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.FieldSchema.ModelType)), 0, 1) appendToResults := func(value reflect.Value) { if _, isZero := rel.Field.ValueOf(value); !isZero { diff --git a/statement.go b/statement.go index 82ebdd91..7c0af59c 100644 --- a/statement.go +++ b/statement.go @@ -239,12 +239,12 @@ func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) { } // BuildCondition build condition -func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (conds []clause.Expression) { +func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []clause.Expression { if s, ok := query.(string); ok { // if it is a number, then treats it as primary key if _, err := strconv.Atoi(s); err != nil { if s == "" && len(args) == 0 { - return + return nil } else if len(args) == 0 || (len(args) > 0 && strings.Contains(s, "?")) { // looks like a where condition return []clause.Expression{clause.Expr{SQL: s, Vars: args}} @@ -257,6 +257,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c } } + conds := make([]clause.Expression, 0, 4) args = append([]interface{}{query}, args...) for _, arg := range args { if valuer, ok := arg.(driver.Valuer); ok { @@ -358,7 +359,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c if len(values) > 0 { conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values}) } - return + return conds } } @@ -367,7 +368,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c } } - return + return conds } // Build build sql with clauses names