mirror of https://github.com/go-gorm/gorm.git
Create with Select
This commit is contained in:
parent
da7830ea50
commit
ad251b9deb
|
@ -24,13 +24,23 @@ func Create(scope *Scope) {
|
||||||
if !scope.HasError() {
|
if !scope.HasError() {
|
||||||
// set create sql
|
// set create sql
|
||||||
var sqls, columns []string
|
var sqls, columns []string
|
||||||
for _, field := range scope.Fields() {
|
fields := scope.Fields()
|
||||||
if (field.IsNormal && !field.IsPrimaryKey) || (field.IsPrimaryKey && !field.IsBlank) {
|
for _, field := range fields {
|
||||||
|
if scope.ValidField(field) {
|
||||||
|
if field.IsNormal {
|
||||||
|
if !field.IsPrimaryKey || (field.IsPrimaryKey && !field.IsBlank) {
|
||||||
if !field.IsBlank || !field.HasDefaultValue {
|
if !field.IsBlank || !field.HasDefaultValue {
|
||||||
columns = append(columns, scope.Quote(field.DBName))
|
columns = append(columns, scope.Quote(field.DBName))
|
||||||
sqls = append(sqls, scope.AddToVars(field.Field.Interface()))
|
sqls = append(sqls, scope.AddToVars(field.Field.Interface()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
|
||||||
|
if relationField := fields[relationship.ForeignDBName]; !scope.ValidField(relationField) {
|
||||||
|
columns = append(columns, scope.Quote(relationField.DBName))
|
||||||
|
sqls = append(sqls, scope.AddToVars(relationField.Field.Interface()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
returningKey := "*"
|
returningKey := "*"
|
||||||
|
|
|
@ -12,7 +12,7 @@ func CommitOrRollbackTransaction(scope *Scope) {
|
||||||
|
|
||||||
func SaveBeforeAssociations(scope *Scope) {
|
func SaveBeforeAssociations(scope *Scope) {
|
||||||
for _, field := range scope.Fields() {
|
for _, field := range scope.Fields() {
|
||||||
if !field.IsBlank && !field.IsIgnored {
|
if scope.ValidField(field) && !field.IsBlank && !field.IsIgnored {
|
||||||
if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
|
if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
|
||||||
value := field.Field
|
value := field.Field
|
||||||
scope.Err(scope.NewDB().Save(value.Addr().Interface()).Error)
|
scope.Err(scope.NewDB().Save(value.Addr().Interface()).Error)
|
||||||
|
@ -26,7 +26,7 @@ func SaveBeforeAssociations(scope *Scope) {
|
||||||
|
|
||||||
func SaveAfterAssociations(scope *Scope) {
|
func SaveAfterAssociations(scope *Scope) {
|
||||||
for _, field := range scope.Fields() {
|
for _, field := range scope.Fields() {
|
||||||
if !field.IsBlank && !field.IsIgnored {
|
if scope.ValidField(field) && !field.IsBlank && !field.IsIgnored {
|
||||||
if relationship := field.Relationship; relationship != nil &&
|
if relationship := field.Relationship; relationship != nil &&
|
||||||
(relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") {
|
(relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") {
|
||||||
value := field.Field
|
value := field.Field
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package gorm_test
|
package gorm_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
@ -121,3 +122,22 @@ func TestAnonymousField(t *testing.T) {
|
||||||
t.Errorf("Should be able to get anonymous field")
|
t.Errorf("Should be able to get anonymous field")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSelectCreate(t *testing.T) {
|
||||||
|
user := getPreparedUser("user1", "select_create")
|
||||||
|
DB.Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Create(&user)
|
||||||
|
|
||||||
|
var user2 User
|
||||||
|
DB.Preload("BillingAddress").Preload("ShippingAddress").
|
||||||
|
Preload("CreditCard").Preload("Emails").Preload("Company").First(&user2, user.Id)
|
||||||
|
|
||||||
|
if user2.Name != user.Name || user2.Age == user.Age {
|
||||||
|
t.Errorf("Should only create users with name column")
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println(user2.CreditCard.ID)
|
||||||
|
if user2.BillingAddressID.Int64 == 0 || user2.ShippingAddressId != 0 ||
|
||||||
|
user2.CreditCard.ID == 0 || len(user2.Emails) == 0 {
|
||||||
|
t.Errorf("Should only create users with name column")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
14
main_test.go
14
main_test.go
|
@ -89,8 +89,12 @@ func TestExceptionsWithInvalidSql(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSetTable(t *testing.T) {
|
func TestSetTable(t *testing.T) {
|
||||||
if DB.Table("users").Pluck("age", &[]int{}).Error != nil {
|
DB.Create(getPreparedUser("pluck_user1", "pluck_user"))
|
||||||
t.Errorf("No errors should happen if set table for pluck")
|
DB.Create(getPreparedUser("pluck_user2", "pluck_user"))
|
||||||
|
DB.Create(getPreparedUser("pluck_user3", "pluck_user"))
|
||||||
|
|
||||||
|
if err := DB.Table("users").Where("role = ?", "pluck_user").Pluck("age", &[]int{}).Error; err != nil {
|
||||||
|
t.Errorf("No errors should happen if set table for pluck", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
var users []User
|
var users []User
|
||||||
|
@ -115,9 +119,11 @@ func TestSetTable(t *testing.T) {
|
||||||
t.Errorf("Query from specified table")
|
t.Errorf("Query from specified table")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DB.Save(getPreparedUser("normal_user", "reset_table"))
|
||||||
|
DB.Table("deleted_users").Save(getPreparedUser("deleted_user", "reset_table"))
|
||||||
var user1, user2, user3 User
|
var user1, user2, user3 User
|
||||||
DB.First(&user1).Table("deleted_users").First(&user2).Table("").First(&user3)
|
DB.Where("role = ?", "reset_table").First(&user1).Table("deleted_users").First(&user2).Table("").First(&user3)
|
||||||
if (user1.Name == user2.Name) || (user1.Name != user3.Name) {
|
if (user1.Name != "normal_user") || (user2.Name != "deleted_user") || (user3.Name != "normal_user") {
|
||||||
t.Errorf("unset specified table with blank string")
|
t.Errorf("unset specified table with blank string")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,29 +1,9 @@
|
||||||
package gorm_test
|
package gorm_test
|
||||||
|
|
||||||
import (
|
import "testing"
|
||||||
"fmt"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func getPreloadUser(name string) User {
|
func getPreloadUser(name string) *User {
|
||||||
var company Company
|
return getPreparedUser(name, "Preload")
|
||||||
DB.Where(Company{Name: "preload"}).FirstOrCreate(&company)
|
|
||||||
|
|
||||||
return User{
|
|
||||||
Name: name,
|
|
||||||
Role: Role{"Preload"},
|
|
||||||
BillingAddress: Address{Address1: fmt.Sprintf("Billing Address %v", name)},
|
|
||||||
ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)},
|
|
||||||
CreditCard: CreditCard{Number: fmt.Sprintf("123456%v", name)},
|
|
||||||
Emails: []Email{
|
|
||||||
{Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)},
|
|
||||||
},
|
|
||||||
Company: company,
|
|
||||||
Languages: []Language{
|
|
||||||
{Name: fmt.Sprintf("lang_1_%v", name)},
|
|
||||||
{Name: fmt.Sprintf("lang_2_%v", name)},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkUserHasPreloadData(user User, t *testing.T) {
|
func checkUserHasPreloadData(user User, t *testing.T) {
|
||||||
|
@ -64,7 +44,7 @@ func checkUserHasPreloadData(user User, t *testing.T) {
|
||||||
|
|
||||||
func TestPreload(t *testing.T) {
|
func TestPreload(t *testing.T) {
|
||||||
user1 := getPreloadUser("user1")
|
user1 := getPreloadUser("user1")
|
||||||
DB.Save(&user1)
|
DB.Save(user1)
|
||||||
|
|
||||||
preloadDB := DB.Where("role = ?", "Preload").Preload("BillingAddress").Preload("ShippingAddress").
|
preloadDB := DB.Where("role = ?", "Preload").Preload("BillingAddress").Preload("ShippingAddress").
|
||||||
Preload("CreditCard").Preload("Emails").Preload("Company")
|
Preload("CreditCard").Preload("Emails").Preload("Company")
|
||||||
|
@ -73,10 +53,10 @@ func TestPreload(t *testing.T) {
|
||||||
checkUserHasPreloadData(user, t)
|
checkUserHasPreloadData(user, t)
|
||||||
|
|
||||||
user2 := getPreloadUser("user2")
|
user2 := getPreloadUser("user2")
|
||||||
DB.Save(&user2)
|
DB.Save(user2)
|
||||||
|
|
||||||
user3 := getPreloadUser("user3")
|
user3 := getPreloadUser("user3")
|
||||||
DB.Save(&user3)
|
DB.Save(user3)
|
||||||
|
|
||||||
var users []User
|
var users []User
|
||||||
preloadDB.Find(&users)
|
preloadDB.Find(&users)
|
||||||
|
|
39
scope.go
39
scope.go
|
@ -333,3 +333,42 @@ func (scope *Scope) CommitOrRollback() *Scope {
|
||||||
}
|
}
|
||||||
return scope
|
return scope
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) SelectAttrs() (attrs []string) {
|
||||||
|
for _, value := range scope.Search.selects {
|
||||||
|
if str, ok := value.(string); ok {
|
||||||
|
attrs = append(attrs, str)
|
||||||
|
} else if strs, ok := value.([]interface{}); ok {
|
||||||
|
for _, str := range strs {
|
||||||
|
attrs = append(attrs, fmt.Sprintf("%v", str))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return attrs
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) OmitAttrs() []string {
|
||||||
|
return scope.Search.omits
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) ValidField(field *Field) bool {
|
||||||
|
selectAttrs := scope.SelectAttrs()
|
||||||
|
omitAttrs := scope.OmitAttrs()
|
||||||
|
|
||||||
|
if len(selectAttrs) > 0 {
|
||||||
|
for _, attr := range selectAttrs {
|
||||||
|
if field.Name == attr || field.DBName == attr {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, attr := range omitAttrs {
|
||||||
|
if field.Name == attr || field.DBName == attr {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return !field.IsIgnored
|
||||||
|
}
|
||||||
|
|
12
search.go
12
search.go
|
@ -89,18 +89,6 @@ func (s *search) Omit(columns ...string) *search {
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *search) SelectAttrs() (attrs []string) {
|
|
||||||
for key, value := range s.selects {
|
|
||||||
attrs = append(attrs, key)
|
|
||||||
attrs = append(attrs, value.([]string)...)
|
|
||||||
}
|
|
||||||
return attrs
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *search) OmitAttrs() []string {
|
|
||||||
return s.omits
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *search) Limit(value interface{}) *search {
|
func (s *search) Limit(value interface{}) *search {
|
||||||
s.limit = s.getInterfaceAsSql(value)
|
s.limit = s.getInterfaceAsSql(value)
|
||||||
return s
|
return s
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"time"
|
"time"
|
||||||
|
@ -194,3 +195,25 @@ func (nt NullTime) Value() (driver.Value, error) {
|
||||||
}
|
}
|
||||||
return nt.Time, nil
|
return nt.Time, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getPreparedUser(name string, role string) *User {
|
||||||
|
var company Company
|
||||||
|
DB.Where(Company{Name: role}).FirstOrCreate(&company)
|
||||||
|
|
||||||
|
return &User{
|
||||||
|
Name: name,
|
||||||
|
Age: 20,
|
||||||
|
Role: Role{role},
|
||||||
|
BillingAddress: Address{Address1: fmt.Sprintf("Billing Address %v", name)},
|
||||||
|
ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)},
|
||||||
|
CreditCard: CreditCard{Number: fmt.Sprintf("123456%v", name)},
|
||||||
|
Emails: []Email{
|
||||||
|
{Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)},
|
||||||
|
},
|
||||||
|
Company: company,
|
||||||
|
Languages: []Language{
|
||||||
|
{Name: fmt.Sprintf("lang_1_%v", name)},
|
||||||
|
{Name: fmt.Sprintf("lang_2_%v", name)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue