diff --git a/callback_create.go b/callback_create.go index 1510122f..4734d65e 100644 --- a/callback_create.go +++ b/callback_create.go @@ -34,10 +34,15 @@ func Create(scope *Scope) { } } + returningKey := "*" + if scope.PrimaryKey() != "" { + returningKey = scope.PrimaryKey() + } + if len(columns) == 0 { scope.Raw(fmt.Sprintf("INSERT INTO %v DEFAULT VALUES %v", scope.QuotedTableName(), - scope.Dialect().ReturningStr(scope.PrimaryKey()), + scope.Dialect().ReturningStr(scope.TableName(), returningKey), )) } else { scope.Raw(fmt.Sprintf( @@ -45,7 +50,7 @@ func Create(scope *Scope) { scope.QuotedTableName(), strings.Join(columns, ","), strings.Join(sqls, ","), - scope.Dialect().ReturningStr(scope.PrimaryKey()), + scope.Dialect().ReturningStr(scope.TableName(), returningKey), )) } @@ -55,18 +60,22 @@ func Create(scope *Scope) { if result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil { id, err = result.LastInsertId() if scope.Err(err) == nil { - if count, err := result.RowsAffected(); err == nil { - scope.db.RowsAffected = count - } + scope.db.RowsAffected, _ = result.RowsAffected() } } } else { - if scope.Err(scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&id)) == nil { - scope.db.RowsAffected = 1 + if scope.PrimaryKey() == "" { + if results, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); err != nil { + scope.db.RowsAffected, _ = results.RowsAffected() + } + } else { + if scope.Err(scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&id)) == nil { + scope.db.RowsAffected = 1 + } } } - if !scope.HasError() && scope.PrimaryKeyZero() { + if scope.PrimaryKey() != "" && !scope.HasError() && scope.PrimaryKeyZero() { scope.SetColumn(scope.PrimaryKey(), id) } } diff --git a/common_dialect.go b/common_dialect.go index 9acedf78..11e00876 100644 --- a/common_dialect.go +++ b/common_dialect.go @@ -64,7 +64,7 @@ func (s *commonDialect) PrimaryKeyTag(value reflect.Value, size int) string { } } -func (s *commonDialect) ReturningStr(key string) string { +func (s *commonDialect) ReturningStr(tableName, key string) string { return "" } diff --git a/create_test.go b/create_test.go index f72933b4..a5b46630 100644 --- a/create_test.go +++ b/create_test.go @@ -56,10 +56,18 @@ func TestCreate(t *testing.T) { } } +func TestCreateWithNoGORMPrimayKey(t *testing.T) { + jt := JoinTable{From: 1, To: 2} + err := DB.Create(&jt).Error + if err != nil { + t.Errorf("No error should happen when create a record without a GORM primary key. But in the database this primary key exists and is the union of 2 or more fields\n But got: %s", err) + } +} + func TestCreateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) { animal := Animal{Name: "Ferdinand"} if DB.Save(&animal).Error != nil { - t.Errorf("No error should happen when create an record without std primary key") + t.Errorf("No error should happen when create a record without std primary key") } if animal.Counter == 0 { diff --git a/dialect.go b/dialect.go index 64ffc7f0..42b5c742 100644 --- a/dialect.go +++ b/dialect.go @@ -14,7 +14,7 @@ type Dialect interface { HasTop() bool SqlTag(value reflect.Value, size int) string PrimaryKeyTag(value reflect.Value, size int) string - ReturningStr(key string) string + ReturningStr(tableName, key string) string SelectFromDummyTable() string Quote(key string) string HasTable(scope *Scope, tableName string) bool diff --git a/migration_test.go b/migration_test.go index 28176f85..1a7ae6f2 100644 --- a/migration_test.go +++ b/migration_test.go @@ -15,19 +15,11 @@ func runMigration() { DB.Exec(fmt.Sprintf("drop table %v;", table)) } - values := []interface{}{&Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}} + values := []interface{}{&Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}} for _, value := range values { DB.DropTable(value) } - if err := DB.CreateTable(&Animal{}).Error; err != nil { - panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) - } - - if err := DB.CreateTable(User{}).Error; err != nil { - panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) - } - if err := DB.AutoMigrate(values...).Error; err != nil { panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) } diff --git a/mssql.go b/mssql.go index 5fe45f0a..9317c1e7 100644 --- a/mssql.go +++ b/mssql.go @@ -64,7 +64,7 @@ func (s *mssql) PrimaryKeyTag(value reflect.Value, size int) string { } } -func (s *mssql) ReturningStr(key string) string { +func (s *mssql) ReturningStr(tableName, key string) string { return "" } diff --git a/mysql.go b/mysql.go index 9f4c525b..0dbc3d1d 100644 --- a/mysql.go +++ b/mysql.go @@ -2,8 +2,8 @@ package gorm import ( "fmt" - "strings" "reflect" + "strings" ) type mysql struct{} @@ -64,7 +64,7 @@ func (s *mysql) PrimaryKeyTag(value reflect.Value, size int) string { } } -func (s *mysql) ReturningStr(key string) string { +func (s *mysql) ReturningStr(tableName, key string) string { return "" } diff --git a/postgres.go b/postgres.go index d279b859..1cc8507d 100644 --- a/postgres.go +++ b/postgres.go @@ -4,8 +4,8 @@ import ( "database/sql" "database/sql/driver" "fmt" - "reflect" "github.com/lib/pq/hstore" + "reflect" ) type postgres struct { @@ -65,8 +65,8 @@ func (s *postgres) PrimaryKeyTag(value reflect.Value, size int) string { } } -func (s *postgres) ReturningStr(key string) string { - return fmt.Sprintf("RETURNING \"%v\"", key) +func (s *postgres) ReturningStr(tableName, key string) string { + return fmt.Sprintf("RETURNING %v.%v", tableName, key) } func (s *postgres) SelectFromDummyTable() string { diff --git a/sqlite3.go b/sqlite3.go index e1e5d2d9..266b4642 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -56,7 +56,7 @@ func (s *sqlite3) PrimaryKeyTag(value reflect.Value, size int) string { } } -func (s *sqlite3) ReturningStr(key string) string { +func (s *sqlite3) ReturningStr(tableName, key string) string { return "" } diff --git a/structs_test.go b/structs_test.go index 621dfda4..acd4cc74 100644 --- a/structs_test.go +++ b/structs_test.go @@ -134,6 +134,12 @@ type Animal struct { UpdatedAt time.Time } +type JoinTable struct { + From uint64 + To uint64 + Time time.Time `sql:"default: null"` +} + type Post struct { Id int64 CategoryId sql.NullInt64