Test CreateWithNoGORMPrimayKey

This commit is contained in:
Jinzhu 2020-05-31 12:52:49 +08:00
parent b3b19a5577
commit 1546f8a4a1
5 changed files with 71 additions and 20 deletions

View File

@ -64,7 +64,7 @@ func Create(config *Config) func(db *gorm.DB) {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err == nil {
if db.Statement.Schema != nil {
if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil {
if _, ok := db.Statement.Schema.FieldsWithDefaultDBValue[db.Statement.Schema.PrioritizedPrimaryField.DBName]; ok {
if insertID, err := result.LastInsertId(); err == nil {
switch db.Statement.ReflectValue.Kind() {

View File

@ -68,26 +68,30 @@ func Create(db *gorm.DB) {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for rows.Next() {
// for idx, field := range fields {
// values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface()
// }
if len(db.Statement.Schema.PrimaryFields) > 0 {
values := make([]interface{}, len(db.Statement.Schema.PrimaryFields))
values := db.Statement.Schema.PrioritizedPrimaryField.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface()
if err := rows.Scan(values); err != nil {
db.AddError(err)
for rows.Next() {
for idx, field := range db.Statement.Schema.PrimaryFields {
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface()
}
db.RowsAffected++
db.AddError(rows.Scan(values...))
}
}
case reflect.Struct:
// for idx, field := range fields {
// values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
// }
values := db.Statement.Schema.PrioritizedPrimaryField.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
if len(db.Statement.Schema.PrimaryFields) > 0 {
values := make([]interface{}, len(db.Statement.Schema.PrimaryFields))
for idx, field := range db.Statement.Schema.PrimaryFields {
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
}
if rows.Next() {
db.RowsAffected++
db.AddError(rows.Scan(values))
db.AddError(rows.Scan(values...))
}
}
}
} else {
@ -177,8 +181,14 @@ func MergeCreate(db *gorm.DB, onConflict clause.OnConflict) {
}
func outputInserted(db *gorm.DB) {
if db.Statement.Schema.PrioritizedPrimaryField != nil {
db.Statement.WriteString(" OUTPUT INSERTED.")
db.Statement.WriteQuoted(db.Statement.Schema.PrioritizedPrimaryField.DBName)
if len(db.Statement.Schema.PrimaryFields) > 0 {
db.Statement.WriteString(" OUTPUT ")
for idx, field := range db.Statement.Schema.PrimaryFields {
if idx > 0 {
db.Statement.WriteString(",")
}
db.Statement.WriteString(" INSERTED.")
db.Statement.AddVar(db.Statement, clause.Column{Name: field.DBName})
}
}
}

View File

@ -149,7 +149,7 @@ func (m Migrator) CreateTable(values ...interface{}) error {
createTableSQL += ","
}
if !hasPrimaryKeyInDataType {
if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 {
createTableSQL += "PRIMARY KEY ?,"
primaryKeys := []interface{}{}
for _, field := range stmt.Schema.PrimaryFields {

View File

@ -245,3 +245,21 @@ func TestCreateWithNowFuncOverride(t *testing.T) {
AssertEqual(t, newUser.CreatedAt, curTime)
AssertEqual(t, newUser.UpdatedAt, curTime)
}
func TestCreateWithNoGORMPrimayKey(t *testing.T) {
type JoinTable struct {
UserID uint
FriendID uint
}
DB.Migrator().DropTable(&JoinTable{})
if err := DB.AutoMigrate(&JoinTable{}); err != nil {
t.Errorf("no error should happen when auto migrate, but got %v", err)
}
jt := JoinTable{UserID: 1, FriendID: 2}
err := DB.Create(&jt).Error
if err != nil {
t.Errorf("No error should happen when create a record without a GORM primary key. But in the database this primary key exists and is the union of 2 or more fields\n But got: %s", err)
}
}

View File

@ -34,6 +34,7 @@ func TestScannerValuer(t *testing.T) {
{"name1", "value1"},
{"name2", "value2"},
},
Role: Role{Name: "admin"},
}
if err := DB.Create(&data).Error; err != nil {
@ -91,6 +92,7 @@ type ScannerValuerStruct struct {
Num Num
Strings StringsSlice
Structs StructsSlice
Role Role
}
type EncryptedData []byte
@ -176,3 +178,24 @@ func (l *StructsSlice) Scan(input interface{}) error {
return errors.New("not supported")
}
}
type Role struct {
Name string `gorm:"size:256"`
}
func (role *Role) Scan(value interface{}) error {
if b, ok := value.([]uint8); ok {
role.Name = string(b)
} else {
role.Name = value.(string)
}
return nil
}
func (role Role) Value() (driver.Value, error) {
return role.Name, nil
}
func (role Role) IsAdmin() bool {
return role.Name == "admin"
}