Fix primary key tag

This commit is contained in:
Jinzhu 2020-03-12 08:39:42 +08:00
parent 9e8a4db36b
commit af080e6773
13 changed files with 58 additions and 66 deletions

View File

@ -90,8 +90,6 @@ func (p *processor) Execute(db *DB) {
} }
if stmt := db.Statement; stmt != nil { if stmt := db.Statement; stmt != nil {
db.RowsAffected = stmt.RowsAffected
db.Logger.Trace(curTime, func() (string, int64) { db.Logger.Trace(curTime, func() (string, int64) {
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected
}, db.Error) }, db.Error)

View File

@ -108,13 +108,14 @@ func (db *DB) Omit(columns ...string) (tx *DB) {
return return
} }
// Where add conditions
func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(query, args...)}) tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(query, args...)})
return return
} }
// Not add NOT condition // Not add NOT conditions
func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Not(tx.Statement.BuildCondtion(query, args...)...)}}) tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Not(tx.Statement.BuildCondtion(query, args...)...)}})

View File

@ -6,23 +6,6 @@ type From struct {
Joins []Join Joins []Join
} }
type JoinType string
const (
CrossJoin JoinType = "CROSS"
InnerJoin = "INNER"
LeftJoin = "LEFT"
RightJoin = "RIGHT"
)
// Join join clause for from
type Join struct {
Type JoinType
Table Table
ON Where
Using []string
}
// Name from clause name // Name from clause name
func (from From) Name() string { func (from From) Name() string {
return "FROM" return "FROM"
@ -48,30 +31,6 @@ func (from From) Build(builder Builder) {
} }
} }
func (join Join) Build(builder Builder) {
if join.Type != "" {
builder.WriteString(string(join.Type))
builder.WriteByte(' ')
}
builder.WriteString("JOIN ")
builder.WriteQuoted(join.Table)
if len(join.ON.Exprs) > 0 {
builder.WriteString(" ON ")
join.ON.Build(builder)
} else if len(join.Using) > 0 {
builder.WriteString(" USING (")
for idx, c := range join.Using {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(c)
}
builder.WriteByte(')')
}
}
// MergeClause merge from clause // MergeClause merge from clause
func (from From) MergeClause(clause *Clause) { func (from From) MergeClause(clause *Clause) {
if v, ok := clause.Expression.(From); ok { if v, ok := clause.Expression.(From); ok {

View File

@ -1,8 +1,42 @@
package clause package clause
// Joins joins clause type JoinType string
type Joins struct {
Name string const (
Query string CrossJoin JoinType = "CROSS"
Vars []interface{} InnerJoin = "INNER"
LeftJoin = "LEFT"
RightJoin = "RIGHT"
)
// Join join clause for from
type Join struct {
Type JoinType
Table Table
ON Where
Using []string
}
func (join Join) Build(builder Builder) {
if join.Type != "" {
builder.WriteString(string(join.Type))
builder.WriteByte(' ')
}
builder.WriteString("JOIN ")
builder.WriteQuoted(join.Table)
if len(join.ON.Exprs) > 0 {
builder.WriteString(" ON ")
join.ON.Build(builder)
} else if len(join.Using) > 0 {
builder.WriteString(" USING (")
for idx, c := range join.Using {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(c)
}
builder.WriteByte(')')
}
} }

View File

@ -70,7 +70,7 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string {
sqlType = "bigint" sqlType = "bigint"
} }
if field.AutoIncrement { if field.AutoIncrement || field == field.Schema.PrioritizedPrimaryField {
return sqlType + " IDENTITY(1,1)" return sqlType + " IDENTITY(1,1)"
} }
return sqlType return sqlType

View File

@ -71,7 +71,7 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string {
sqlType += " unsigned" sqlType += " unsigned"
} }
if field.AutoIncrement { if field.AutoIncrement || field == field.Schema.PrioritizedPrimaryField {
sqlType += " AUTO_INCREMENT" sqlType += " AUTO_INCREMENT"
} }
return sqlType return sqlType
@ -94,6 +94,10 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string {
return fmt.Sprintf("varchar(%d)", size) return fmt.Sprintf("varchar(%d)", size)
case schema.Time: case schema.Time:
precision := "" precision := ""
if field.Precision == 0 {
field.Precision = 3
}
if field.Precision > 0 { if field.Precision > 0 {
precision = fmt.Sprintf("(%d)", field.Precision) precision = fmt.Sprintf("(%d)", field.Precision)
} }

View File

@ -16,7 +16,7 @@ var (
) )
func init() { func init() {
dsn := "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" dsn := "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local"
if os.Getenv("GORM_DSN") != "" { if os.Getenv("GORM_DSN") != "" {
dsn = os.Getenv("GORM_DSN") dsn = os.Getenv("GORM_DSN")
} }

View File

@ -60,7 +60,7 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string {
case schema.Bool: case schema.Bool:
return "boolean" return "boolean"
case schema.Int, schema.Uint: case schema.Int, schema.Uint:
if field.AutoIncrement { if field.AutoIncrement || field == field.Schema.PrioritizedPrimaryField {
switch { switch {
case field.Size < 16: case field.Size < 16:
return "smallserial" return "smallserial"

View File

@ -33,7 +33,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v
if v.IsZero() { if v.IsZero() {
vars[idx] = escaper + "0000-00-00 00:00:00" + escaper vars[idx] = escaper + "0000-00-00 00:00:00" + escaper
} else { } else {
vars[idx] = escaper + v.Format("2006-01-02 15:04:05") + escaper vars[idx] = escaper + v.Format("2006-01-02 15:04:05.999") + escaper
} }
case []byte: case []byte:
if isPrintable(v) { if isPrintable(v) {

View File

@ -219,7 +219,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
} }
if field.Size == 0 { if field.Size == 0 {
switch fieldValue.Kind() { switch reflect.Indirect(fieldValue).Kind() {
case reflect.Int, reflect.Int64, reflect.Uint, reflect.Uint64, reflect.Float64: case reflect.Int, reflect.Int64, reflect.Uint, reflect.Uint64, reflect.Float64:
field.Size = 64 field.Size = 64
case reflect.Int8, reflect.Uint8: case reflect.Int8, reflect.Uint8:

View File

@ -28,17 +28,15 @@ type Statement struct {
ConnPool ConnPool ConnPool ConnPool
Schema *schema.Schema Schema *schema.Schema
Context context.Context Context context.Context
Error error
RowsAffected int64
RaiseErrorOnNotFound bool RaiseErrorOnNotFound bool
SQL strings.Builder SQL strings.Builder
Vars []interface{} Vars []interface{}
NamedVars []sql.NamedArg NamedVars []sql.NamedArg
} }
// StatementOptimizer statement optimizer interface // StatementModifier statement modifier interface
type StatementOptimizer interface { type StatementModifier interface {
OptimizeStatement(*Statement) ModifyStatement(*Statement)
} }
// Write write string // Write write string
@ -144,8 +142,8 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
// AddClause add clause // AddClause add clause
func (stmt *Statement) AddClause(v clause.Interface) { func (stmt *Statement) AddClause(v clause.Interface) {
if optimizer, ok := v.(StatementOptimizer); ok { if optimizer, ok := v.(StatementModifier); ok {
optimizer.OptimizeStatement(stmt) optimizer.ModifyStatement(stmt)
} }
c, ok := stmt.Clauses[v.Name()] c, ok := stmt.Clauses[v.Name()]
@ -255,8 +253,6 @@ func (stmt *Statement) reinit() {
stmt.ConnPool = stmt.DB.Config.ConnPool stmt.ConnPool = stmt.DB.Config.ConnPool
stmt.Schema = nil stmt.Schema = nil
stmt.Context = context.Background() stmt.Context = context.Background()
stmt.Error = nil
stmt.RowsAffected = 0
stmt.RaiseErrorOnNotFound = false stmt.RaiseErrorOnNotFound = false
stmt.SQL.Reset() stmt.SQL.Reset()

View File

@ -21,7 +21,7 @@ type User struct {
Toys []Toy `gorm:"polymorphic:Owner"` Toys []Toy `gorm:"polymorphic:Owner"`
CompanyID *int CompanyID *int
Company Company Company Company
ManagerID uint ManagerID *uint
Manager *User Manager *User
Team []User `gorm:"foreignkey:ManagerID"` Team []User `gorm:"foreignkey:ManagerID"`
Languages []Language `gorm:"many2many:UserSpeak"` Languages []Language `gorm:"many2many:UserSpeak"`

View File

@ -81,7 +81,7 @@ func TestFind(t *testing.T, db *gorm.DB) {
}} }}
if err := db.Create(&users).Error; err != nil { if err := db.Create(&users).Error; err != nil {
t.Errorf("errors happened when create users: %v", err) t.Fatal("errors happened when create users: %v", err)
} }
t.Run("First", func(t *testing.T) { t.Run("First", func(t *testing.T) {