Create with Select

This commit is contained in:
Jinzhu 2015-03-12 17:47:31 +08:00
parent da7830ea50
commit ad251b9deb
8 changed files with 115 additions and 49 deletions

View File

@ -24,13 +24,23 @@ func Create(scope *Scope) {
if !scope.HasError() {
// set create sql
var sqls, columns []string
for _, field := range scope.Fields() {
if (field.IsNormal && !field.IsPrimaryKey) || (field.IsPrimaryKey && !field.IsBlank) {
fields := scope.Fields()
for _, field := range fields {
if scope.ValidField(field) {
if field.IsNormal {
if !field.IsPrimaryKey || (field.IsPrimaryKey && !field.IsBlank) {
if !field.IsBlank || !field.HasDefaultValue {
columns = append(columns, scope.Quote(field.DBName))
sqls = append(sqls, scope.AddToVars(field.Field.Interface()))
}
}
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
if relationField := fields[relationship.ForeignDBName]; !scope.ValidField(relationField) {
columns = append(columns, scope.Quote(relationField.DBName))
sqls = append(sqls, scope.AddToVars(relationField.Field.Interface()))
}
}
}
}
returningKey := "*"

View File

@ -12,7 +12,7 @@ func CommitOrRollbackTransaction(scope *Scope) {
func SaveBeforeAssociations(scope *Scope) {
for _, field := range scope.Fields() {
if !field.IsBlank && !field.IsIgnored {
if scope.ValidField(field) && !field.IsBlank && !field.IsIgnored {
if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
value := field.Field
scope.Err(scope.NewDB().Save(value.Addr().Interface()).Error)
@ -26,7 +26,7 @@ func SaveBeforeAssociations(scope *Scope) {
func SaveAfterAssociations(scope *Scope) {
for _, field := range scope.Fields() {
if !field.IsBlank && !field.IsIgnored {
if scope.ValidField(field) && !field.IsBlank && !field.IsIgnored {
if relationship := field.Relationship; relationship != nil &&
(relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") {
value := field.Field

View File

@ -1,6 +1,7 @@
package gorm_test
import (
"fmt"
"reflect"
"testing"
"time"
@ -121,3 +122,22 @@ func TestAnonymousField(t *testing.T) {
t.Errorf("Should be able to get anonymous field")
}
}
func TestSelectCreate(t *testing.T) {
user := getPreparedUser("user1", "select_create")
DB.Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Create(&user)
var user2 User
DB.Preload("BillingAddress").Preload("ShippingAddress").
Preload("CreditCard").Preload("Emails").Preload("Company").First(&user2, user.Id)
if user2.Name != user.Name || user2.Age == user.Age {
t.Errorf("Should only create users with name column")
}
fmt.Println(user2.CreditCard.ID)
if user2.BillingAddressID.Int64 == 0 || user2.ShippingAddressId != 0 ||
user2.CreditCard.ID == 0 || len(user2.Emails) == 0 {
t.Errorf("Should only create users with name column")
}
}

View File

@ -89,8 +89,12 @@ func TestExceptionsWithInvalidSql(t *testing.T) {
}
func TestSetTable(t *testing.T) {
if DB.Table("users").Pluck("age", &[]int{}).Error != nil {
t.Errorf("No errors should happen if set table for pluck")
DB.Create(getPreparedUser("pluck_user1", "pluck_user"))
DB.Create(getPreparedUser("pluck_user2", "pluck_user"))
DB.Create(getPreparedUser("pluck_user3", "pluck_user"))
if err := DB.Table("users").Where("role = ?", "pluck_user").Pluck("age", &[]int{}).Error; err != nil {
t.Errorf("No errors should happen if set table for pluck", err.Error())
}
var users []User
@ -115,9 +119,11 @@ func TestSetTable(t *testing.T) {
t.Errorf("Query from specified table")
}
DB.Save(getPreparedUser("normal_user", "reset_table"))
DB.Table("deleted_users").Save(getPreparedUser("deleted_user", "reset_table"))
var user1, user2, user3 User
DB.First(&user1).Table("deleted_users").First(&user2).Table("").First(&user3)
if (user1.Name == user2.Name) || (user1.Name != user3.Name) {
DB.Where("role = ?", "reset_table").First(&user1).Table("deleted_users").First(&user2).Table("").First(&user3)
if (user1.Name != "normal_user") || (user2.Name != "deleted_user") || (user3.Name != "normal_user") {
t.Errorf("unset specified table with blank string")
}
}

View File

@ -1,29 +1,9 @@
package gorm_test
import (
"fmt"
"testing"
)
import "testing"
func getPreloadUser(name string) User {
var company Company
DB.Where(Company{Name: "preload"}).FirstOrCreate(&company)
return User{
Name: name,
Role: Role{"Preload"},
BillingAddress: Address{Address1: fmt.Sprintf("Billing Address %v", name)},
ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)},
CreditCard: CreditCard{Number: fmt.Sprintf("123456%v", name)},
Emails: []Email{
{Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)},
},
Company: company,
Languages: []Language{
{Name: fmt.Sprintf("lang_1_%v", name)},
{Name: fmt.Sprintf("lang_2_%v", name)},
},
}
func getPreloadUser(name string) *User {
return getPreparedUser(name, "Preload")
}
func checkUserHasPreloadData(user User, t *testing.T) {
@ -64,7 +44,7 @@ func checkUserHasPreloadData(user User, t *testing.T) {
func TestPreload(t *testing.T) {
user1 := getPreloadUser("user1")
DB.Save(&user1)
DB.Save(user1)
preloadDB := DB.Where("role = ?", "Preload").Preload("BillingAddress").Preload("ShippingAddress").
Preload("CreditCard").Preload("Emails").Preload("Company")
@ -73,10 +53,10 @@ func TestPreload(t *testing.T) {
checkUserHasPreloadData(user, t)
user2 := getPreloadUser("user2")
DB.Save(&user2)
DB.Save(user2)
user3 := getPreloadUser("user3")
DB.Save(&user3)
DB.Save(user3)
var users []User
preloadDB.Find(&users)

View File

@ -333,3 +333,42 @@ func (scope *Scope) CommitOrRollback() *Scope {
}
return scope
}
func (scope *Scope) SelectAttrs() (attrs []string) {
for _, value := range scope.Search.selects {
if str, ok := value.(string); ok {
attrs = append(attrs, str)
} else if strs, ok := value.([]interface{}); ok {
for _, str := range strs {
attrs = append(attrs, fmt.Sprintf("%v", str))
}
}
}
return attrs
}
func (scope *Scope) OmitAttrs() []string {
return scope.Search.omits
}
func (scope *Scope) ValidField(field *Field) bool {
selectAttrs := scope.SelectAttrs()
omitAttrs := scope.OmitAttrs()
if len(selectAttrs) > 0 {
for _, attr := range selectAttrs {
if field.Name == attr || field.DBName == attr {
return true
}
}
return false
}
for _, attr := range omitAttrs {
if field.Name == attr || field.DBName == attr {
return false
}
}
return !field.IsIgnored
}

View File

@ -89,18 +89,6 @@ func (s *search) Omit(columns ...string) *search {
return s
}
func (s *search) SelectAttrs() (attrs []string) {
for key, value := range s.selects {
attrs = append(attrs, key)
attrs = append(attrs, value.([]string)...)
}
return attrs
}
func (s *search) OmitAttrs() []string {
return s.omits
}
func (s *search) Limit(value interface{}) *search {
s.limit = s.getInterfaceAsSql(value)
return s

View File

@ -4,6 +4,7 @@ import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"reflect"
"time"
@ -194,3 +195,25 @@ func (nt NullTime) Value() (driver.Value, error) {
}
return nt.Time, nil
}
func getPreparedUser(name string, role string) *User {
var company Company
DB.Where(Company{Name: role}).FirstOrCreate(&company)
return &User{
Name: name,
Age: 20,
Role: Role{role},
BillingAddress: Address{Address1: fmt.Sprintf("Billing Address %v", name)},
ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)},
CreditCard: CreditCard{Number: fmt.Sprintf("123456%v", name)},
Emails: []Email{
{Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)},
},
Company: company,
Languages: []Language{
{Name: fmt.Sprintf("lang_1_%v", name)},
{Name: fmt.Sprintf("lang_2_%v", name)},
},
}
}