diff --git a/callbacks.go b/callbacks.go index a9a6dd85..4f700081 100644 --- a/callbacks.go +++ b/callbacks.go @@ -105,8 +105,11 @@ func (p *processor) Execute(db *DB) { return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected }, db.Error) - stmt.reinit() - // db.Config.statementPool.Put(stmt) + if !stmt.DB.DryRun { + stmt.SQL.Reset() + stmt.Vars = nil + stmt.NamedVars = nil + } } } diff --git a/callbacks/callmethod.go b/callbacks/callmethod.go new file mode 100644 index 00000000..a0e9b0e7 --- /dev/null +++ b/callbacks/callmethod.go @@ -0,0 +1,21 @@ +package callbacks + +import ( + "reflect" + + "gorm.io/gorm" +) + +func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) { + tx := db.Session(&gorm.Session{}) + if called := fc(db.Statement.Dest, tx); !called { + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + fc(db.Statement.ReflectValue.Index(i).Addr().Interface(), tx) + } + case reflect.Struct: + fc(db.Statement.ReflectValue.Addr().Interface(), tx) + } + } +} diff --git a/callbacks/create.go b/callbacks/create.go index 99140612..ec4ee1d1 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -10,9 +10,7 @@ import ( func BeforeCreate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) { - tx := db.Session(&gorm.Session{}) - callMethod := func(value interface{}) bool { - var called bool + callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.BeforeSave { if i, ok := value.(gorm.BeforeSaveInterface); ok { called = true @@ -27,18 +25,7 @@ func BeforeCreate(db *gorm.DB) { } } return called - } - - if ok := callMethod(db.Statement.Dest); !ok { - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - callMethod(db.Statement.ReflectValue.Index(i).Addr().Interface()) - } - case reflect.Struct: - callMethod(db.Statement.ReflectValue.Addr().Interface()) - } - } + }) } } @@ -67,28 +54,26 @@ func Create(config *Config) func(db *gorm.DB) { result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if err == nil { - if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil { - if _, ok := db.Statement.Schema.FieldsWithDefaultDBValue[db.Statement.Schema.PrioritizedPrimaryField.DBName]; ok { - if insertID, err := result.LastInsertId(); err == nil { - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - if config.LastInsertIDReversed { - for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) - insertID-- - } - } else { - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) - insertID++ - } + if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { + if insertID, err := result.LastInsertId(); err == nil { + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + if config.LastInsertIDReversed { + for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) + insertID-- + } + } else { + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) + insertID++ } - case reflect.Struct: - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) } - } else { - db.AddError(err) + case reflect.Struct: + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) } + } else { + db.AddError(err) } } db.RowsAffected, _ = result.RowsAffected() @@ -122,19 +107,17 @@ func CreateWithReturning(db *gorm.DB) { db.Statement.WriteString(" RETURNING ") var ( - idx int fields = make([]*schema.Field, len(sch.FieldsWithDefaultDBValue)) values = make([]interface{}, len(sch.FieldsWithDefaultDBValue)) ) - for dbName, field := range sch.FieldsWithDefaultDBValue { - if idx != 0 { + for idx, field := range sch.FieldsWithDefaultDBValue { + if idx > 0 { db.Statement.WriteByte(',') } fields[idx] = field - db.Statement.WriteQuoted(dbName) - idx++ + db.Statement.WriteQuoted(field.DBName) } if !db.DryRun { @@ -149,10 +132,11 @@ func CreateWithReturning(db *gorm.DB) { for idx, field := range fields { values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() } + + db.RowsAffected++ if err := rows.Scan(values...); err != nil { db.AddError(err) } - db.RowsAffected++ } case reflect.Struct: for idx, field := range fields { @@ -161,12 +145,10 @@ func CreateWithReturning(db *gorm.DB) { if rows.Next() { db.RowsAffected++ - err = rows.Scan(values...) + db.AddError(rows.Scan(values...)) } } - } - - if err != nil { + } else { db.AddError(err) } } @@ -182,9 +164,7 @@ func CreateWithReturning(db *gorm.DB) { func AfterCreate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { - tx := db.Session(&gorm.Session{}) - callMethod := func(value interface{}) bool { - var called bool + callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.AfterSave { if i, ok := value.(gorm.AfterSaveInterface); ok { called = true @@ -199,18 +179,7 @@ func AfterCreate(db *gorm.DB) { } } return called - } - - if ok := callMethod(db.Statement.Dest); !ok { - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - callMethod(db.Statement.ReflectValue.Index(i).Addr().Interface()) - } - case reflect.Struct: - callMethod(db.Statement.ReflectValue.Addr().Interface()) - } - } + }) } } @@ -230,7 +199,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) clause.Values { ) for _, db := range stmt.Schema.DBNames { - if stmt.Schema.FieldsWithDefaultDBValue[db] == nil { + if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil { if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) { values.Columns = append(values.Columns, clause.Column{Name: db}) } @@ -257,13 +226,13 @@ func ConvertToCreateValues(stmt *gorm.Statement) clause.Values { } } - for db, field := range stmt.Schema.FieldsWithDefaultDBValue { - if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) { + for _, field := range stmt.Schema.FieldsWithDefaultDBValue { + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { if v, isZero := field.ValueOf(rv); !isZero { - if len(defaultValueFieldsHavingValue[db]) == 0 { - defaultValueFieldsHavingValue[db] = make([]interface{}, stmt.ReflectValue.Len()) + if len(defaultValueFieldsHavingValue[field.DBName]) == 0 { + defaultValueFieldsHavingValue[field.DBName] = make([]interface{}, stmt.ReflectValue.Len()) } - defaultValueFieldsHavingValue[db][i] = v + defaultValueFieldsHavingValue[field.DBName][i] = v } } } @@ -294,10 +263,10 @@ func ConvertToCreateValues(stmt *gorm.Statement) clause.Values { } } - for db, field := range stmt.Schema.FieldsWithDefaultDBValue { - if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) { + for _, field := range stmt.Schema.FieldsWithDefaultDBValue { + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { if v, isZero := field.ValueOf(stmt.ReflectValue); !isZero { - values.Columns = append(values.Columns, clause.Column{Name: db}) + values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) values.Values[0] = append(values.Values[0], v) } } diff --git a/callbacks/delete.go b/callbacks/delete.go index f1a49c11..b246e69f 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -10,27 +10,14 @@ import ( func BeforeDelete(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete { - tx := db.Session(&gorm.Session{}) - callMethod := func(value interface{}) bool { - if db.Statement.Schema.BeforeDelete { - if i, ok := value.(gorm.BeforeDeleteInterface); ok { - db.AddError(i.BeforeDelete(tx)) - return true - } + callMethod(db, func(value interface{}, tx *gorm.DB) bool { + if i, ok := value.(gorm.BeforeDeleteInterface); ok { + db.AddError(i.BeforeDelete(tx)) + return true } - return false - } - if ok := callMethod(db.Statement.Dest); !ok { - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - callMethod(db.Statement.ReflectValue.Index(i).Addr().Interface()) - } - case reflect.Struct: - callMethod(db.Statement.ReflectValue.Addr().Interface()) - } - } + return false + }) } } @@ -86,26 +73,12 @@ func Delete(db *gorm.DB) { func AfterDelete(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterDelete { - tx := db.Session(&gorm.Session{}) - callMethod := func(value interface{}) bool { - if db.Statement.Schema.AfterDelete { - if i, ok := value.(gorm.AfterDeleteInterface); ok { - db.AddError(i.AfterDelete(tx)) - return true - } + callMethod(db, func(value interface{}, tx *gorm.DB) bool { + if i, ok := value.(gorm.AfterDeleteInterface); ok { + db.AddError(i.AfterDelete(tx)) + return true } return false - } - - if ok := callMethod(db.Statement.Dest); !ok { - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - callMethod(db.Statement.ReflectValue.Index(i).Addr().Interface()) - } - case reflect.Struct: - callMethod(db.Statement.ReflectValue.Addr().Interface()) - } - } + }) } } diff --git a/callbacks/query.go b/callbacks/query.go index b6667414..41f09375 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -188,26 +188,12 @@ func Preload(db *gorm.DB) { func AfterQuery(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterFind { - tx := db.Session(&gorm.Session{}) - callMethod := func(value interface{}) bool { - if db.Statement.Schema.AfterFind { - if i, ok := value.(gorm.AfterFindInterface); ok { - db.AddError(i.AfterFind(tx)) - return true - } + callMethod(db, func(value interface{}, tx *gorm.DB) bool { + if i, ok := value.(gorm.AfterFindInterface); ok { + db.AddError(i.AfterFind(tx)) + return true } return false - } - - if ok := callMethod(db.Statement.Dest); !ok { - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - callMethod(db.Statement.ReflectValue.Index(i).Addr().Interface()) - } - case reflect.Struct: - callMethod(db.Statement.ReflectValue.Addr().Interface()) - } - } + }) } } diff --git a/callbacks/update.go b/callbacks/update.go index 9c922956..a41a3c59 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -30,9 +30,7 @@ func SetupUpdateReflectValue(db *gorm.DB) { func BeforeUpdate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { - tx := db.Session(&gorm.Session{}) - callMethod := func(value interface{}) bool { - var called bool + callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.BeforeSave { if i, ok := value.(gorm.BeforeSaveInterface); ok { called = true @@ -46,19 +44,9 @@ func BeforeUpdate(db *gorm.DB) { db.AddError(i.BeforeUpdate(tx)) } } - return called - } - if ok := callMethod(db.Statement.Dest); !ok { - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - callMethod(db.Statement.ReflectValue.Index(i).Addr().Interface()) - } - case reflect.Struct: - callMethod(db.Statement.ReflectValue.Addr().Interface()) - } - } + return called + }) } } @@ -99,9 +87,7 @@ func Update(db *gorm.DB) { func AfterUpdate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { - tx := db.Session(&gorm.Session{}) - callMethod := func(value interface{}) bool { - var called bool + callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.AfterSave { if i, ok := value.(gorm.AfterSaveInterface); ok { called = true @@ -116,18 +102,7 @@ func AfterUpdate(db *gorm.DB) { } } return called - } - - if ok := callMethod(db.Statement.Dest); !ok { - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - callMethod(db.Statement.ReflectValue.Index(i).Addr().Interface()) - } - case reflect.Struct: - callMethod(db.Statement.ReflectValue.Addr().Interface()) - } - } + }) } } diff --git a/gorm.go b/gorm.go index e6a28635..cea744f7 100644 --- a/gorm.go +++ b/gorm.go @@ -25,9 +25,10 @@ type Config struct { NowFunc func() time.Time // DryRun generate sql without execute DryRun bool - // PrepareStmt executes the given query in cached statement PrepareStmt bool + // DisableAutomaticPing + DisableAutomaticPing bool // ClauseBuilders clause builder ClauseBuilders map[string]clause.ClauseBuilder @@ -93,8 +94,8 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { config.ClauseBuilders = map[string]clause.ClauseBuilder{} } - if dialector != nil { - err = dialector.Initialize(db) + if config.Dialector != nil { + err = config.Dialector.Initialize(db) } if config.PrepareStmt { @@ -104,16 +105,14 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { } } - if db.Statement == nil { - db.Statement = &Statement{ - DB: db, - ConnPool: db.ConnPool, - Context: context.Background(), - Clauses: map[string]clause.Clause{}, - } + db.Statement = &Statement{ + DB: db, + ConnPool: db.ConnPool, + Context: context.Background(), + Clauses: map[string]clause.Clause{}, } - if err == nil { + if err == nil && !config.DisableAutomaticPing { if pinger, ok := db.ConnPool.(interface{ Ping() error }); ok { err = pinger.Ping() } @@ -138,17 +137,8 @@ func (db *DB) Session(config *Session) *DB { ) if config.Context != nil { - if tx.Statement != nil { - tx.Statement = tx.Statement.clone() - tx.Statement.DB = tx - } else { - tx.Statement = &Statement{ - DB: tx, - Clauses: map[string]clause.Clause{}, - ConnPool: tx.ConnPool, - } - } - + tx.Statement = tx.Statement.clone() + tx.Statement.DB = tx tx.Statement.Context = config.Context } @@ -160,7 +150,7 @@ func (db *DB) Session(config *Session) *DB { } if config.WithConditions { - tx.clone = 3 + tx.clone = 2 } if config.DryRun { @@ -200,10 +190,7 @@ func (db *DB) Set(key string, value interface{}) *DB { // Get get value with key from current db instance's context func (db *DB) Get(key string) (interface{}, bool) { - if db.Statement != nil { - return db.Statement.Settings.Load(key) - } - return nil, false + return db.Statement.Settings.Load(key) } // InstanceSet store value with key into current db instance's context @@ -215,10 +202,7 @@ func (db *DB) InstanceSet(key string, value interface{}) *DB { // InstanceGet get value with key from current db instance's context func (db *DB) InstanceGet(key string) (interface{}, bool) { - if db.Statement != nil { - return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key) - } - return nil, false + return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key) } func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error { @@ -282,22 +266,18 @@ func (db *DB) getInstance() *DB { if db.clone > 0 { tx := &DB{Config: db.Config} - switch db.clone { - case 1: // clone with new statement + if db.clone == 1 { + // clone with new statement tx.Statement = &Statement{ DB: tx, ConnPool: db.Statement.ConnPool, Context: db.Statement.Context, Clauses: map[string]clause.Clause{}, } - case 2: // with old statement, generate new statement for future call, used to pass to callbacks - db.clone = 1 - tx.Statement = db.Statement - case 3: // with clone statement - if db.Statement != nil { - tx.Statement = db.Statement.clone() - tx.Statement.DB = tx - } + } else { + // with clone statement + tx.Statement = db.Statement.clone() + tx.Statement.DB = tx } return tx diff --git a/migrator/migrator.go b/migrator/migrator.go index afef65c3..18b2593d 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -62,7 +62,7 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { expr.SQL += " UNIQUE" } - if field.HasDefaultValue { + if field.HasDefaultValue && field.DefaultValue != "" { if field.DataType == schema.String { defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValue}} m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValue) diff --git a/schema/field_test.go b/schema/field_test.go index cc4b53fc..0936c0d1 100644 --- a/schema/field_test.go +++ b/schema/field_test.go @@ -235,7 +235,7 @@ func TestParseFieldWithPermission(t *testing.T) { } fields := []schema.Field{ - {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, Creatable: true, Updatable: true, Readable: true}, + {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, Creatable: true, Updatable: true, Readable: true, HasDefaultValue: true}, {Name: "Name", DBName: "", BindNames: []string{"Name"}, DataType: "", Tag: `gorm:"-"`, Creatable: false, Updatable: false, Readable: false}, {Name: "Name2", DBName: "name2", BindNames: []string{"Name2"}, DataType: schema.String, Tag: `gorm:"->"`, Creatable: false, Updatable: false, Readable: true}, {Name: "Name3", DBName: "name3", BindNames: []string{"Name3"}, DataType: schema.String, Tag: `gorm:"<-"`, Creatable: true, Updatable: true, Readable: true}, diff --git a/schema/schema.go b/schema/schema.go index 9e05303a..d2c4d08b 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -26,7 +26,7 @@ type Schema struct { Fields []*Field FieldsByName map[string]*Field FieldsByDBName map[string]*Field - FieldsWithDefaultDBValue map[string]*Field // fields with default value assigned by database + FieldsWithDefaultDBValue []*Field // fields with default value assigned by database Relationships Relationships CreateClauses []clause.Interface QueryClauses []clause.Interface @@ -153,23 +153,14 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) schema.FieldsByName[field.Name] = field if v != nil && v.PrimaryKey { - if schema.PrioritizedPrimaryField == v { - schema.PrioritizedPrimaryField = nil - } - for idx, f := range schema.PrimaryFields { if f == v { schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...) - } else if schema.PrioritizedPrimaryField == nil { - schema.PrioritizedPrimaryField = f } } } if field.PrimaryKey { - if schema.PrioritizedPrimaryField == nil { - schema.PrioritizedPrimaryField = field - } schema.PrimaryFields = append(schema.PrimaryFields, field) } } @@ -192,21 +183,27 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) } } + if schema.PrioritizedPrimaryField == nil && len(schema.PrimaryFields) == 1 { + schema.PrioritizedPrimaryField = schema.PrimaryFields[0] + } + for _, field := range schema.PrimaryFields { schema.PrimaryFieldDBNames = append(schema.PrimaryFieldDBNames, field.DBName) } - schema.FieldsWithDefaultDBValue = map[string]*Field{} - for db, field := range schema.FieldsByDBName { + for _, field := range schema.FieldsByDBName { if field.HasDefaultValue && field.DefaultValueInterface == nil { - schema.FieldsWithDefaultDBValue[db] = field + schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) } } - if schema.PrioritizedPrimaryField != nil { - switch schema.PrioritizedPrimaryField.DataType { + if field := schema.PrioritizedPrimaryField; field != nil { + switch field.DataType { case Int, Uint: - schema.FieldsWithDefaultDBValue[schema.PrioritizedPrimaryField.DBName] = schema.PrioritizedPrimaryField + if !field.HasDefaultValue || field.DefaultValueInterface != nil { + schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) + } + field.HasDefaultValue = true } } diff --git a/schema/schema_test.go b/schema/schema_test.go index 82f07fa8..4ec7ff0c 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -32,7 +32,7 @@ func checkUserSchema(t *testing.T, user *schema.Schema) { // check fields fields := []schema.Field{ - {Name: "ID", DBName: "id", BindNames: []string{"Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Tag: `gorm:"primarykey"`, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}, Size: 64}, + {Name: "ID", DBName: "id", BindNames: []string{"Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Tag: `gorm:"primarykey"`, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}, Size: 64, HasDefaultValue: true}, {Name: "CreatedAt", DBName: "created_at", BindNames: []string{"Model", "CreatedAt"}, DataType: schema.Time}, {Name: "UpdatedAt", DBName: "updated_at", BindNames: []string{"Model", "UpdatedAt"}, DataType: schema.Time}, {Name: "DeletedAt", DBName: "deleted_at", BindNames: []string{"Model", "DeletedAt"}, Tag: `gorm:"index"`, DataType: schema.Time}, @@ -125,7 +125,7 @@ func TestParseSchemaWithAdvancedDataType(t *testing.T) { // check fields fields := []schema.Field{ - {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Int, PrimaryKey: true, Size: 64}, + {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Int, PrimaryKey: true, Size: 64, HasDefaultValue: true}, {Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String}, {Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time}, {Name: "RegisteredAt", DBName: "registered_at", BindNames: []string{"RegisteredAt"}, DataType: schema.Time}, diff --git a/statement.go b/statement.go index e3f324b9..2c814547 100644 --- a/statement.go +++ b/statement.go @@ -226,6 +226,7 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con if sql == "" && len(args) == 0 { return } else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") { + // looks like a where condition return []clause.Expression{clause.Expr{SQL: sql, Vars: args}} } else if len(args) == 1 { return []clause.Expression{clause.Eq{Column: sql, Value: args[0]}} @@ -242,12 +243,6 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con switch v := arg.(type) { case clause.Expression: conds = append(conds, v) - case *DB: - if v.Statement != nil { - if cs, ok := v.Statement.Clauses["WHERE"]; ok { - conds = append(conds, cs.Expression) - } - } case map[interface{}]interface{}: for i, j := range v { conds = append(conds, clause.Eq{Column: i, Value: j}) @@ -326,7 +321,6 @@ func (stmt *Statement) Parse(value interface{}) (err error) { func (stmt *Statement) clone() *Statement { newStmt := &Statement{ - DB: stmt.DB, Table: stmt.Table, Model: stmt.Model, Dest: stmt.Dest, @@ -357,37 +351,3 @@ func (stmt *Statement) clone() *Statement { return newStmt } - -func (stmt *Statement) reinit() { - // stmt.Table = "" - // stmt.Model = nil - // stmt.Selects = nil - // stmt.Omits = nil - // stmt.ConnPool = stmt.DB.Config.ConnPool - // stmt.Context = context.Background() - // stmt.RaiseErrorOnNotFound = false - - // for k := range stmt.Clauses { - // delete(stmt.Clauses, k) - // } - - // for k := range stmt.Joins { - // delete(stmt.Joins, k) - // } - - // for k := range stmt.Preloads { - // delete(stmt.Preloads, k) - // } - - // stmt.Settings.Range(func(k, _ interface{}) bool { - // stmt.Settings.Delete(k) - // return true - // }) - - // stmt.Schema = nil - if !stmt.DB.DryRun { - stmt.SQL.Reset() - stmt.Vars = nil - stmt.NamedVars = nil - } -} diff --git a/tests/benchmark_test.go b/tests/benchmark_test.go new file mode 100644 index 00000000..c6ce93a2 --- /dev/null +++ b/tests/benchmark_test.go @@ -0,0 +1,44 @@ +package tests_test + +import ( + "testing" + + . "gorm.io/gorm/utils/tests" +) + +func BenchmarkCreate(b *testing.B) { + var user = *GetUser("bench", Config{}) + + for x := 0; x < b.N; x++ { + user.ID = 0 + DB.Create(&user) + } +} + +func BenchmarkFind(b *testing.B) { + var user = *GetUser("find", Config{}) + DB.Create(&user) + + for x := 0; x < b.N; x++ { + DB.Find(&User{}, "id = ?", user.ID) + } +} + +func BenchmarkUpdate(b *testing.B) { + var user = *GetUser("find", Config{}) + DB.Create(&user) + + for x := 0; x < b.N; x++ { + DB.Model(&user).Updates(map[string]interface{}{"Age": x}) + } +} + +func BenchmarkDelete(b *testing.B) { + var user = *GetUser("find", Config{}) + + for x := 0; x < b.N; x++ { + user.ID = 0 + DB.Create(&user) + DB.Delete(&user) + } +} diff --git a/tests/go.mod b/tests/go.mod index de58a0de..3c2dfc6c 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -7,7 +7,7 @@ require ( gorm.io/driver/mysql v0.0.0-20200602015408-0407d0c21cf0 gorm.io/driver/postgres v0.0.0-20200602015520-15fcc29eb286 gorm.io/driver/sqlite v1.0.0 - gorm.io/driver/sqlserver v0.0.0-20200602144728-79c224f6c1a2 + gorm.io/driver/sqlserver v0.0.0-20200605135528-04ae0f7a15bf gorm.io/gorm v0.0.0-00010101000000-000000000000 )