Test CreateWithNoGORMPrimayKey

This commit is contained in:
Jinzhu 2020-05-31 12:52:49 +08:00
parent b3b19a5577
commit 1546f8a4a1
5 changed files with 71 additions and 20 deletions

View File

@ -64,7 +64,7 @@ func Create(config *Config) func(db *gorm.DB) {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err == nil { if err == nil {
if db.Statement.Schema != nil { if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil {
if _, ok := db.Statement.Schema.FieldsWithDefaultDBValue[db.Statement.Schema.PrioritizedPrimaryField.DBName]; ok { if _, ok := db.Statement.Schema.FieldsWithDefaultDBValue[db.Statement.Schema.PrioritizedPrimaryField.DBName]; ok {
if insertID, err := result.LastInsertId(); err == nil { if insertID, err := result.LastInsertId(); err == nil {
switch db.Statement.ReflectValue.Kind() { switch db.Statement.ReflectValue.Kind() {

View File

@ -68,26 +68,30 @@ func Create(db *gorm.DB) {
switch db.Statement.ReflectValue.Kind() { switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
for rows.Next() { if len(db.Statement.Schema.PrimaryFields) > 0 {
// for idx, field := range fields { values := make([]interface{}, len(db.Statement.Schema.PrimaryFields))
// values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface()
// }
values := db.Statement.Schema.PrioritizedPrimaryField.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() for rows.Next() {
if err := rows.Scan(values); err != nil { for idx, field := range db.Statement.Schema.PrimaryFields {
db.AddError(err) values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface()
}
db.RowsAffected++
db.AddError(rows.Scan(values...))
} }
db.RowsAffected++
} }
case reflect.Struct: case reflect.Struct:
// for idx, field := range fields { if len(db.Statement.Schema.PrimaryFields) > 0 {
// values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() values := make([]interface{}, len(db.Statement.Schema.PrimaryFields))
// }
values := db.Statement.Schema.PrioritizedPrimaryField.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
if rows.Next() { for idx, field := range db.Statement.Schema.PrimaryFields {
db.RowsAffected++ values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
db.AddError(rows.Scan(values)) }
if rows.Next() {
db.RowsAffected++
db.AddError(rows.Scan(values...))
}
} }
} }
} else { } else {
@ -177,8 +181,14 @@ func MergeCreate(db *gorm.DB, onConflict clause.OnConflict) {
} }
func outputInserted(db *gorm.DB) { func outputInserted(db *gorm.DB) {
if db.Statement.Schema.PrioritizedPrimaryField != nil { if len(db.Statement.Schema.PrimaryFields) > 0 {
db.Statement.WriteString(" OUTPUT INSERTED.") db.Statement.WriteString(" OUTPUT ")
db.Statement.WriteQuoted(db.Statement.Schema.PrioritizedPrimaryField.DBName) for idx, field := range db.Statement.Schema.PrimaryFields {
if idx > 0 {
db.Statement.WriteString(",")
}
db.Statement.WriteString(" INSERTED.")
db.Statement.AddVar(db.Statement, clause.Column{Name: field.DBName})
}
} }
} }

View File

@ -149,7 +149,7 @@ func (m Migrator) CreateTable(values ...interface{}) error {
createTableSQL += "," createTableSQL += ","
} }
if !hasPrimaryKeyInDataType { if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 {
createTableSQL += "PRIMARY KEY ?," createTableSQL += "PRIMARY KEY ?,"
primaryKeys := []interface{}{} primaryKeys := []interface{}{}
for _, field := range stmt.Schema.PrimaryFields { for _, field := range stmt.Schema.PrimaryFields {

View File

@ -245,3 +245,21 @@ func TestCreateWithNowFuncOverride(t *testing.T) {
AssertEqual(t, newUser.CreatedAt, curTime) AssertEqual(t, newUser.CreatedAt, curTime)
AssertEqual(t, newUser.UpdatedAt, curTime) AssertEqual(t, newUser.UpdatedAt, curTime)
} }
func TestCreateWithNoGORMPrimayKey(t *testing.T) {
type JoinTable struct {
UserID uint
FriendID uint
}
DB.Migrator().DropTable(&JoinTable{})
if err := DB.AutoMigrate(&JoinTable{}); err != nil {
t.Errorf("no error should happen when auto migrate, but got %v", err)
}
jt := JoinTable{UserID: 1, FriendID: 2}
err := DB.Create(&jt).Error
if err != nil {
t.Errorf("No error should happen when create a record without a GORM primary key. But in the database this primary key exists and is the union of 2 or more fields\n But got: %s", err)
}
}

View File

@ -34,6 +34,7 @@ func TestScannerValuer(t *testing.T) {
{"name1", "value1"}, {"name1", "value1"},
{"name2", "value2"}, {"name2", "value2"},
}, },
Role: Role{Name: "admin"},
} }
if err := DB.Create(&data).Error; err != nil { if err := DB.Create(&data).Error; err != nil {
@ -91,6 +92,7 @@ type ScannerValuerStruct struct {
Num Num Num Num
Strings StringsSlice Strings StringsSlice
Structs StructsSlice Structs StructsSlice
Role Role
} }
type EncryptedData []byte type EncryptedData []byte
@ -176,3 +178,24 @@ func (l *StructsSlice) Scan(input interface{}) error {
return errors.New("not supported") return errors.New("not supported")
} }
} }
type Role struct {
Name string `gorm:"size:256"`
}
func (role *Role) Scan(value interface{}) error {
if b, ok := value.([]uint8); ok {
role.Name = string(b)
} else {
role.Name = value.(string)
}
return nil
}
func (role Role) Value() (driver.Value, error) {
return role.Name, nil
}
func (role Role) IsAdmin() bool {
return role.Name == "admin"
}