Merge branch 'galeone-master'

This commit is contained in:
Jinzhu 2014-12-13 10:46:25 +08:00
commit 5c9d3bb352
10 changed files with 42 additions and 27 deletions

View File

@ -34,10 +34,15 @@ func Create(scope *Scope) {
} }
} }
returningKey := "*"
if scope.PrimaryKey() != "" {
returningKey = scope.PrimaryKey()
}
if len(columns) == 0 { if len(columns) == 0 {
scope.Raw(fmt.Sprintf("INSERT INTO %v DEFAULT VALUES %v", scope.Raw(fmt.Sprintf("INSERT INTO %v DEFAULT VALUES %v",
scope.QuotedTableName(), scope.QuotedTableName(),
scope.Dialect().ReturningStr(scope.PrimaryKey()), scope.Dialect().ReturningStr(scope.TableName(), returningKey),
)) ))
} else { } else {
scope.Raw(fmt.Sprintf( scope.Raw(fmt.Sprintf(
@ -45,7 +50,7 @@ func Create(scope *Scope) {
scope.QuotedTableName(), scope.QuotedTableName(),
strings.Join(columns, ","), strings.Join(columns, ","),
strings.Join(sqls, ","), strings.Join(sqls, ","),
scope.Dialect().ReturningStr(scope.PrimaryKey()), scope.Dialect().ReturningStr(scope.TableName(), returningKey),
)) ))
} }
@ -55,18 +60,22 @@ func Create(scope *Scope) {
if result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil { if result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
id, err = result.LastInsertId() id, err = result.LastInsertId()
if scope.Err(err) == nil { if scope.Err(err) == nil {
if count, err := result.RowsAffected(); err == nil { scope.db.RowsAffected, _ = result.RowsAffected()
scope.db.RowsAffected = count
}
} }
} }
} else { } else {
if scope.Err(scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&id)) == nil { if scope.PrimaryKey() == "" {
scope.db.RowsAffected = 1 if results, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); err != nil {
scope.db.RowsAffected, _ = results.RowsAffected()
}
} else {
if scope.Err(scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&id)) == nil {
scope.db.RowsAffected = 1
}
} }
} }
if !scope.HasError() && scope.PrimaryKeyZero() { if scope.PrimaryKey() != "" && !scope.HasError() && scope.PrimaryKeyZero() {
scope.SetColumn(scope.PrimaryKey(), id) scope.SetColumn(scope.PrimaryKey(), id)
} }
} }

View File

@ -64,7 +64,7 @@ func (s *commonDialect) PrimaryKeyTag(value reflect.Value, size int) string {
} }
} }
func (s *commonDialect) ReturningStr(key string) string { func (s *commonDialect) ReturningStr(tableName, key string) string {
return "" return ""
} }

View File

@ -56,10 +56,18 @@ func TestCreate(t *testing.T) {
} }
} }
func TestCreateWithNoGORMPrimayKey(t *testing.T) {
jt := JoinTable{From: 1, To: 2}
err := DB.Create(&jt).Error
if err != nil {
t.Errorf("No error should happen when create a record without a GORM primary key. But in the database this primary key exists and is the union of 2 or more fields\n But got: %s", err)
}
}
func TestCreateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) { func TestCreateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) {
animal := Animal{Name: "Ferdinand"} animal := Animal{Name: "Ferdinand"}
if DB.Save(&animal).Error != nil { if DB.Save(&animal).Error != nil {
t.Errorf("No error should happen when create an record without std primary key") t.Errorf("No error should happen when create a record without std primary key")
} }
if animal.Counter == 0 { if animal.Counter == 0 {

View File

@ -14,7 +14,7 @@ type Dialect interface {
HasTop() bool HasTop() bool
SqlTag(value reflect.Value, size int) string SqlTag(value reflect.Value, size int) string
PrimaryKeyTag(value reflect.Value, size int) string PrimaryKeyTag(value reflect.Value, size int) string
ReturningStr(key string) string ReturningStr(tableName, key string) string
SelectFromDummyTable() string SelectFromDummyTable() string
Quote(key string) string Quote(key string) string
HasTable(scope *Scope, tableName string) bool HasTable(scope *Scope, tableName string) bool

View File

@ -15,19 +15,11 @@ func runMigration() {
DB.Exec(fmt.Sprintf("drop table %v;", table)) DB.Exec(fmt.Sprintf("drop table %v;", table))
} }
values := []interface{}{&Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}} values := []interface{}{&Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}}
for _, value := range values { for _, value := range values {
DB.DropTable(value) DB.DropTable(value)
} }
if err := DB.CreateTable(&Animal{}).Error; err != nil {
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
}
if err := DB.CreateTable(User{}).Error; err != nil {
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
}
if err := DB.AutoMigrate(values...).Error; err != nil { if err := DB.AutoMigrate(values...).Error; err != nil {
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
} }

View File

@ -64,7 +64,7 @@ func (s *mssql) PrimaryKeyTag(value reflect.Value, size int) string {
} }
} }
func (s *mssql) ReturningStr(key string) string { func (s *mssql) ReturningStr(tableName, key string) string {
return "" return ""
} }

View File

@ -2,8 +2,8 @@ package gorm
import ( import (
"fmt" "fmt"
"strings"
"reflect" "reflect"
"strings"
) )
type mysql struct{} type mysql struct{}
@ -64,7 +64,7 @@ func (s *mysql) PrimaryKeyTag(value reflect.Value, size int) string {
} }
} }
func (s *mysql) ReturningStr(key string) string { func (s *mysql) ReturningStr(tableName, key string) string {
return "" return ""
} }

View File

@ -4,8 +4,8 @@ import (
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"fmt" "fmt"
"reflect"
"github.com/lib/pq/hstore" "github.com/lib/pq/hstore"
"reflect"
) )
type postgres struct { type postgres struct {
@ -65,8 +65,8 @@ func (s *postgres) PrimaryKeyTag(value reflect.Value, size int) string {
} }
} }
func (s *postgres) ReturningStr(key string) string { func (s *postgres) ReturningStr(tableName, key string) string {
return fmt.Sprintf("RETURNING \"%v\"", key) return fmt.Sprintf("RETURNING %v.%v", tableName, key)
} }
func (s *postgres) SelectFromDummyTable() string { func (s *postgres) SelectFromDummyTable() string {

View File

@ -56,7 +56,7 @@ func (s *sqlite3) PrimaryKeyTag(value reflect.Value, size int) string {
} }
} }
func (s *sqlite3) ReturningStr(key string) string { func (s *sqlite3) ReturningStr(tableName, key string) string {
return "" return ""
} }

View File

@ -134,6 +134,12 @@ type Animal struct {
UpdatedAt time.Time UpdatedAt time.Time
} }
type JoinTable struct {
From uint64
To uint64
Time time.Time `sql:"default: null"`
}
type Post struct { type Post struct {
Id int64 Id int64
CategoryId sql.NullInt64 CategoryId sql.NullInt64