From af080e677317015c36070227e889c2943f92752a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 12 Mar 2020 08:39:42 +0800 Subject: [PATCH] Fix primary key tag --- callbacks.go | 2 -- chainable_api.go | 3 ++- clause/from.go | 41 -------------------------------- clause/joins.go | 44 +++++++++++++++++++++++++++++++---- dialects/mssql/mssql.go | 2 +- dialects/mysql/mysql.go | 6 ++++- dialects/mysql/mysql_test.go | 2 +- dialects/postgres/postgres.go | 2 +- logger/sql.go | 2 +- schema/field.go | 2 +- statement.go | 14 ++++------- tests/model.go | 2 +- tests/tests.go | 2 +- 13 files changed, 58 insertions(+), 66 deletions(-) diff --git a/callbacks.go b/callbacks.go index e1b2b410..78f1192e 100644 --- a/callbacks.go +++ b/callbacks.go @@ -90,8 +90,6 @@ func (p *processor) Execute(db *DB) { } if stmt := db.Statement; stmt != nil { - db.RowsAffected = stmt.RowsAffected - db.Logger.Trace(curTime, func() (string, int64) { return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected }, db.Error) diff --git a/chainable_api.go b/chainable_api.go index 432caa4f..7a6e8b7c 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -108,13 +108,14 @@ func (db *DB) Omit(columns ...string) (tx *DB) { return } +// Where add conditions func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(query, args...)}) return } -// Not add NOT condition +// Not add NOT conditions func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Not(tx.Statement.BuildCondtion(query, args...)...)}}) diff --git a/clause/from.go b/clause/from.go index 5e8c5d25..59b0bfaf 100644 --- a/clause/from.go +++ b/clause/from.go @@ -6,23 +6,6 @@ type From struct { Joins []Join } -type JoinType string - -const ( - CrossJoin JoinType = "CROSS" - InnerJoin = "INNER" - LeftJoin = "LEFT" - RightJoin = "RIGHT" -) - -// Join join clause for from -type Join struct { - Type JoinType - Table Table - ON Where - Using []string -} - // Name from clause name func (from From) Name() string { return "FROM" @@ -48,30 +31,6 @@ func (from From) Build(builder Builder) { } } -func (join Join) Build(builder Builder) { - if join.Type != "" { - builder.WriteString(string(join.Type)) - builder.WriteByte(' ') - } - - builder.WriteString("JOIN ") - builder.WriteQuoted(join.Table) - - if len(join.ON.Exprs) > 0 { - builder.WriteString(" ON ") - join.ON.Build(builder) - } else if len(join.Using) > 0 { - builder.WriteString(" USING (") - for idx, c := range join.Using { - if idx > 0 { - builder.WriteByte(',') - } - builder.WriteQuoted(c) - } - builder.WriteByte(')') - } -} - // MergeClause merge from clause func (from From) MergeClause(clause *Clause) { if v, ok := clause.Expression.(From); ok { diff --git a/clause/joins.go b/clause/joins.go index 4983d6fd..a78bde39 100644 --- a/clause/joins.go +++ b/clause/joins.go @@ -1,8 +1,42 @@ package clause -// Joins joins clause -type Joins struct { - Name string - Query string - Vars []interface{} +type JoinType string + +const ( + CrossJoin JoinType = "CROSS" + InnerJoin = "INNER" + LeftJoin = "LEFT" + RightJoin = "RIGHT" +) + +// Join join clause for from +type Join struct { + Type JoinType + Table Table + ON Where + Using []string +} + +func (join Join) Build(builder Builder) { + if join.Type != "" { + builder.WriteString(string(join.Type)) + builder.WriteByte(' ') + } + + builder.WriteString("JOIN ") + builder.WriteQuoted(join.Table) + + if len(join.ON.Exprs) > 0 { + builder.WriteString(" ON ") + join.ON.Build(builder) + } else if len(join.Using) > 0 { + builder.WriteString(" USING (") + for idx, c := range join.Using { + if idx > 0 { + builder.WriteByte(',') + } + builder.WriteQuoted(c) + } + builder.WriteByte(')') + } } diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index 8cf1e2e2..e5bc7dd2 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -70,7 +70,7 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string { sqlType = "bigint" } - if field.AutoIncrement { + if field.AutoIncrement || field == field.Schema.PrioritizedPrimaryField { return sqlType + " IDENTITY(1,1)" } return sqlType diff --git a/dialects/mysql/mysql.go b/dialects/mysql/mysql.go index 514dfc14..af796847 100644 --- a/dialects/mysql/mysql.go +++ b/dialects/mysql/mysql.go @@ -71,7 +71,7 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string { sqlType += " unsigned" } - if field.AutoIncrement { + if field.AutoIncrement || field == field.Schema.PrioritizedPrimaryField { sqlType += " AUTO_INCREMENT" } return sqlType @@ -94,6 +94,10 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string { return fmt.Sprintf("varchar(%d)", size) case schema.Time: precision := "" + if field.Precision == 0 { + field.Precision = 3 + } + if field.Precision > 0 { precision = fmt.Sprintf("(%d)", field.Precision) } diff --git a/dialects/mysql/mysql_test.go b/dialects/mysql/mysql_test.go index 5bc1debd..cb3b240a 100644 --- a/dialects/mysql/mysql_test.go +++ b/dialects/mysql/mysql_test.go @@ -16,7 +16,7 @@ var ( ) func init() { - dsn := "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" + dsn := "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local" if os.Getenv("GORM_DSN") != "" { dsn = os.Getenv("GORM_DSN") } diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go index c2ddd82c..7589025d 100644 --- a/dialects/postgres/postgres.go +++ b/dialects/postgres/postgres.go @@ -60,7 +60,7 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string { case schema.Bool: return "boolean" case schema.Int, schema.Uint: - if field.AutoIncrement { + if field.AutoIncrement || field == field.Schema.PrioritizedPrimaryField { switch { case field.Size < 16: return "smallserial" diff --git a/logger/sql.go b/logger/sql.go index cb50ccf6..41c514fd 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -33,7 +33,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v if v.IsZero() { vars[idx] = escaper + "0000-00-00 00:00:00" + escaper } else { - vars[idx] = escaper + v.Format("2006-01-02 15:04:05") + escaper + vars[idx] = escaper + v.Format("2006-01-02 15:04:05.999") + escaper } case []byte: if isPrintable(v) { diff --git a/schema/field.go b/schema/field.go index c6de669d..ee1baf3c 100644 --- a/schema/field.go +++ b/schema/field.go @@ -219,7 +219,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } if field.Size == 0 { - switch fieldValue.Kind() { + switch reflect.Indirect(fieldValue).Kind() { case reflect.Int, reflect.Int64, reflect.Uint, reflect.Uint64, reflect.Float64: field.Size = 64 case reflect.Int8, reflect.Uint8: diff --git a/statement.go b/statement.go index fb3599ec..298a4c56 100644 --- a/statement.go +++ b/statement.go @@ -28,17 +28,15 @@ type Statement struct { ConnPool ConnPool Schema *schema.Schema Context context.Context - Error error - RowsAffected int64 RaiseErrorOnNotFound bool SQL strings.Builder Vars []interface{} NamedVars []sql.NamedArg } -// StatementOptimizer statement optimizer interface -type StatementOptimizer interface { - OptimizeStatement(*Statement) +// StatementModifier statement modifier interface +type StatementModifier interface { + ModifyStatement(*Statement) } // Write write string @@ -144,8 +142,8 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { // AddClause add clause func (stmt *Statement) AddClause(v clause.Interface) { - if optimizer, ok := v.(StatementOptimizer); ok { - optimizer.OptimizeStatement(stmt) + if optimizer, ok := v.(StatementModifier); ok { + optimizer.ModifyStatement(stmt) } c, ok := stmt.Clauses[v.Name()] @@ -255,8 +253,6 @@ func (stmt *Statement) reinit() { stmt.ConnPool = stmt.DB.Config.ConnPool stmt.Schema = nil stmt.Context = context.Background() - stmt.Error = nil - stmt.RowsAffected = 0 stmt.RaiseErrorOnNotFound = false stmt.SQL.Reset() diff --git a/tests/model.go b/tests/model.go index b2d5efe1..4d686a57 100644 --- a/tests/model.go +++ b/tests/model.go @@ -21,7 +21,7 @@ type User struct { Toys []Toy `gorm:"polymorphic:Owner"` CompanyID *int Company Company - ManagerID uint + ManagerID *uint Manager *User Team []User `gorm:"foreignkey:ManagerID"` Languages []Language `gorm:"many2many:UserSpeak"` diff --git a/tests/tests.go b/tests/tests.go index 33013032..c26d743e 100644 --- a/tests/tests.go +++ b/tests/tests.go @@ -81,7 +81,7 @@ func TestFind(t *testing.T, db *gorm.DB) { }} if err := db.Create(&users).Error; err != nil { - t.Errorf("errors happened when create users: %v", err) + t.Fatal("errors happened when create users: %v", err) } t.Run("First", func(t *testing.T) {