mirror of https://github.com/go-gorm/gorm.git
Test CreateWithNoGORMPrimayKey
This commit is contained in:
parent
b3b19a5577
commit
1546f8a4a1
|
@ -64,7 +64,7 @@ func Create(config *Config) func(db *gorm.DB) {
|
||||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if db.Statement.Schema != nil {
|
if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil {
|
||||||
if _, ok := db.Statement.Schema.FieldsWithDefaultDBValue[db.Statement.Schema.PrioritizedPrimaryField.DBName]; ok {
|
if _, ok := db.Statement.Schema.FieldsWithDefaultDBValue[db.Statement.Schema.PrioritizedPrimaryField.DBName]; ok {
|
||||||
if insertID, err := result.LastInsertId(); err == nil {
|
if insertID, err := result.LastInsertId(); err == nil {
|
||||||
switch db.Statement.ReflectValue.Kind() {
|
switch db.Statement.ReflectValue.Kind() {
|
||||||
|
|
|
@ -68,26 +68,30 @@ func Create(db *gorm.DB) {
|
||||||
|
|
||||||
switch db.Statement.ReflectValue.Kind() {
|
switch db.Statement.ReflectValue.Kind() {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
for rows.Next() {
|
if len(db.Statement.Schema.PrimaryFields) > 0 {
|
||||||
// for idx, field := range fields {
|
values := make([]interface{}, len(db.Statement.Schema.PrimaryFields))
|
||||||
// values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface()
|
|
||||||
// }
|
|
||||||
|
|
||||||
values := db.Statement.Schema.PrioritizedPrimaryField.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface()
|
for rows.Next() {
|
||||||
if err := rows.Scan(values); err != nil {
|
for idx, field := range db.Statement.Schema.PrimaryFields {
|
||||||
db.AddError(err)
|
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface()
|
||||||
}
|
}
|
||||||
|
|
||||||
db.RowsAffected++
|
db.RowsAffected++
|
||||||
|
db.AddError(rows.Scan(values...))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
// for idx, field := range fields {
|
if len(db.Statement.Schema.PrimaryFields) > 0 {
|
||||||
// values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
|
values := make([]interface{}, len(db.Statement.Schema.PrimaryFields))
|
||||||
// }
|
|
||||||
values := db.Statement.Schema.PrioritizedPrimaryField.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
|
for idx, field := range db.Statement.Schema.PrimaryFields {
|
||||||
|
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
|
||||||
|
}
|
||||||
|
|
||||||
if rows.Next() {
|
if rows.Next() {
|
||||||
db.RowsAffected++
|
db.RowsAffected++
|
||||||
db.AddError(rows.Scan(values))
|
db.AddError(rows.Scan(values...))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -177,8 +181,14 @@ func MergeCreate(db *gorm.DB, onConflict clause.OnConflict) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func outputInserted(db *gorm.DB) {
|
func outputInserted(db *gorm.DB) {
|
||||||
if db.Statement.Schema.PrioritizedPrimaryField != nil {
|
if len(db.Statement.Schema.PrimaryFields) > 0 {
|
||||||
db.Statement.WriteString(" OUTPUT INSERTED.")
|
db.Statement.WriteString(" OUTPUT ")
|
||||||
db.Statement.WriteQuoted(db.Statement.Schema.PrioritizedPrimaryField.DBName)
|
for idx, field := range db.Statement.Schema.PrimaryFields {
|
||||||
|
if idx > 0 {
|
||||||
|
db.Statement.WriteString(",")
|
||||||
|
}
|
||||||
|
db.Statement.WriteString(" INSERTED.")
|
||||||
|
db.Statement.AddVar(db.Statement, clause.Column{Name: field.DBName})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -149,7 +149,7 @@ func (m Migrator) CreateTable(values ...interface{}) error {
|
||||||
createTableSQL += ","
|
createTableSQL += ","
|
||||||
}
|
}
|
||||||
|
|
||||||
if !hasPrimaryKeyInDataType {
|
if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 {
|
||||||
createTableSQL += "PRIMARY KEY ?,"
|
createTableSQL += "PRIMARY KEY ?,"
|
||||||
primaryKeys := []interface{}{}
|
primaryKeys := []interface{}{}
|
||||||
for _, field := range stmt.Schema.PrimaryFields {
|
for _, field := range stmt.Schema.PrimaryFields {
|
||||||
|
|
|
@ -245,3 +245,21 @@ func TestCreateWithNowFuncOverride(t *testing.T) {
|
||||||
AssertEqual(t, newUser.CreatedAt, curTime)
|
AssertEqual(t, newUser.CreatedAt, curTime)
|
||||||
AssertEqual(t, newUser.UpdatedAt, curTime)
|
AssertEqual(t, newUser.UpdatedAt, curTime)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCreateWithNoGORMPrimayKey(t *testing.T) {
|
||||||
|
type JoinTable struct {
|
||||||
|
UserID uint
|
||||||
|
FriendID uint
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Migrator().DropTable(&JoinTable{})
|
||||||
|
if err := DB.AutoMigrate(&JoinTable{}); err != nil {
|
||||||
|
t.Errorf("no error should happen when auto migrate, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
jt := JoinTable{UserID: 1, FriendID: 2}
|
||||||
|
err := DB.Create(&jt).Error
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("No error should happen when create a record without a GORM primary key. But in the database this primary key exists and is the union of 2 or more fields\n But got: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -34,6 +34,7 @@ func TestScannerValuer(t *testing.T) {
|
||||||
{"name1", "value1"},
|
{"name1", "value1"},
|
||||||
{"name2", "value2"},
|
{"name2", "value2"},
|
||||||
},
|
},
|
||||||
|
Role: Role{Name: "admin"},
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := DB.Create(&data).Error; err != nil {
|
if err := DB.Create(&data).Error; err != nil {
|
||||||
|
@ -91,6 +92,7 @@ type ScannerValuerStruct struct {
|
||||||
Num Num
|
Num Num
|
||||||
Strings StringsSlice
|
Strings StringsSlice
|
||||||
Structs StructsSlice
|
Structs StructsSlice
|
||||||
|
Role Role
|
||||||
}
|
}
|
||||||
|
|
||||||
type EncryptedData []byte
|
type EncryptedData []byte
|
||||||
|
@ -176,3 +178,24 @@ func (l *StructsSlice) Scan(input interface{}) error {
|
||||||
return errors.New("not supported")
|
return errors.New("not supported")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Role struct {
|
||||||
|
Name string `gorm:"size:256"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (role *Role) Scan(value interface{}) error {
|
||||||
|
if b, ok := value.([]uint8); ok {
|
||||||
|
role.Name = string(b)
|
||||||
|
} else {
|
||||||
|
role.Name = value.(string)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (role Role) Value() (driver.Value, error) {
|
||||||
|
return role.Name, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (role Role) IsAdmin() bool {
|
||||||
|
return role.Name == "admin"
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue