diff --git a/callbacks/create.go b/callbacks/create.go index ac63c89b..f558d7ae 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -64,7 +64,7 @@ func Create(config *Config) func(db *gorm.DB) { result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if err == nil { - if db.Statement.Schema != nil { + if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil { if _, ok := db.Statement.Schema.FieldsWithDefaultDBValue[db.Statement.Schema.PrioritizedPrimaryField.DBName]; ok { if insertID, err := result.LastInsertId(); err == nil { switch db.Statement.ReflectValue.Kind() { diff --git a/dialects/mssql/create.go b/dialects/mssql/create.go index c85997fb..ebdeeab0 100644 --- a/dialects/mssql/create.go +++ b/dialects/mssql/create.go @@ -68,26 +68,30 @@ func Create(db *gorm.DB) { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: - for rows.Next() { - // for idx, field := range fields { - // values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() - // } + if len(db.Statement.Schema.PrimaryFields) > 0 { + values := make([]interface{}, len(db.Statement.Schema.PrimaryFields)) - values := db.Statement.Schema.PrioritizedPrimaryField.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() - if err := rows.Scan(values); err != nil { - db.AddError(err) + for rows.Next() { + for idx, field := range db.Statement.Schema.PrimaryFields { + values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() + } + + db.RowsAffected++ + db.AddError(rows.Scan(values...)) } - db.RowsAffected++ } case reflect.Struct: - // for idx, field := range fields { - // values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() - // } - values := db.Statement.Schema.PrioritizedPrimaryField.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() + if len(db.Statement.Schema.PrimaryFields) > 0 { + values := make([]interface{}, len(db.Statement.Schema.PrimaryFields)) - if rows.Next() { - db.RowsAffected++ - db.AddError(rows.Scan(values)) + for idx, field := range db.Statement.Schema.PrimaryFields { + values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() + } + + if rows.Next() { + db.RowsAffected++ + db.AddError(rows.Scan(values...)) + } } } } else { @@ -177,8 +181,14 @@ func MergeCreate(db *gorm.DB, onConflict clause.OnConflict) { } func outputInserted(db *gorm.DB) { - if db.Statement.Schema.PrioritizedPrimaryField != nil { - db.Statement.WriteString(" OUTPUT INSERTED.") - db.Statement.WriteQuoted(db.Statement.Schema.PrioritizedPrimaryField.DBName) + if len(db.Statement.Schema.PrimaryFields) > 0 { + db.Statement.WriteString(" OUTPUT ") + for idx, field := range db.Statement.Schema.PrimaryFields { + if idx > 0 { + db.Statement.WriteString(",") + } + db.Statement.WriteString(" INSERTED.") + db.Statement.AddVar(db.Statement, clause.Column{Name: field.DBName}) + } } } diff --git a/migrator/migrator.go b/migrator/migrator.go index 5a06beb1..4e0f28b5 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -149,7 +149,7 @@ func (m Migrator) CreateTable(values ...interface{}) error { createTableSQL += "," } - if !hasPrimaryKeyInDataType { + if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 { createTableSQL += "PRIMARY KEY ?," primaryKeys := []interface{}{} for _, field := range stmt.Schema.PrimaryFields { diff --git a/tests/create_test.go b/tests/create_test.go index a3b3b598..6421ca34 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -245,3 +245,21 @@ func TestCreateWithNowFuncOverride(t *testing.T) { AssertEqual(t, newUser.CreatedAt, curTime) AssertEqual(t, newUser.UpdatedAt, curTime) } + +func TestCreateWithNoGORMPrimayKey(t *testing.T) { + type JoinTable struct { + UserID uint + FriendID uint + } + + DB.Migrator().DropTable(&JoinTable{}) + if err := DB.AutoMigrate(&JoinTable{}); err != nil { + t.Errorf("no error should happen when auto migrate, but got %v", err) + } + + jt := JoinTable{UserID: 1, FriendID: 2} + err := DB.Create(&jt).Error + if err != nil { + t.Errorf("No error should happen when create a record without a GORM primary key. But in the database this primary key exists and is the union of 2 or more fields\n But got: %s", err) + } +} diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index 88e7e12e..04c91ab2 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -34,6 +34,7 @@ func TestScannerValuer(t *testing.T) { {"name1", "value1"}, {"name2", "value2"}, }, + Role: Role{Name: "admin"}, } if err := DB.Create(&data).Error; err != nil { @@ -91,6 +92,7 @@ type ScannerValuerStruct struct { Num Num Strings StringsSlice Structs StructsSlice + Role Role } type EncryptedData []byte @@ -176,3 +178,24 @@ func (l *StructsSlice) Scan(input interface{}) error { return errors.New("not supported") } } + +type Role struct { + Name string `gorm:"size:256"` +} + +func (role *Role) Scan(value interface{}) error { + if b, ok := value.([]uint8); ok { + role.Name = string(b) + } else { + role.Name = value.(string) + } + return nil +} + +func (role Role) Value() (driver.Value, error) { + return role.Name, nil +} + +func (role Role) IsAdmin() bool { + return role.Name == "admin" +}