Refactor codebase and add benchmark test

This commit is contained in:
Jinzhu 2020-06-05 21:23:20 +08:00
parent 163200d05f
commit 1490a062db
14 changed files with 168 additions and 260 deletions

View File

@ -105,8 +105,11 @@ func (p *processor) Execute(db *DB) {
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected
}, db.Error) }, db.Error)
stmt.reinit() if !stmt.DB.DryRun {
// db.Config.statementPool.Put(stmt) stmt.SQL.Reset()
stmt.Vars = nil
stmt.NamedVars = nil
}
} }
} }

21
callbacks/callmethod.go Normal file
View File

@ -0,0 +1,21 @@
package callbacks
import (
"reflect"
"gorm.io/gorm"
)
func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) {
tx := db.Session(&gorm.Session{})
if called := fc(db.Statement.Dest, tx); !called {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
fc(db.Statement.ReflectValue.Index(i).Addr().Interface(), tx)
}
case reflect.Struct:
fc(db.Statement.ReflectValue.Addr().Interface(), tx)
}
}
}

View File

@ -10,9 +10,7 @@ import (
func BeforeCreate(db *gorm.DB) { func BeforeCreate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) { if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
tx := db.Session(&gorm.Session{}) callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
callMethod := func(value interface{}) bool {
var called bool
if db.Statement.Schema.BeforeSave { if db.Statement.Schema.BeforeSave {
if i, ok := value.(gorm.BeforeSaveInterface); ok { if i, ok := value.(gorm.BeforeSaveInterface); ok {
called = true called = true
@ -27,18 +25,7 @@ func BeforeCreate(db *gorm.DB) {
} }
} }
return called return called
} })
if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Addr().Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Addr().Interface())
}
}
} }
} }
@ -67,28 +54,26 @@ func Create(config *Config) func(db *gorm.DB) {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err == nil { if err == nil {
if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil { if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
if _, ok := db.Statement.Schema.FieldsWithDefaultDBValue[db.Statement.Schema.PrioritizedPrimaryField.DBName]; ok { if insertID, err := result.LastInsertId(); err == nil {
if insertID, err := result.LastInsertId(); err == nil { switch db.Statement.ReflectValue.Kind() {
switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array:
case reflect.Slice, reflect.Array: if config.LastInsertIDReversed {
if config.LastInsertIDReversed { for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID)
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) insertID--
insertID-- }
} } else {
} else { for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
for i := 0; i < db.Statement.ReflectValue.Len(); i++ { db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID)
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) insertID++
insertID++
}
} }
case reflect.Struct:
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
} }
} else { case reflect.Struct:
db.AddError(err) db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
} }
} else {
db.AddError(err)
} }
} }
db.RowsAffected, _ = result.RowsAffected() db.RowsAffected, _ = result.RowsAffected()
@ -122,19 +107,17 @@ func CreateWithReturning(db *gorm.DB) {
db.Statement.WriteString(" RETURNING ") db.Statement.WriteString(" RETURNING ")
var ( var (
idx int
fields = make([]*schema.Field, len(sch.FieldsWithDefaultDBValue)) fields = make([]*schema.Field, len(sch.FieldsWithDefaultDBValue))
values = make([]interface{}, len(sch.FieldsWithDefaultDBValue)) values = make([]interface{}, len(sch.FieldsWithDefaultDBValue))
) )
for dbName, field := range sch.FieldsWithDefaultDBValue { for idx, field := range sch.FieldsWithDefaultDBValue {
if idx != 0 { if idx > 0 {
db.Statement.WriteByte(',') db.Statement.WriteByte(',')
} }
fields[idx] = field fields[idx] = field
db.Statement.WriteQuoted(dbName) db.Statement.WriteQuoted(field.DBName)
idx++
} }
if !db.DryRun { if !db.DryRun {
@ -149,10 +132,11 @@ func CreateWithReturning(db *gorm.DB) {
for idx, field := range fields { for idx, field := range fields {
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface()
} }
db.RowsAffected++
if err := rows.Scan(values...); err != nil { if err := rows.Scan(values...); err != nil {
db.AddError(err) db.AddError(err)
} }
db.RowsAffected++
} }
case reflect.Struct: case reflect.Struct:
for idx, field := range fields { for idx, field := range fields {
@ -161,12 +145,10 @@ func CreateWithReturning(db *gorm.DB) {
if rows.Next() { if rows.Next() {
db.RowsAffected++ db.RowsAffected++
err = rows.Scan(values...) db.AddError(rows.Scan(values...))
} }
} }
} } else {
if err != nil {
db.AddError(err) db.AddError(err)
} }
} }
@ -182,9 +164,7 @@ func CreateWithReturning(db *gorm.DB) {
func AfterCreate(db *gorm.DB) { func AfterCreate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
tx := db.Session(&gorm.Session{}) callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
callMethod := func(value interface{}) bool {
var called bool
if db.Statement.Schema.AfterSave { if db.Statement.Schema.AfterSave {
if i, ok := value.(gorm.AfterSaveInterface); ok { if i, ok := value.(gorm.AfterSaveInterface); ok {
called = true called = true
@ -199,18 +179,7 @@ func AfterCreate(db *gorm.DB) {
} }
} }
return called return called
} })
if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Addr().Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Addr().Interface())
}
}
} }
} }
@ -230,7 +199,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) clause.Values {
) )
for _, db := range stmt.Schema.DBNames { for _, db := range stmt.Schema.DBNames {
if stmt.Schema.FieldsWithDefaultDBValue[db] == nil { if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil {
if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) { if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) {
values.Columns = append(values.Columns, clause.Column{Name: db}) values.Columns = append(values.Columns, clause.Column{Name: db})
} }
@ -257,13 +226,13 @@ func ConvertToCreateValues(stmt *gorm.Statement) clause.Values {
} }
} }
for db, field := range stmt.Schema.FieldsWithDefaultDBValue { for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if v, isZero := field.ValueOf(rv); !isZero { if v, isZero := field.ValueOf(rv); !isZero {
if len(defaultValueFieldsHavingValue[db]) == 0 { if len(defaultValueFieldsHavingValue[field.DBName]) == 0 {
defaultValueFieldsHavingValue[db] = make([]interface{}, stmt.ReflectValue.Len()) defaultValueFieldsHavingValue[field.DBName] = make([]interface{}, stmt.ReflectValue.Len())
} }
defaultValueFieldsHavingValue[db][i] = v defaultValueFieldsHavingValue[field.DBName][i] = v
} }
} }
} }
@ -294,10 +263,10 @@ func ConvertToCreateValues(stmt *gorm.Statement) clause.Values {
} }
} }
for db, field := range stmt.Schema.FieldsWithDefaultDBValue { for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if v, isZero := field.ValueOf(stmt.ReflectValue); !isZero { if v, isZero := field.ValueOf(stmt.ReflectValue); !isZero {
values.Columns = append(values.Columns, clause.Column{Name: db}) values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
values.Values[0] = append(values.Values[0], v) values.Values[0] = append(values.Values[0], v)
} }
} }

View File

@ -10,27 +10,14 @@ import (
func BeforeDelete(db *gorm.DB) { func BeforeDelete(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete { if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete {
tx := db.Session(&gorm.Session{}) callMethod(db, func(value interface{}, tx *gorm.DB) bool {
callMethod := func(value interface{}) bool { if i, ok := value.(gorm.BeforeDeleteInterface); ok {
if db.Statement.Schema.BeforeDelete { db.AddError(i.BeforeDelete(tx))
if i, ok := value.(gorm.BeforeDeleteInterface); ok { return true
db.AddError(i.BeforeDelete(tx))
return true
}
} }
return false
}
if ok := callMethod(db.Statement.Dest); !ok { return false
switch db.Statement.ReflectValue.Kind() { })
case reflect.Slice, reflect.Array:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Addr().Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Addr().Interface())
}
}
} }
} }
@ -86,26 +73,12 @@ func Delete(db *gorm.DB) {
func AfterDelete(db *gorm.DB) { func AfterDelete(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterDelete { if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterDelete {
tx := db.Session(&gorm.Session{}) callMethod(db, func(value interface{}, tx *gorm.DB) bool {
callMethod := func(value interface{}) bool { if i, ok := value.(gorm.AfterDeleteInterface); ok {
if db.Statement.Schema.AfterDelete { db.AddError(i.AfterDelete(tx))
if i, ok := value.(gorm.AfterDeleteInterface); ok { return true
db.AddError(i.AfterDelete(tx))
return true
}
} }
return false return false
} })
if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Addr().Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Addr().Interface())
}
}
} }
} }

View File

@ -188,26 +188,12 @@ func Preload(db *gorm.DB) {
func AfterQuery(db *gorm.DB) { func AfterQuery(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterFind { if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterFind {
tx := db.Session(&gorm.Session{}) callMethod(db, func(value interface{}, tx *gorm.DB) bool {
callMethod := func(value interface{}) bool { if i, ok := value.(gorm.AfterFindInterface); ok {
if db.Statement.Schema.AfterFind { db.AddError(i.AfterFind(tx))
if i, ok := value.(gorm.AfterFindInterface); ok { return true
db.AddError(i.AfterFind(tx))
return true
}
} }
return false return false
} })
if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Addr().Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Addr().Interface())
}
}
} }
} }

View File

@ -30,9 +30,7 @@ func SetupUpdateReflectValue(db *gorm.DB) {
func BeforeUpdate(db *gorm.DB) { func BeforeUpdate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
tx := db.Session(&gorm.Session{}) callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
callMethod := func(value interface{}) bool {
var called bool
if db.Statement.Schema.BeforeSave { if db.Statement.Schema.BeforeSave {
if i, ok := value.(gorm.BeforeSaveInterface); ok { if i, ok := value.(gorm.BeforeSaveInterface); ok {
called = true called = true
@ -46,19 +44,9 @@ func BeforeUpdate(db *gorm.DB) {
db.AddError(i.BeforeUpdate(tx)) db.AddError(i.BeforeUpdate(tx))
} }
} }
return called
}
if ok := callMethod(db.Statement.Dest); !ok { return called
switch db.Statement.ReflectValue.Kind() { })
case reflect.Slice, reflect.Array:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Addr().Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Addr().Interface())
}
}
} }
} }
@ -99,9 +87,7 @@ func Update(db *gorm.DB) {
func AfterUpdate(db *gorm.DB) { func AfterUpdate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
tx := db.Session(&gorm.Session{}) callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
callMethod := func(value interface{}) bool {
var called bool
if db.Statement.Schema.AfterSave { if db.Statement.Schema.AfterSave {
if i, ok := value.(gorm.AfterSaveInterface); ok { if i, ok := value.(gorm.AfterSaveInterface); ok {
called = true called = true
@ -116,18 +102,7 @@ func AfterUpdate(db *gorm.DB) {
} }
} }
return called return called
} })
if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Addr().Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Addr().Interface())
}
}
} }
} }

62
gorm.go
View File

@ -25,9 +25,10 @@ type Config struct {
NowFunc func() time.Time NowFunc func() time.Time
// DryRun generate sql without execute // DryRun generate sql without execute
DryRun bool DryRun bool
// PrepareStmt executes the given query in cached statement // PrepareStmt executes the given query in cached statement
PrepareStmt bool PrepareStmt bool
// DisableAutomaticPing
DisableAutomaticPing bool
// ClauseBuilders clause builder // ClauseBuilders clause builder
ClauseBuilders map[string]clause.ClauseBuilder ClauseBuilders map[string]clause.ClauseBuilder
@ -93,8 +94,8 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) {
config.ClauseBuilders = map[string]clause.ClauseBuilder{} config.ClauseBuilders = map[string]clause.ClauseBuilder{}
} }
if dialector != nil { if config.Dialector != nil {
err = dialector.Initialize(db) err = config.Dialector.Initialize(db)
} }
if config.PrepareStmt { if config.PrepareStmt {
@ -104,16 +105,14 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) {
} }
} }
if db.Statement == nil { db.Statement = &Statement{
db.Statement = &Statement{ DB: db,
DB: db, ConnPool: db.ConnPool,
ConnPool: db.ConnPool, Context: context.Background(),
Context: context.Background(), Clauses: map[string]clause.Clause{},
Clauses: map[string]clause.Clause{},
}
} }
if err == nil { if err == nil && !config.DisableAutomaticPing {
if pinger, ok := db.ConnPool.(interface{ Ping() error }); ok { if pinger, ok := db.ConnPool.(interface{ Ping() error }); ok {
err = pinger.Ping() err = pinger.Ping()
} }
@ -138,17 +137,8 @@ func (db *DB) Session(config *Session) *DB {
) )
if config.Context != nil { if config.Context != nil {
if tx.Statement != nil { tx.Statement = tx.Statement.clone()
tx.Statement = tx.Statement.clone() tx.Statement.DB = tx
tx.Statement.DB = tx
} else {
tx.Statement = &Statement{
DB: tx,
Clauses: map[string]clause.Clause{},
ConnPool: tx.ConnPool,
}
}
tx.Statement.Context = config.Context tx.Statement.Context = config.Context
} }
@ -160,7 +150,7 @@ func (db *DB) Session(config *Session) *DB {
} }
if config.WithConditions { if config.WithConditions {
tx.clone = 3 tx.clone = 2
} }
if config.DryRun { if config.DryRun {
@ -200,10 +190,7 @@ func (db *DB) Set(key string, value interface{}) *DB {
// Get get value with key from current db instance's context // Get get value with key from current db instance's context
func (db *DB) Get(key string) (interface{}, bool) { func (db *DB) Get(key string) (interface{}, bool) {
if db.Statement != nil { return db.Statement.Settings.Load(key)
return db.Statement.Settings.Load(key)
}
return nil, false
} }
// InstanceSet store value with key into current db instance's context // InstanceSet store value with key into current db instance's context
@ -215,10 +202,7 @@ func (db *DB) InstanceSet(key string, value interface{}) *DB {
// InstanceGet get value with key from current db instance's context // InstanceGet get value with key from current db instance's context
func (db *DB) InstanceGet(key string) (interface{}, bool) { func (db *DB) InstanceGet(key string) (interface{}, bool) {
if db.Statement != nil { return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key)
return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key)
}
return nil, false
} }
func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error { func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error {
@ -282,22 +266,18 @@ func (db *DB) getInstance() *DB {
if db.clone > 0 { if db.clone > 0 {
tx := &DB{Config: db.Config} tx := &DB{Config: db.Config}
switch db.clone { if db.clone == 1 {
case 1: // clone with new statement // clone with new statement
tx.Statement = &Statement{ tx.Statement = &Statement{
DB: tx, DB: tx,
ConnPool: db.Statement.ConnPool, ConnPool: db.Statement.ConnPool,
Context: db.Statement.Context, Context: db.Statement.Context,
Clauses: map[string]clause.Clause{}, Clauses: map[string]clause.Clause{},
} }
case 2: // with old statement, generate new statement for future call, used to pass to callbacks } else {
db.clone = 1 // with clone statement
tx.Statement = db.Statement tx.Statement = db.Statement.clone()
case 3: // with clone statement tx.Statement.DB = tx
if db.Statement != nil {
tx.Statement = db.Statement.clone()
tx.Statement.DB = tx
}
} }
return tx return tx

View File

@ -62,7 +62,7 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
expr.SQL += " UNIQUE" expr.SQL += " UNIQUE"
} }
if field.HasDefaultValue { if field.HasDefaultValue && field.DefaultValue != "" {
if field.DataType == schema.String { if field.DataType == schema.String {
defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValue}} defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValue}}
m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValue) m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValue)

View File

@ -235,7 +235,7 @@ func TestParseFieldWithPermission(t *testing.T) {
} }
fields := []schema.Field{ fields := []schema.Field{
{Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, Creatable: true, Updatable: true, Readable: true}, {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, Creatable: true, Updatable: true, Readable: true, HasDefaultValue: true},
{Name: "Name", DBName: "", BindNames: []string{"Name"}, DataType: "", Tag: `gorm:"-"`, Creatable: false, Updatable: false, Readable: false}, {Name: "Name", DBName: "", BindNames: []string{"Name"}, DataType: "", Tag: `gorm:"-"`, Creatable: false, Updatable: false, Readable: false},
{Name: "Name2", DBName: "name2", BindNames: []string{"Name2"}, DataType: schema.String, Tag: `gorm:"->"`, Creatable: false, Updatable: false, Readable: true}, {Name: "Name2", DBName: "name2", BindNames: []string{"Name2"}, DataType: schema.String, Tag: `gorm:"->"`, Creatable: false, Updatable: false, Readable: true},
{Name: "Name3", DBName: "name3", BindNames: []string{"Name3"}, DataType: schema.String, Tag: `gorm:"<-"`, Creatable: true, Updatable: true, Readable: true}, {Name: "Name3", DBName: "name3", BindNames: []string{"Name3"}, DataType: schema.String, Tag: `gorm:"<-"`, Creatable: true, Updatable: true, Readable: true},

View File

@ -26,7 +26,7 @@ type Schema struct {
Fields []*Field Fields []*Field
FieldsByName map[string]*Field FieldsByName map[string]*Field
FieldsByDBName map[string]*Field FieldsByDBName map[string]*Field
FieldsWithDefaultDBValue map[string]*Field // fields with default value assigned by database FieldsWithDefaultDBValue []*Field // fields with default value assigned by database
Relationships Relationships Relationships Relationships
CreateClauses []clause.Interface CreateClauses []clause.Interface
QueryClauses []clause.Interface QueryClauses []clause.Interface
@ -153,23 +153,14 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
schema.FieldsByName[field.Name] = field schema.FieldsByName[field.Name] = field
if v != nil && v.PrimaryKey { if v != nil && v.PrimaryKey {
if schema.PrioritizedPrimaryField == v {
schema.PrioritizedPrimaryField = nil
}
for idx, f := range schema.PrimaryFields { for idx, f := range schema.PrimaryFields {
if f == v { if f == v {
schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...) schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...)
} else if schema.PrioritizedPrimaryField == nil {
schema.PrioritizedPrimaryField = f
} }
} }
} }
if field.PrimaryKey { if field.PrimaryKey {
if schema.PrioritizedPrimaryField == nil {
schema.PrioritizedPrimaryField = field
}
schema.PrimaryFields = append(schema.PrimaryFields, field) schema.PrimaryFields = append(schema.PrimaryFields, field)
} }
} }
@ -192,21 +183,27 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
} }
} }
if schema.PrioritizedPrimaryField == nil && len(schema.PrimaryFields) == 1 {
schema.PrioritizedPrimaryField = schema.PrimaryFields[0]
}
for _, field := range schema.PrimaryFields { for _, field := range schema.PrimaryFields {
schema.PrimaryFieldDBNames = append(schema.PrimaryFieldDBNames, field.DBName) schema.PrimaryFieldDBNames = append(schema.PrimaryFieldDBNames, field.DBName)
} }
schema.FieldsWithDefaultDBValue = map[string]*Field{} for _, field := range schema.FieldsByDBName {
for db, field := range schema.FieldsByDBName {
if field.HasDefaultValue && field.DefaultValueInterface == nil { if field.HasDefaultValue && field.DefaultValueInterface == nil {
schema.FieldsWithDefaultDBValue[db] = field schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field)
} }
} }
if schema.PrioritizedPrimaryField != nil { if field := schema.PrioritizedPrimaryField; field != nil {
switch schema.PrioritizedPrimaryField.DataType { switch field.DataType {
case Int, Uint: case Int, Uint:
schema.FieldsWithDefaultDBValue[schema.PrioritizedPrimaryField.DBName] = schema.PrioritizedPrimaryField if !field.HasDefaultValue || field.DefaultValueInterface != nil {
schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field)
}
field.HasDefaultValue = true
} }
} }

View File

@ -32,7 +32,7 @@ func checkUserSchema(t *testing.T, user *schema.Schema) {
// check fields // check fields
fields := []schema.Field{ fields := []schema.Field{
{Name: "ID", DBName: "id", BindNames: []string{"Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Tag: `gorm:"primarykey"`, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}, Size: 64}, {Name: "ID", DBName: "id", BindNames: []string{"Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Tag: `gorm:"primarykey"`, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}, Size: 64, HasDefaultValue: true},
{Name: "CreatedAt", DBName: "created_at", BindNames: []string{"Model", "CreatedAt"}, DataType: schema.Time}, {Name: "CreatedAt", DBName: "created_at", BindNames: []string{"Model", "CreatedAt"}, DataType: schema.Time},
{Name: "UpdatedAt", DBName: "updated_at", BindNames: []string{"Model", "UpdatedAt"}, DataType: schema.Time}, {Name: "UpdatedAt", DBName: "updated_at", BindNames: []string{"Model", "UpdatedAt"}, DataType: schema.Time},
{Name: "DeletedAt", DBName: "deleted_at", BindNames: []string{"Model", "DeletedAt"}, Tag: `gorm:"index"`, DataType: schema.Time}, {Name: "DeletedAt", DBName: "deleted_at", BindNames: []string{"Model", "DeletedAt"}, Tag: `gorm:"index"`, DataType: schema.Time},
@ -125,7 +125,7 @@ func TestParseSchemaWithAdvancedDataType(t *testing.T) {
// check fields // check fields
fields := []schema.Field{ fields := []schema.Field{
{Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Int, PrimaryKey: true, Size: 64}, {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Int, PrimaryKey: true, Size: 64, HasDefaultValue: true},
{Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String}, {Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String},
{Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time}, {Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time},
{Name: "RegisteredAt", DBName: "registered_at", BindNames: []string{"RegisteredAt"}, DataType: schema.Time}, {Name: "RegisteredAt", DBName: "registered_at", BindNames: []string{"RegisteredAt"}, DataType: schema.Time},

View File

@ -226,6 +226,7 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con
if sql == "" && len(args) == 0 { if sql == "" && len(args) == 0 {
return return
} else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") { } else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") {
// looks like a where condition
return []clause.Expression{clause.Expr{SQL: sql, Vars: args}} return []clause.Expression{clause.Expr{SQL: sql, Vars: args}}
} else if len(args) == 1 { } else if len(args) == 1 {
return []clause.Expression{clause.Eq{Column: sql, Value: args[0]}} return []clause.Expression{clause.Eq{Column: sql, Value: args[0]}}
@ -242,12 +243,6 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con
switch v := arg.(type) { switch v := arg.(type) {
case clause.Expression: case clause.Expression:
conds = append(conds, v) conds = append(conds, v)
case *DB:
if v.Statement != nil {
if cs, ok := v.Statement.Clauses["WHERE"]; ok {
conds = append(conds, cs.Expression)
}
}
case map[interface{}]interface{}: case map[interface{}]interface{}:
for i, j := range v { for i, j := range v {
conds = append(conds, clause.Eq{Column: i, Value: j}) conds = append(conds, clause.Eq{Column: i, Value: j})
@ -326,7 +321,6 @@ func (stmt *Statement) Parse(value interface{}) (err error) {
func (stmt *Statement) clone() *Statement { func (stmt *Statement) clone() *Statement {
newStmt := &Statement{ newStmt := &Statement{
DB: stmt.DB,
Table: stmt.Table, Table: stmt.Table,
Model: stmt.Model, Model: stmt.Model,
Dest: stmt.Dest, Dest: stmt.Dest,
@ -357,37 +351,3 @@ func (stmt *Statement) clone() *Statement {
return newStmt return newStmt
} }
func (stmt *Statement) reinit() {
// stmt.Table = ""
// stmt.Model = nil
// stmt.Selects = nil
// stmt.Omits = nil
// stmt.ConnPool = stmt.DB.Config.ConnPool
// stmt.Context = context.Background()
// stmt.RaiseErrorOnNotFound = false
// for k := range stmt.Clauses {
// delete(stmt.Clauses, k)
// }
// for k := range stmt.Joins {
// delete(stmt.Joins, k)
// }
// for k := range stmt.Preloads {
// delete(stmt.Preloads, k)
// }
// stmt.Settings.Range(func(k, _ interface{}) bool {
// stmt.Settings.Delete(k)
// return true
// })
// stmt.Schema = nil
if !stmt.DB.DryRun {
stmt.SQL.Reset()
stmt.Vars = nil
stmt.NamedVars = nil
}
}

44
tests/benchmark_test.go Normal file
View File

@ -0,0 +1,44 @@
package tests_test
import (
"testing"
. "gorm.io/gorm/utils/tests"
)
func BenchmarkCreate(b *testing.B) {
var user = *GetUser("bench", Config{})
for x := 0; x < b.N; x++ {
user.ID = 0
DB.Create(&user)
}
}
func BenchmarkFind(b *testing.B) {
var user = *GetUser("find", Config{})
DB.Create(&user)
for x := 0; x < b.N; x++ {
DB.Find(&User{}, "id = ?", user.ID)
}
}
func BenchmarkUpdate(b *testing.B) {
var user = *GetUser("find", Config{})
DB.Create(&user)
for x := 0; x < b.N; x++ {
DB.Model(&user).Updates(map[string]interface{}{"Age": x})
}
}
func BenchmarkDelete(b *testing.B) {
var user = *GetUser("find", Config{})
for x := 0; x < b.N; x++ {
user.ID = 0
DB.Create(&user)
DB.Delete(&user)
}
}

View File

@ -7,7 +7,7 @@ require (
gorm.io/driver/mysql v0.0.0-20200602015408-0407d0c21cf0 gorm.io/driver/mysql v0.0.0-20200602015408-0407d0c21cf0
gorm.io/driver/postgres v0.0.0-20200602015520-15fcc29eb286 gorm.io/driver/postgres v0.0.0-20200602015520-15fcc29eb286
gorm.io/driver/sqlite v1.0.0 gorm.io/driver/sqlite v1.0.0
gorm.io/driver/sqlserver v0.0.0-20200602144728-79c224f6c1a2 gorm.io/driver/sqlserver v0.0.0-20200605135528-04ae0f7a15bf
gorm.io/gorm v0.0.0-00010101000000-000000000000 gorm.io/gorm v0.0.0-00010101000000-000000000000
) )