Fix primary key tag

This commit is contained in:
Jinzhu 2020-03-12 08:39:42 +08:00
parent 9e8a4db36b
commit af080e6773
13 changed files with 58 additions and 66 deletions

View File

@ -90,8 +90,6 @@ func (p *processor) Execute(db *DB) {
}
if stmt := db.Statement; stmt != nil {
db.RowsAffected = stmt.RowsAffected
db.Logger.Trace(curTime, func() (string, int64) {
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected
}, db.Error)

View File

@ -108,13 +108,14 @@ func (db *DB) Omit(columns ...string) (tx *DB) {
return
}
// Where add conditions
func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(query, args...)})
return
}
// Not add NOT condition
// Not add NOT conditions
func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Not(tx.Statement.BuildCondtion(query, args...)...)}})

View File

@ -6,23 +6,6 @@ type From struct {
Joins []Join
}
type JoinType string
const (
CrossJoin JoinType = "CROSS"
InnerJoin = "INNER"
LeftJoin = "LEFT"
RightJoin = "RIGHT"
)
// Join join clause for from
type Join struct {
Type JoinType
Table Table
ON Where
Using []string
}
// Name from clause name
func (from From) Name() string {
return "FROM"
@ -48,30 +31,6 @@ func (from From) Build(builder Builder) {
}
}
func (join Join) Build(builder Builder) {
if join.Type != "" {
builder.WriteString(string(join.Type))
builder.WriteByte(' ')
}
builder.WriteString("JOIN ")
builder.WriteQuoted(join.Table)
if len(join.ON.Exprs) > 0 {
builder.WriteString(" ON ")
join.ON.Build(builder)
} else if len(join.Using) > 0 {
builder.WriteString(" USING (")
for idx, c := range join.Using {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(c)
}
builder.WriteByte(')')
}
}
// MergeClause merge from clause
func (from From) MergeClause(clause *Clause) {
if v, ok := clause.Expression.(From); ok {

View File

@ -1,8 +1,42 @@
package clause
// Joins joins clause
type Joins struct {
Name string
Query string
Vars []interface{}
type JoinType string
const (
CrossJoin JoinType = "CROSS"
InnerJoin = "INNER"
LeftJoin = "LEFT"
RightJoin = "RIGHT"
)
// Join join clause for from
type Join struct {
Type JoinType
Table Table
ON Where
Using []string
}
func (join Join) Build(builder Builder) {
if join.Type != "" {
builder.WriteString(string(join.Type))
builder.WriteByte(' ')
}
builder.WriteString("JOIN ")
builder.WriteQuoted(join.Table)
if len(join.ON.Exprs) > 0 {
builder.WriteString(" ON ")
join.ON.Build(builder)
} else if len(join.Using) > 0 {
builder.WriteString(" USING (")
for idx, c := range join.Using {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(c)
}
builder.WriteByte(')')
}
}

View File

@ -70,7 +70,7 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string {
sqlType = "bigint"
}
if field.AutoIncrement {
if field.AutoIncrement || field == field.Schema.PrioritizedPrimaryField {
return sqlType + " IDENTITY(1,1)"
}
return sqlType

View File

@ -71,7 +71,7 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string {
sqlType += " unsigned"
}
if field.AutoIncrement {
if field.AutoIncrement || field == field.Schema.PrioritizedPrimaryField {
sqlType += " AUTO_INCREMENT"
}
return sqlType
@ -94,6 +94,10 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string {
return fmt.Sprintf("varchar(%d)", size)
case schema.Time:
precision := ""
if field.Precision == 0 {
field.Precision = 3
}
if field.Precision > 0 {
precision = fmt.Sprintf("(%d)", field.Precision)
}

View File

@ -16,7 +16,7 @@ var (
)
func init() {
dsn := "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True"
dsn := "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local"
if os.Getenv("GORM_DSN") != "" {
dsn = os.Getenv("GORM_DSN")
}

View File

@ -60,7 +60,7 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string {
case schema.Bool:
return "boolean"
case schema.Int, schema.Uint:
if field.AutoIncrement {
if field.AutoIncrement || field == field.Schema.PrioritizedPrimaryField {
switch {
case field.Size < 16:
return "smallserial"

View File

@ -33,7 +33,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v
if v.IsZero() {
vars[idx] = escaper + "0000-00-00 00:00:00" + escaper
} else {
vars[idx] = escaper + v.Format("2006-01-02 15:04:05") + escaper
vars[idx] = escaper + v.Format("2006-01-02 15:04:05.999") + escaper
}
case []byte:
if isPrintable(v) {

View File

@ -219,7 +219,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
}
if field.Size == 0 {
switch fieldValue.Kind() {
switch reflect.Indirect(fieldValue).Kind() {
case reflect.Int, reflect.Int64, reflect.Uint, reflect.Uint64, reflect.Float64:
field.Size = 64
case reflect.Int8, reflect.Uint8:

View File

@ -28,17 +28,15 @@ type Statement struct {
ConnPool ConnPool
Schema *schema.Schema
Context context.Context
Error error
RowsAffected int64
RaiseErrorOnNotFound bool
SQL strings.Builder
Vars []interface{}
NamedVars []sql.NamedArg
}
// StatementOptimizer statement optimizer interface
type StatementOptimizer interface {
OptimizeStatement(*Statement)
// StatementModifier statement modifier interface
type StatementModifier interface {
ModifyStatement(*Statement)
}
// Write write string
@ -144,8 +142,8 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
// AddClause add clause
func (stmt *Statement) AddClause(v clause.Interface) {
if optimizer, ok := v.(StatementOptimizer); ok {
optimizer.OptimizeStatement(stmt)
if optimizer, ok := v.(StatementModifier); ok {
optimizer.ModifyStatement(stmt)
}
c, ok := stmt.Clauses[v.Name()]
@ -255,8 +253,6 @@ func (stmt *Statement) reinit() {
stmt.ConnPool = stmt.DB.Config.ConnPool
stmt.Schema = nil
stmt.Context = context.Background()
stmt.Error = nil
stmt.RowsAffected = 0
stmt.RaiseErrorOnNotFound = false
stmt.SQL.Reset()

View File

@ -21,7 +21,7 @@ type User struct {
Toys []Toy `gorm:"polymorphic:Owner"`
CompanyID *int
Company Company
ManagerID uint
ManagerID *uint
Manager *User
Team []User `gorm:"foreignkey:ManagerID"`
Languages []Language `gorm:"many2many:UserSpeak"`

View File

@ -81,7 +81,7 @@ func TestFind(t *testing.T, db *gorm.DB) {
}}
if err := db.Create(&users).Error; err != nil {
t.Errorf("errors happened when create users: %v", err)
t.Fatal("errors happened when create users: %v", err)
}
t.Run("First", func(t *testing.T) {