mirror of https://github.com/go-gorm/gorm.git
Test CreateWithNoGORMPrimayKey
This commit is contained in:
parent
b3b19a5577
commit
1546f8a4a1
|
@ -64,7 +64,7 @@ func Create(config *Config) func(db *gorm.DB) {
|
|||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
|
||||
if err == nil {
|
||||
if db.Statement.Schema != nil {
|
||||
if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil {
|
||||
if _, ok := db.Statement.Schema.FieldsWithDefaultDBValue[db.Statement.Schema.PrioritizedPrimaryField.DBName]; ok {
|
||||
if insertID, err := result.LastInsertId(); err == nil {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
|
|
|
@ -68,26 +68,30 @@ func Create(db *gorm.DB) {
|
|||
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for rows.Next() {
|
||||
// for idx, field := range fields {
|
||||
// values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface()
|
||||
// }
|
||||
if len(db.Statement.Schema.PrimaryFields) > 0 {
|
||||
values := make([]interface{}, len(db.Statement.Schema.PrimaryFields))
|
||||
|
||||
values := db.Statement.Schema.PrioritizedPrimaryField.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface()
|
||||
if err := rows.Scan(values); err != nil {
|
||||
db.AddError(err)
|
||||
for rows.Next() {
|
||||
for idx, field := range db.Statement.Schema.PrimaryFields {
|
||||
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface()
|
||||
}
|
||||
|
||||
db.RowsAffected++
|
||||
db.AddError(rows.Scan(values...))
|
||||
}
|
||||
db.RowsAffected++
|
||||
}
|
||||
case reflect.Struct:
|
||||
// for idx, field := range fields {
|
||||
// values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
|
||||
// }
|
||||
values := db.Statement.Schema.PrioritizedPrimaryField.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
|
||||
if len(db.Statement.Schema.PrimaryFields) > 0 {
|
||||
values := make([]interface{}, len(db.Statement.Schema.PrimaryFields))
|
||||
|
||||
if rows.Next() {
|
||||
db.RowsAffected++
|
||||
db.AddError(rows.Scan(values))
|
||||
for idx, field := range db.Statement.Schema.PrimaryFields {
|
||||
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
|
||||
}
|
||||
|
||||
if rows.Next() {
|
||||
db.RowsAffected++
|
||||
db.AddError(rows.Scan(values...))
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
@ -177,8 +181,14 @@ func MergeCreate(db *gorm.DB, onConflict clause.OnConflict) {
|
|||
}
|
||||
|
||||
func outputInserted(db *gorm.DB) {
|
||||
if db.Statement.Schema.PrioritizedPrimaryField != nil {
|
||||
db.Statement.WriteString(" OUTPUT INSERTED.")
|
||||
db.Statement.WriteQuoted(db.Statement.Schema.PrioritizedPrimaryField.DBName)
|
||||
if len(db.Statement.Schema.PrimaryFields) > 0 {
|
||||
db.Statement.WriteString(" OUTPUT ")
|
||||
for idx, field := range db.Statement.Schema.PrimaryFields {
|
||||
if idx > 0 {
|
||||
db.Statement.WriteString(",")
|
||||
}
|
||||
db.Statement.WriteString(" INSERTED.")
|
||||
db.Statement.AddVar(db.Statement, clause.Column{Name: field.DBName})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -149,7 +149,7 @@ func (m Migrator) CreateTable(values ...interface{}) error {
|
|||
createTableSQL += ","
|
||||
}
|
||||
|
||||
if !hasPrimaryKeyInDataType {
|
||||
if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 {
|
||||
createTableSQL += "PRIMARY KEY ?,"
|
||||
primaryKeys := []interface{}{}
|
||||
for _, field := range stmt.Schema.PrimaryFields {
|
||||
|
|
|
@ -245,3 +245,21 @@ func TestCreateWithNowFuncOverride(t *testing.T) {
|
|||
AssertEqual(t, newUser.CreatedAt, curTime)
|
||||
AssertEqual(t, newUser.UpdatedAt, curTime)
|
||||
}
|
||||
|
||||
func TestCreateWithNoGORMPrimayKey(t *testing.T) {
|
||||
type JoinTable struct {
|
||||
UserID uint
|
||||
FriendID uint
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&JoinTable{})
|
||||
if err := DB.AutoMigrate(&JoinTable{}); err != nil {
|
||||
t.Errorf("no error should happen when auto migrate, but got %v", err)
|
||||
}
|
||||
|
||||
jt := JoinTable{UserID: 1, FriendID: 2}
|
||||
err := DB.Create(&jt).Error
|
||||
if err != nil {
|
||||
t.Errorf("No error should happen when create a record without a GORM primary key. But in the database this primary key exists and is the union of 2 or more fields\n But got: %s", err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -34,6 +34,7 @@ func TestScannerValuer(t *testing.T) {
|
|||
{"name1", "value1"},
|
||||
{"name2", "value2"},
|
||||
},
|
||||
Role: Role{Name: "admin"},
|
||||
}
|
||||
|
||||
if err := DB.Create(&data).Error; err != nil {
|
||||
|
@ -91,6 +92,7 @@ type ScannerValuerStruct struct {
|
|||
Num Num
|
||||
Strings StringsSlice
|
||||
Structs StructsSlice
|
||||
Role Role
|
||||
}
|
||||
|
||||
type EncryptedData []byte
|
||||
|
@ -176,3 +178,24 @@ func (l *StructsSlice) Scan(input interface{}) error {
|
|||
return errors.New("not supported")
|
||||
}
|
||||
}
|
||||
|
||||
type Role struct {
|
||||
Name string `gorm:"size:256"`
|
||||
}
|
||||
|
||||
func (role *Role) Scan(value interface{}) error {
|
||||
if b, ok := value.([]uint8); ok {
|
||||
role.Name = string(b)
|
||||
} else {
|
||||
role.Name = value.(string)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (role Role) Value() (driver.Value, error) {
|
||||
return role.Name, nil
|
||||
}
|
||||
|
||||
func (role Role) IsAdmin() bool {
|
||||
return role.Name == "admin"
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue