mirror of https://github.com/go-gorm/gorm.git
Refactor codebase and add benchmark test
This commit is contained in:
parent
163200d05f
commit
1490a062db
|
@ -105,8 +105,11 @@ func (p *processor) Execute(db *DB) {
|
|||
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected
|
||||
}, db.Error)
|
||||
|
||||
stmt.reinit()
|
||||
// db.Config.statementPool.Put(stmt)
|
||||
if !stmt.DB.DryRun {
|
||||
stmt.SQL.Reset()
|
||||
stmt.Vars = nil
|
||||
stmt.NamedVars = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
package callbacks
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) {
|
||||
tx := db.Session(&gorm.Session{})
|
||||
if called := fc(db.Statement.Dest, tx); !called {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
fc(db.Statement.ReflectValue.Index(i).Addr().Interface(), tx)
|
||||
}
|
||||
case reflect.Struct:
|
||||
fc(db.Statement.ReflectValue.Addr().Interface(), tx)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -10,9 +10,7 @@ import (
|
|||
|
||||
func BeforeCreate(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
|
||||
tx := db.Session(&gorm.Session{})
|
||||
callMethod := func(value interface{}) bool {
|
||||
var called bool
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||
if db.Statement.Schema.BeforeSave {
|
||||
if i, ok := value.(gorm.BeforeSaveInterface); ok {
|
||||
called = true
|
||||
|
@ -27,18 +25,7 @@ func BeforeCreate(db *gorm.DB) {
|
|||
}
|
||||
}
|
||||
return called
|
||||
}
|
||||
|
||||
if ok := callMethod(db.Statement.Dest); !ok {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
callMethod(db.Statement.ReflectValue.Index(i).Addr().Interface())
|
||||
}
|
||||
case reflect.Struct:
|
||||
callMethod(db.Statement.ReflectValue.Addr().Interface())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -67,8 +54,7 @@ func Create(config *Config) func(db *gorm.DB) {
|
|||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
|
||||
if err == nil {
|
||||
if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil {
|
||||
if _, ok := db.Statement.Schema.FieldsWithDefaultDBValue[db.Statement.Schema.PrioritizedPrimaryField.DBName]; ok {
|
||||
if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
|
||||
if insertID, err := result.LastInsertId(); err == nil {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
|
@ -90,7 +76,6 @@ func Create(config *Config) func(db *gorm.DB) {
|
|||
db.AddError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
} else {
|
||||
db.AddError(err)
|
||||
|
@ -122,19 +107,17 @@ func CreateWithReturning(db *gorm.DB) {
|
|||
db.Statement.WriteString(" RETURNING ")
|
||||
|
||||
var (
|
||||
idx int
|
||||
fields = make([]*schema.Field, len(sch.FieldsWithDefaultDBValue))
|
||||
values = make([]interface{}, len(sch.FieldsWithDefaultDBValue))
|
||||
)
|
||||
|
||||
for dbName, field := range sch.FieldsWithDefaultDBValue {
|
||||
if idx != 0 {
|
||||
for idx, field := range sch.FieldsWithDefaultDBValue {
|
||||
if idx > 0 {
|
||||
db.Statement.WriteByte(',')
|
||||
}
|
||||
|
||||
fields[idx] = field
|
||||
db.Statement.WriteQuoted(dbName)
|
||||
idx++
|
||||
db.Statement.WriteQuoted(field.DBName)
|
||||
}
|
||||
|
||||
if !db.DryRun {
|
||||
|
@ -149,10 +132,11 @@ func CreateWithReturning(db *gorm.DB) {
|
|||
for idx, field := range fields {
|
||||
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface()
|
||||
}
|
||||
|
||||
db.RowsAffected++
|
||||
if err := rows.Scan(values...); err != nil {
|
||||
db.AddError(err)
|
||||
}
|
||||
db.RowsAffected++
|
||||
}
|
||||
case reflect.Struct:
|
||||
for idx, field := range fields {
|
||||
|
@ -161,12 +145,10 @@ func CreateWithReturning(db *gorm.DB) {
|
|||
|
||||
if rows.Next() {
|
||||
db.RowsAffected++
|
||||
err = rows.Scan(values...)
|
||||
db.AddError(rows.Scan(values...))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
} else {
|
||||
db.AddError(err)
|
||||
}
|
||||
}
|
||||
|
@ -182,9 +164,7 @@ func CreateWithReturning(db *gorm.DB) {
|
|||
|
||||
func AfterCreate(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
|
||||
tx := db.Session(&gorm.Session{})
|
||||
callMethod := func(value interface{}) bool {
|
||||
var called bool
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||
if db.Statement.Schema.AfterSave {
|
||||
if i, ok := value.(gorm.AfterSaveInterface); ok {
|
||||
called = true
|
||||
|
@ -199,18 +179,7 @@ func AfterCreate(db *gorm.DB) {
|
|||
}
|
||||
}
|
||||
return called
|
||||
}
|
||||
|
||||
if ok := callMethod(db.Statement.Dest); !ok {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
callMethod(db.Statement.ReflectValue.Index(i).Addr().Interface())
|
||||
}
|
||||
case reflect.Struct:
|
||||
callMethod(db.Statement.ReflectValue.Addr().Interface())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -230,7 +199,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) clause.Values {
|
|||
)
|
||||
|
||||
for _, db := range stmt.Schema.DBNames {
|
||||
if stmt.Schema.FieldsWithDefaultDBValue[db] == nil {
|
||||
if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil {
|
||||
if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) {
|
||||
values.Columns = append(values.Columns, clause.Column{Name: db})
|
||||
}
|
||||
|
@ -257,13 +226,13 @@ func ConvertToCreateValues(stmt *gorm.Statement) clause.Values {
|
|||
}
|
||||
}
|
||||
|
||||
for db, field := range stmt.Schema.FieldsWithDefaultDBValue {
|
||||
if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) {
|
||||
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||
if v, isZero := field.ValueOf(rv); !isZero {
|
||||
if len(defaultValueFieldsHavingValue[db]) == 0 {
|
||||
defaultValueFieldsHavingValue[db] = make([]interface{}, stmt.ReflectValue.Len())
|
||||
if len(defaultValueFieldsHavingValue[field.DBName]) == 0 {
|
||||
defaultValueFieldsHavingValue[field.DBName] = make([]interface{}, stmt.ReflectValue.Len())
|
||||
}
|
||||
defaultValueFieldsHavingValue[db][i] = v
|
||||
defaultValueFieldsHavingValue[field.DBName][i] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -294,10 +263,10 @@ func ConvertToCreateValues(stmt *gorm.Statement) clause.Values {
|
|||
}
|
||||
}
|
||||
|
||||
for db, field := range stmt.Schema.FieldsWithDefaultDBValue {
|
||||
if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) {
|
||||
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||
if v, isZero := field.ValueOf(stmt.ReflectValue); !isZero {
|
||||
values.Columns = append(values.Columns, clause.Column{Name: db})
|
||||
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
|
||||
values.Values[0] = append(values.Values[0], v)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,27 +10,14 @@ import (
|
|||
|
||||
func BeforeDelete(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete {
|
||||
tx := db.Session(&gorm.Session{})
|
||||
callMethod := func(value interface{}) bool {
|
||||
if db.Statement.Schema.BeforeDelete {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
||||
if i, ok := value.(gorm.BeforeDeleteInterface); ok {
|
||||
db.AddError(i.BeforeDelete(tx))
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
if ok := callMethod(db.Statement.Dest); !ok {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
callMethod(db.Statement.ReflectValue.Index(i).Addr().Interface())
|
||||
}
|
||||
case reflect.Struct:
|
||||
callMethod(db.Statement.ReflectValue.Addr().Interface())
|
||||
}
|
||||
}
|
||||
return false
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -86,26 +73,12 @@ func Delete(db *gorm.DB) {
|
|||
|
||||
func AfterDelete(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterDelete {
|
||||
tx := db.Session(&gorm.Session{})
|
||||
callMethod := func(value interface{}) bool {
|
||||
if db.Statement.Schema.AfterDelete {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
||||
if i, ok := value.(gorm.AfterDeleteInterface); ok {
|
||||
db.AddError(i.AfterDelete(tx))
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
if ok := callMethod(db.Statement.Dest); !ok {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
callMethod(db.Statement.ReflectValue.Index(i).Addr().Interface())
|
||||
}
|
||||
case reflect.Struct:
|
||||
callMethod(db.Statement.ReflectValue.Addr().Interface())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -188,26 +188,12 @@ func Preload(db *gorm.DB) {
|
|||
|
||||
func AfterQuery(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterFind {
|
||||
tx := db.Session(&gorm.Session{})
|
||||
callMethod := func(value interface{}) bool {
|
||||
if db.Statement.Schema.AfterFind {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
||||
if i, ok := value.(gorm.AfterFindInterface); ok {
|
||||
db.AddError(i.AfterFind(tx))
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
if ok := callMethod(db.Statement.Dest); !ok {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
callMethod(db.Statement.ReflectValue.Index(i).Addr().Interface())
|
||||
}
|
||||
case reflect.Struct:
|
||||
callMethod(db.Statement.ReflectValue.Addr().Interface())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -30,9 +30,7 @@ func SetupUpdateReflectValue(db *gorm.DB) {
|
|||
|
||||
func BeforeUpdate(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
|
||||
tx := db.Session(&gorm.Session{})
|
||||
callMethod := func(value interface{}) bool {
|
||||
var called bool
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||
if db.Statement.Schema.BeforeSave {
|
||||
if i, ok := value.(gorm.BeforeSaveInterface); ok {
|
||||
called = true
|
||||
|
@ -46,19 +44,9 @@ func BeforeUpdate(db *gorm.DB) {
|
|||
db.AddError(i.BeforeUpdate(tx))
|
||||
}
|
||||
}
|
||||
return called
|
||||
}
|
||||
|
||||
if ok := callMethod(db.Statement.Dest); !ok {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
callMethod(db.Statement.ReflectValue.Index(i).Addr().Interface())
|
||||
}
|
||||
case reflect.Struct:
|
||||
callMethod(db.Statement.ReflectValue.Addr().Interface())
|
||||
}
|
||||
}
|
||||
return called
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -99,9 +87,7 @@ func Update(db *gorm.DB) {
|
|||
|
||||
func AfterUpdate(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
|
||||
tx := db.Session(&gorm.Session{})
|
||||
callMethod := func(value interface{}) bool {
|
||||
var called bool
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||
if db.Statement.Schema.AfterSave {
|
||||
if i, ok := value.(gorm.AfterSaveInterface); ok {
|
||||
called = true
|
||||
|
@ -116,18 +102,7 @@ func AfterUpdate(db *gorm.DB) {
|
|||
}
|
||||
}
|
||||
return called
|
||||
}
|
||||
|
||||
if ok := callMethod(db.Statement.Dest); !ok {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
callMethod(db.Statement.ReflectValue.Index(i).Addr().Interface())
|
||||
}
|
||||
case reflect.Struct:
|
||||
callMethod(db.Statement.ReflectValue.Addr().Interface())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
|
40
gorm.go
40
gorm.go
|
@ -25,9 +25,10 @@ type Config struct {
|
|||
NowFunc func() time.Time
|
||||
// DryRun generate sql without execute
|
||||
DryRun bool
|
||||
|
||||
// PrepareStmt executes the given query in cached statement
|
||||
PrepareStmt bool
|
||||
// DisableAutomaticPing
|
||||
DisableAutomaticPing bool
|
||||
|
||||
// ClauseBuilders clause builder
|
||||
ClauseBuilders map[string]clause.ClauseBuilder
|
||||
|
@ -93,8 +94,8 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) {
|
|||
config.ClauseBuilders = map[string]clause.ClauseBuilder{}
|
||||
}
|
||||
|
||||
if dialector != nil {
|
||||
err = dialector.Initialize(db)
|
||||
if config.Dialector != nil {
|
||||
err = config.Dialector.Initialize(db)
|
||||
}
|
||||
|
||||
if config.PrepareStmt {
|
||||
|
@ -104,16 +105,14 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) {
|
|||
}
|
||||
}
|
||||
|
||||
if db.Statement == nil {
|
||||
db.Statement = &Statement{
|
||||
DB: db,
|
||||
ConnPool: db.ConnPool,
|
||||
Context: context.Background(),
|
||||
Clauses: map[string]clause.Clause{},
|
||||
}
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
if err == nil && !config.DisableAutomaticPing {
|
||||
if pinger, ok := db.ConnPool.(interface{ Ping() error }); ok {
|
||||
err = pinger.Ping()
|
||||
}
|
||||
|
@ -138,17 +137,8 @@ func (db *DB) Session(config *Session) *DB {
|
|||
)
|
||||
|
||||
if config.Context != nil {
|
||||
if tx.Statement != nil {
|
||||
tx.Statement = tx.Statement.clone()
|
||||
tx.Statement.DB = tx
|
||||
} else {
|
||||
tx.Statement = &Statement{
|
||||
DB: tx,
|
||||
Clauses: map[string]clause.Clause{},
|
||||
ConnPool: tx.ConnPool,
|
||||
}
|
||||
}
|
||||
|
||||
tx.Statement.Context = config.Context
|
||||
}
|
||||
|
||||
|
@ -160,7 +150,7 @@ func (db *DB) Session(config *Session) *DB {
|
|||
}
|
||||
|
||||
if config.WithConditions {
|
||||
tx.clone = 3
|
||||
tx.clone = 2
|
||||
}
|
||||
|
||||
if config.DryRun {
|
||||
|
@ -200,10 +190,7 @@ func (db *DB) Set(key string, value interface{}) *DB {
|
|||
|
||||
// Get get value with key from current db instance's context
|
||||
func (db *DB) Get(key string) (interface{}, bool) {
|
||||
if db.Statement != nil {
|
||||
return db.Statement.Settings.Load(key)
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// InstanceSet store value with key into current db instance's context
|
||||
|
@ -215,10 +202,7 @@ func (db *DB) InstanceSet(key string, value interface{}) *DB {
|
|||
|
||||
// InstanceGet get value with key from current db instance's context
|
||||
func (db *DB) InstanceGet(key string) (interface{}, bool) {
|
||||
if db.Statement != nil {
|
||||
return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key)
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error {
|
||||
|
@ -282,23 +266,19 @@ func (db *DB) getInstance() *DB {
|
|||
if db.clone > 0 {
|
||||
tx := &DB{Config: db.Config}
|
||||
|
||||
switch db.clone {
|
||||
case 1: // clone with new statement
|
||||
if db.clone == 1 {
|
||||
// clone with new statement
|
||||
tx.Statement = &Statement{
|
||||
DB: tx,
|
||||
ConnPool: db.Statement.ConnPool,
|
||||
Context: db.Statement.Context,
|
||||
Clauses: map[string]clause.Clause{},
|
||||
}
|
||||
case 2: // with old statement, generate new statement for future call, used to pass to callbacks
|
||||
db.clone = 1
|
||||
tx.Statement = db.Statement
|
||||
case 3: // with clone statement
|
||||
if db.Statement != nil {
|
||||
} else {
|
||||
// with clone statement
|
||||
tx.Statement = db.Statement.clone()
|
||||
tx.Statement.DB = tx
|
||||
}
|
||||
}
|
||||
|
||||
return tx
|
||||
}
|
||||
|
|
|
@ -62,7 +62,7 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
|
|||
expr.SQL += " UNIQUE"
|
||||
}
|
||||
|
||||
if field.HasDefaultValue {
|
||||
if field.HasDefaultValue && field.DefaultValue != "" {
|
||||
if field.DataType == schema.String {
|
||||
defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValue}}
|
||||
m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValue)
|
||||
|
|
|
@ -235,7 +235,7 @@ func TestParseFieldWithPermission(t *testing.T) {
|
|||
}
|
||||
|
||||
fields := []schema.Field{
|
||||
{Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, Creatable: true, Updatable: true, Readable: true},
|
||||
{Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, Creatable: true, Updatable: true, Readable: true, HasDefaultValue: true},
|
||||
{Name: "Name", DBName: "", BindNames: []string{"Name"}, DataType: "", Tag: `gorm:"-"`, Creatable: false, Updatable: false, Readable: false},
|
||||
{Name: "Name2", DBName: "name2", BindNames: []string{"Name2"}, DataType: schema.String, Tag: `gorm:"->"`, Creatable: false, Updatable: false, Readable: true},
|
||||
{Name: "Name3", DBName: "name3", BindNames: []string{"Name3"}, DataType: schema.String, Tag: `gorm:"<-"`, Creatable: true, Updatable: true, Readable: true},
|
||||
|
|
|
@ -26,7 +26,7 @@ type Schema struct {
|
|||
Fields []*Field
|
||||
FieldsByName map[string]*Field
|
||||
FieldsByDBName map[string]*Field
|
||||
FieldsWithDefaultDBValue map[string]*Field // fields with default value assigned by database
|
||||
FieldsWithDefaultDBValue []*Field // fields with default value assigned by database
|
||||
Relationships Relationships
|
||||
CreateClauses []clause.Interface
|
||||
QueryClauses []clause.Interface
|
||||
|
@ -153,23 +153,14 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
|
|||
schema.FieldsByName[field.Name] = field
|
||||
|
||||
if v != nil && v.PrimaryKey {
|
||||
if schema.PrioritizedPrimaryField == v {
|
||||
schema.PrioritizedPrimaryField = nil
|
||||
}
|
||||
|
||||
for idx, f := range schema.PrimaryFields {
|
||||
if f == v {
|
||||
schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...)
|
||||
} else if schema.PrioritizedPrimaryField == nil {
|
||||
schema.PrioritizedPrimaryField = f
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if field.PrimaryKey {
|
||||
if schema.PrioritizedPrimaryField == nil {
|
||||
schema.PrioritizedPrimaryField = field
|
||||
}
|
||||
schema.PrimaryFields = append(schema.PrimaryFields, field)
|
||||
}
|
||||
}
|
||||
|
@ -192,21 +183,27 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
|
|||
}
|
||||
}
|
||||
|
||||
if schema.PrioritizedPrimaryField == nil && len(schema.PrimaryFields) == 1 {
|
||||
schema.PrioritizedPrimaryField = schema.PrimaryFields[0]
|
||||
}
|
||||
|
||||
for _, field := range schema.PrimaryFields {
|
||||
schema.PrimaryFieldDBNames = append(schema.PrimaryFieldDBNames, field.DBName)
|
||||
}
|
||||
|
||||
schema.FieldsWithDefaultDBValue = map[string]*Field{}
|
||||
for db, field := range schema.FieldsByDBName {
|
||||
for _, field := range schema.FieldsByDBName {
|
||||
if field.HasDefaultValue && field.DefaultValueInterface == nil {
|
||||
schema.FieldsWithDefaultDBValue[db] = field
|
||||
schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field)
|
||||
}
|
||||
}
|
||||
|
||||
if schema.PrioritizedPrimaryField != nil {
|
||||
switch schema.PrioritizedPrimaryField.DataType {
|
||||
if field := schema.PrioritizedPrimaryField; field != nil {
|
||||
switch field.DataType {
|
||||
case Int, Uint:
|
||||
schema.FieldsWithDefaultDBValue[schema.PrioritizedPrimaryField.DBName] = schema.PrioritizedPrimaryField
|
||||
if !field.HasDefaultValue || field.DefaultValueInterface != nil {
|
||||
schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field)
|
||||
}
|
||||
field.HasDefaultValue = true
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ func checkUserSchema(t *testing.T, user *schema.Schema) {
|
|||
|
||||
// check fields
|
||||
fields := []schema.Field{
|
||||
{Name: "ID", DBName: "id", BindNames: []string{"Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Tag: `gorm:"primarykey"`, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}, Size: 64},
|
||||
{Name: "ID", DBName: "id", BindNames: []string{"Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Tag: `gorm:"primarykey"`, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}, Size: 64, HasDefaultValue: true},
|
||||
{Name: "CreatedAt", DBName: "created_at", BindNames: []string{"Model", "CreatedAt"}, DataType: schema.Time},
|
||||
{Name: "UpdatedAt", DBName: "updated_at", BindNames: []string{"Model", "UpdatedAt"}, DataType: schema.Time},
|
||||
{Name: "DeletedAt", DBName: "deleted_at", BindNames: []string{"Model", "DeletedAt"}, Tag: `gorm:"index"`, DataType: schema.Time},
|
||||
|
@ -125,7 +125,7 @@ func TestParseSchemaWithAdvancedDataType(t *testing.T) {
|
|||
|
||||
// check fields
|
||||
fields := []schema.Field{
|
||||
{Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Int, PrimaryKey: true, Size: 64},
|
||||
{Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Int, PrimaryKey: true, Size: 64, HasDefaultValue: true},
|
||||
{Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String},
|
||||
{Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time},
|
||||
{Name: "RegisteredAt", DBName: "registered_at", BindNames: []string{"RegisteredAt"}, DataType: schema.Time},
|
||||
|
|
42
statement.go
42
statement.go
|
@ -226,6 +226,7 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con
|
|||
if sql == "" && len(args) == 0 {
|
||||
return
|
||||
} else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") {
|
||||
// looks like a where condition
|
||||
return []clause.Expression{clause.Expr{SQL: sql, Vars: args}}
|
||||
} else if len(args) == 1 {
|
||||
return []clause.Expression{clause.Eq{Column: sql, Value: args[0]}}
|
||||
|
@ -242,12 +243,6 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con
|
|||
switch v := arg.(type) {
|
||||
case clause.Expression:
|
||||
conds = append(conds, v)
|
||||
case *DB:
|
||||
if v.Statement != nil {
|
||||
if cs, ok := v.Statement.Clauses["WHERE"]; ok {
|
||||
conds = append(conds, cs.Expression)
|
||||
}
|
||||
}
|
||||
case map[interface{}]interface{}:
|
||||
for i, j := range v {
|
||||
conds = append(conds, clause.Eq{Column: i, Value: j})
|
||||
|
@ -326,7 +321,6 @@ func (stmt *Statement) Parse(value interface{}) (err error) {
|
|||
|
||||
func (stmt *Statement) clone() *Statement {
|
||||
newStmt := &Statement{
|
||||
DB: stmt.DB,
|
||||
Table: stmt.Table,
|
||||
Model: stmt.Model,
|
||||
Dest: stmt.Dest,
|
||||
|
@ -357,37 +351,3 @@ func (stmt *Statement) clone() *Statement {
|
|||
|
||||
return newStmt
|
||||
}
|
||||
|
||||
func (stmt *Statement) reinit() {
|
||||
// stmt.Table = ""
|
||||
// stmt.Model = nil
|
||||
// stmt.Selects = nil
|
||||
// stmt.Omits = nil
|
||||
// stmt.ConnPool = stmt.DB.Config.ConnPool
|
||||
// stmt.Context = context.Background()
|
||||
// stmt.RaiseErrorOnNotFound = false
|
||||
|
||||
// for k := range stmt.Clauses {
|
||||
// delete(stmt.Clauses, k)
|
||||
// }
|
||||
|
||||
// for k := range stmt.Joins {
|
||||
// delete(stmt.Joins, k)
|
||||
// }
|
||||
|
||||
// for k := range stmt.Preloads {
|
||||
// delete(stmt.Preloads, k)
|
||||
// }
|
||||
|
||||
// stmt.Settings.Range(func(k, _ interface{}) bool {
|
||||
// stmt.Settings.Delete(k)
|
||||
// return true
|
||||
// })
|
||||
|
||||
// stmt.Schema = nil
|
||||
if !stmt.DB.DryRun {
|
||||
stmt.SQL.Reset()
|
||||
stmt.Vars = nil
|
||||
stmt.NamedVars = nil
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
package tests_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func BenchmarkCreate(b *testing.B) {
|
||||
var user = *GetUser("bench", Config{})
|
||||
|
||||
for x := 0; x < b.N; x++ {
|
||||
user.ID = 0
|
||||
DB.Create(&user)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkFind(b *testing.B) {
|
||||
var user = *GetUser("find", Config{})
|
||||
DB.Create(&user)
|
||||
|
||||
for x := 0; x < b.N; x++ {
|
||||
DB.Find(&User{}, "id = ?", user.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkUpdate(b *testing.B) {
|
||||
var user = *GetUser("find", Config{})
|
||||
DB.Create(&user)
|
||||
|
||||
for x := 0; x < b.N; x++ {
|
||||
DB.Model(&user).Updates(map[string]interface{}{"Age": x})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDelete(b *testing.B) {
|
||||
var user = *GetUser("find", Config{})
|
||||
|
||||
for x := 0; x < b.N; x++ {
|
||||
user.ID = 0
|
||||
DB.Create(&user)
|
||||
DB.Delete(&user)
|
||||
}
|
||||
}
|
|
@ -7,7 +7,7 @@ require (
|
|||
gorm.io/driver/mysql v0.0.0-20200602015408-0407d0c21cf0
|
||||
gorm.io/driver/postgres v0.0.0-20200602015520-15fcc29eb286
|
||||
gorm.io/driver/sqlite v1.0.0
|
||||
gorm.io/driver/sqlserver v0.0.0-20200602144728-79c224f6c1a2
|
||||
gorm.io/driver/sqlserver v0.0.0-20200605135528-04ae0f7a15bf
|
||||
gorm.io/gorm v0.0.0-00010101000000-000000000000
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue