From 0fa1335555f40164b4da8c0d888e48863f18ad65 Mon Sep 17 00:00:00 2001 From: Paolo Galeone Date: Mon, 8 Dec 2014 11:33:30 +0100 Subject: [PATCH 1/5] Avoid Errors in postgres when creating a row without a GORM defined primary key (but defined db-side) --- callback_create.go | 11 +++++++++-- create_test.go | 10 +++++++++- migration_test.go | 10 +--------- structs_test.go | 6 ++++++ 4 files changed, 25 insertions(+), 12 deletions(-) diff --git a/callback_create.go b/callback_create.go index 1510122f..342fb413 100644 --- a/callback_create.go +++ b/callback_create.go @@ -34,10 +34,17 @@ func Create(scope *Scope) { } } + returningField := "" + if scope.PrimaryKey() == "" { + returningField = "*" + } else { + returningField = scope.PrimaryKey() + } + if len(columns) == 0 { scope.Raw(fmt.Sprintf("INSERT INTO %v DEFAULT VALUES %v", scope.QuotedTableName(), - scope.Dialect().ReturningStr(scope.PrimaryKey()), + scope.Dialect().ReturningStr(returningField), )) } else { scope.Raw(fmt.Sprintf( @@ -45,7 +52,7 @@ func Create(scope *Scope) { scope.QuotedTableName(), strings.Join(columns, ","), strings.Join(sqls, ","), - scope.Dialect().ReturningStr(scope.PrimaryKey()), + scope.Dialect().ReturningStr(returningField), )) } diff --git a/create_test.go b/create_test.go index f72933b4..a5b46630 100644 --- a/create_test.go +++ b/create_test.go @@ -56,10 +56,18 @@ func TestCreate(t *testing.T) { } } +func TestCreateWithNoGORMPrimayKey(t *testing.T) { + jt := JoinTable{From: 1, To: 2} + err := DB.Create(&jt).Error + if err != nil { + t.Errorf("No error should happen when create a record without a GORM primary key. But in the database this primary key exists and is the union of 2 or more fields\n But got: %s", err) + } +} + func TestCreateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) { animal := Animal{Name: "Ferdinand"} if DB.Save(&animal).Error != nil { - t.Errorf("No error should happen when create an record without std primary key") + t.Errorf("No error should happen when create a record without std primary key") } if animal.Counter == 0 { diff --git a/migration_test.go b/migration_test.go index 28176f85..1a7ae6f2 100644 --- a/migration_test.go +++ b/migration_test.go @@ -15,19 +15,11 @@ func runMigration() { DB.Exec(fmt.Sprintf("drop table %v;", table)) } - values := []interface{}{&Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}} + values := []interface{}{&Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}} for _, value := range values { DB.DropTable(value) } - if err := DB.CreateTable(&Animal{}).Error; err != nil { - panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) - } - - if err := DB.CreateTable(User{}).Error; err != nil { - panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) - } - if err := DB.AutoMigrate(values...).Error; err != nil { panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) } diff --git a/structs_test.go b/structs_test.go index 621dfda4..acd4cc74 100644 --- a/structs_test.go +++ b/structs_test.go @@ -134,6 +134,12 @@ type Animal struct { UpdatedAt time.Time } +type JoinTable struct { + From uint64 + To uint64 + Time time.Time `sql:"default: null"` +} + type Post struct { Id int64 CategoryId sql.NullInt64 From 21f4de584f1cd93b15bf2ed498e861933ff37ff4 Mon Sep 17 00:00:00 2001 From: Paolo Galeone Date: Mon, 8 Dec 2014 12:03:42 +0100 Subject: [PATCH 2/5] Use tableName.field or tableName.* in returning string --- callback_create.go | 10 +++++----- common_dialect.go | 2 +- dialect.go | 2 +- mssql.go | 2 +- mysql.go | 4 ++-- postgres.go | 6 +++--- sqlite3.go | 2 +- 7 files changed, 14 insertions(+), 14 deletions(-) diff --git a/callback_create.go b/callback_create.go index 342fb413..8e76e855 100644 --- a/callback_create.go +++ b/callback_create.go @@ -34,17 +34,17 @@ func Create(scope *Scope) { } } - returningField := "" + returningKey := "*" if scope.PrimaryKey() == "" { - returningField = "*" + returningKey = "*" } else { - returningField = scope.PrimaryKey() + returningKey = scope.PrimaryKey() } if len(columns) == 0 { scope.Raw(fmt.Sprintf("INSERT INTO %v DEFAULT VALUES %v", scope.QuotedTableName(), - scope.Dialect().ReturningStr(returningField), + scope.Dialect().ReturningStr(scope.TableName(), returningKey), )) } else { scope.Raw(fmt.Sprintf( @@ -52,7 +52,7 @@ func Create(scope *Scope) { scope.QuotedTableName(), strings.Join(columns, ","), strings.Join(sqls, ","), - scope.Dialect().ReturningStr(returningField), + scope.Dialect().ReturningStr(scope.TableName(), returningKey), )) } diff --git a/common_dialect.go b/common_dialect.go index 9acedf78..11e00876 100644 --- a/common_dialect.go +++ b/common_dialect.go @@ -64,7 +64,7 @@ func (s *commonDialect) PrimaryKeyTag(value reflect.Value, size int) string { } } -func (s *commonDialect) ReturningStr(key string) string { +func (s *commonDialect) ReturningStr(tableName, key string) string { return "" } diff --git a/dialect.go b/dialect.go index 64ffc7f0..42b5c742 100644 --- a/dialect.go +++ b/dialect.go @@ -14,7 +14,7 @@ type Dialect interface { HasTop() bool SqlTag(value reflect.Value, size int) string PrimaryKeyTag(value reflect.Value, size int) string - ReturningStr(key string) string + ReturningStr(tableName, key string) string SelectFromDummyTable() string Quote(key string) string HasTable(scope *Scope, tableName string) bool diff --git a/mssql.go b/mssql.go index 5fe45f0a..9317c1e7 100644 --- a/mssql.go +++ b/mssql.go @@ -64,7 +64,7 @@ func (s *mssql) PrimaryKeyTag(value reflect.Value, size int) string { } } -func (s *mssql) ReturningStr(key string) string { +func (s *mssql) ReturningStr(tableName, key string) string { return "" } diff --git a/mysql.go b/mysql.go index 9f4c525b..0dbc3d1d 100644 --- a/mysql.go +++ b/mysql.go @@ -2,8 +2,8 @@ package gorm import ( "fmt" - "strings" "reflect" + "strings" ) type mysql struct{} @@ -64,7 +64,7 @@ func (s *mysql) PrimaryKeyTag(value reflect.Value, size int) string { } } -func (s *mysql) ReturningStr(key string) string { +func (s *mysql) ReturningStr(tableName, key string) string { return "" } diff --git a/postgres.go b/postgres.go index d279b859..1cc8507d 100644 --- a/postgres.go +++ b/postgres.go @@ -4,8 +4,8 @@ import ( "database/sql" "database/sql/driver" "fmt" - "reflect" "github.com/lib/pq/hstore" + "reflect" ) type postgres struct { @@ -65,8 +65,8 @@ func (s *postgres) PrimaryKeyTag(value reflect.Value, size int) string { } } -func (s *postgres) ReturningStr(key string) string { - return fmt.Sprintf("RETURNING \"%v\"", key) +func (s *postgres) ReturningStr(tableName, key string) string { + return fmt.Sprintf("RETURNING %v.%v", tableName, key) } func (s *postgres) SelectFromDummyTable() string { diff --git a/sqlite3.go b/sqlite3.go index e1e5d2d9..266b4642 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -56,7 +56,7 @@ func (s *sqlite3) PrimaryKeyTag(value reflect.Value, size int) string { } } -func (s *sqlite3) ReturningStr(key string) string { +func (s *sqlite3) ReturningStr(tableName, key string) string { return "" } From 907ea93b432f8bb236f183209750b31cc7a74d76 Mon Sep 17 00:00:00 2001 From: Paolo Galeone Date: Mon, 8 Dec 2014 19:00:02 +0100 Subject: [PATCH 3/5] Fixes scan when primary key is not defined in gorm (but is defined db-size). Useful in join tables --- callback_create.go | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/callback_create.go b/callback_create.go index 8e76e855..ed2589dd 100644 --- a/callback_create.go +++ b/callback_create.go @@ -35,9 +35,7 @@ func Create(scope *Scope) { } returningKey := "*" - if scope.PrimaryKey() == "" { - returningKey = "*" - } else { + if scope.PrimaryKey() != "" { returningKey = scope.PrimaryKey() } @@ -68,12 +66,24 @@ func Create(scope *Scope) { } } } else { - if scope.Err(scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&id)) == nil { - scope.db.RowsAffected = 1 + if scope.PrimaryKey() == "" { + if rows, err := scope.DB().Query(scope.Sql, scope.SqlVars...); err != nil { + //extract column name to get fields lenght + if names, columnsErr := rows.Columns(); columnsErr != nil { + ids := make([]interface{}, len(names)) + if scope.Err(rows.Scan(ids...)) == nil { + scope.db.RowsAffected = int64(len(names)) + } + } + } + } else { + if scope.Err(scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&id)) == nil { + scope.db.RowsAffected = 1 + } } } - if !scope.HasError() && scope.PrimaryKeyZero() { + if scope.PrimaryKey() != "" && !scope.HasError() && scope.PrimaryKeyZero() { scope.SetColumn(scope.PrimaryKey(), id) } } From e30cf3bbd1e6286b8d6c9b6dcb578d405c350b45 Mon Sep 17 00:00:00 2001 From: Paolo Galeone Date: Mon, 8 Dec 2014 20:33:31 +0100 Subject: [PATCH 4/5] Fix number of columns affected. Is always one. The old parameter was the numer of fields affected --- callback_create.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/callback_create.go b/callback_create.go index ed2589dd..4c909e1c 100644 --- a/callback_create.go +++ b/callback_create.go @@ -72,7 +72,7 @@ func Create(scope *Scope) { if names, columnsErr := rows.Columns(); columnsErr != nil { ids := make([]interface{}, len(names)) if scope.Err(rows.Scan(ids...)) == nil { - scope.db.RowsAffected = int64(len(names)) + scope.db.RowsAffected = 1 } } } From 5823670269c9eb0d8eeedd2eaddcc06310df767c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 13 Dec 2014 10:46:16 +0800 Subject: [PATCH 5/5] Refactor callbck_create --- callback_create.go | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/callback_create.go b/callback_create.go index 4c909e1c..4734d65e 100644 --- a/callback_create.go +++ b/callback_create.go @@ -60,21 +60,13 @@ func Create(scope *Scope) { if result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil { id, err = result.LastInsertId() if scope.Err(err) == nil { - if count, err := result.RowsAffected(); err == nil { - scope.db.RowsAffected = count - } + scope.db.RowsAffected, _ = result.RowsAffected() } } } else { if scope.PrimaryKey() == "" { - if rows, err := scope.DB().Query(scope.Sql, scope.SqlVars...); err != nil { - //extract column name to get fields lenght - if names, columnsErr := rows.Columns(); columnsErr != nil { - ids := make([]interface{}, len(names)) - if scope.Err(rows.Scan(ids...)) == nil { - scope.db.RowsAffected = 1 - } - } + if results, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); err != nil { + scope.db.RowsAffected, _ = results.RowsAffected() } } else { if scope.Err(scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&id)) == nil {