Merge branch 'master' into v1.0_dev

This commit is contained in:
Jinzhu 2016-01-16 11:40:08 +08:00
commit c84e787b1d
7 changed files with 68 additions and 22 deletions

View File

@ -600,7 +600,9 @@ func TestRelated(t *testing.T) {
Company: Company{Name: "company1"}, Company: Company{Name: "company1"},
} }
DB.Save(&user) if err := DB.Save(&user).Error; err != nil {
t.Errorf("No error should happen when saving user")
}
if user.CreditCard.ID == 0 { if user.CreditCard.ID == 0 {
t.Errorf("After user save, credit card should have id") t.Errorf("After user save, credit card should have id")

View File

@ -13,7 +13,7 @@ type Field struct {
Field reflect.Value Field reflect.Value
} }
func (field *Field) Set(value interface{}) error { func (field *Field) Set(value interface{}) (err error) {
if !field.Field.IsValid() { if !field.Field.IsValid() {
return errors.New("field value not valid") return errors.New("field value not valid")
} }
@ -27,15 +27,25 @@ func (field *Field) Set(value interface{}) error {
reflectValue = reflect.ValueOf(value) reflectValue = reflect.ValueOf(value)
} }
fieldValue := field.Field
if reflectValue.IsValid() { if reflectValue.IsValid() {
if reflectValue.Type().ConvertibleTo(field.Field.Type()) { if reflectValue.Type().ConvertibleTo(fieldValue.Type()) {
field.Field.Set(reflectValue.Convert(field.Field.Type())) fieldValue.Set(reflectValue.Convert(fieldValue.Type()))
} else if scanner, ok := field.Field.Addr().Interface().(sql.Scanner); ok {
if err := scanner.Scan(reflectValue.Interface()); err != nil {
return err
}
} else { } else {
return fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), field.Field.Type()) if fieldValue.Kind() == reflect.Ptr {
if fieldValue.IsNil() {
fieldValue.Set(reflect.New(field.Struct.Type.Elem()))
}
fieldValue = fieldValue.Elem()
}
if reflectValue.Type().ConvertibleTo(fieldValue.Type()) {
fieldValue.Set(reflectValue.Convert(fieldValue.Type()))
} else if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
err = scanner.Scan(reflectValue.Interface())
} else {
err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type())
}
} }
} else { } else {
field.Field.Set(reflect.Zero(field.Field.Type())) field.Field.Set(reflect.Zero(field.Field.Type()))

View File

@ -7,7 +7,6 @@ import (
"strings" "strings"
) )
// Preload preload relations callback
func Preload(scope *Scope) { func Preload(scope *Scope) {
if scope.Search.preload == nil || scope.HasError() { if scope.Search.preload == nil || scope.HasError() {
return return

View File

@ -207,10 +207,12 @@ func TestSearchWithStruct(t *testing.T) {
} }
func TestSearchWithMap(t *testing.T) { func TestSearchWithMap(t *testing.T) {
companyID := 1
user1 := User{Name: "MapSearchUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} user1 := User{Name: "MapSearchUser1", Age: 1, Birthday: now.MustParse("2000-1-1")}
user2 := User{Name: "MapSearchUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} user2 := User{Name: "MapSearchUser2", Age: 10, Birthday: now.MustParse("2010-1-1")}
user3 := User{Name: "MapSearchUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} user3 := User{Name: "MapSearchUser3", Age: 20, Birthday: now.MustParse("2020-1-1")}
DB.Save(&user1).Save(&user2).Save(&user3) user4 := User{Name: "MapSearchUser4", Age: 30, Birthday: now.MustParse("2020-1-1"), CompanyID: &companyID}
DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4)
var user User var user User
DB.First(&user, map[string]interface{}{"name": user1.Name}) DB.First(&user, map[string]interface{}{"name": user1.Name})
@ -234,6 +236,21 @@ func TestSearchWithMap(t *testing.T) {
if len(users) != 1 { if len(users) != 1 {
t.Errorf("Search all records with inline map") t.Errorf("Search all records with inline map")
} }
DB.Find(&users, map[string]interface{}{"name": user4.Name, "company_id": nil})
if len(users) != 0 {
t.Errorf("Search all records with inline map containing null value finding 0 records")
}
DB.Find(&users, map[string]interface{}{"name": user1.Name, "company_id": nil})
if len(users) != 1 {
t.Errorf("Search all records with inline map containing null value finding 1 record")
}
DB.Find(&users, map[string]interface{}{"name": user4.Name, "company_id": companyID})
if len(users) != 1 {
t.Errorf("Search all records with inline multiple value map")
}
} }
func TestSearchWithEmptyChain(t *testing.T) { func TestSearchWithEmptyChain(t *testing.T) {
@ -377,10 +394,14 @@ func TestNot(t *testing.T) {
DB.Create(getPreparedUser("user1", "not")) DB.Create(getPreparedUser("user1", "not"))
DB.Create(getPreparedUser("user2", "not")) DB.Create(getPreparedUser("user2", "not"))
DB.Create(getPreparedUser("user3", "not")) DB.Create(getPreparedUser("user3", "not"))
DB.Create(getPreparedUser("user4", "not"))
user4 := getPreparedUser("user4", "not")
user4.Company = Company{}
DB.Create(user4)
DB := DB.Where("role = ?", "not") DB := DB.Where("role = ?", "not")
var users1, users2, users3, users4, users5, users6, users7, users8 []User var users1, users2, users3, users4, users5, users6, users7, users8, users9 []User
if DB.Find(&users1).RowsAffected != 4 { if DB.Find(&users1).RowsAffected != 4 {
t.Errorf("should find 4 not users") t.Errorf("should find 4 not users")
} }
@ -423,15 +444,20 @@ func TestNot(t *testing.T) {
t.Errorf("Should find all users's name not equal 3") t.Errorf("Should find all users's name not equal 3")
} }
DB.Not("name", []string{"user3"}).Find(&users7) DB.Not(map[string]interface{}{"name": "user3", "company_id": nil}).Find(&users7)
if len(users1)-len(users7) != int(name3Count) { if len(users1)-len(users7) != 2 { // not user3 or user4
t.Errorf("Should find all user's name not equal to 3 who do not have company id")
}
DB.Not("name", []string{"user3"}).Find(&users8)
if len(users1)-len(users8) != int(name3Count) {
t.Errorf("Should find all users's name not equal 3") t.Errorf("Should find all users's name not equal 3")
} }
var name2Count int64 var name2Count int64
DB.Table("users").Where("name = ?", "user2").Count(&name2Count) DB.Table("users").Where("name = ?", "user2").Count(&name2Count)
DB.Not("name", []string{"user3", "user2"}).Find(&users8) DB.Not("name", []string{"user3", "user2"}).Find(&users9)
if len(users1)-len(users8) != (int(name3Count) + int(name2Count)) { if len(users1)-len(users9) != (int(name3Count) + int(name2Count)) {
t.Errorf("Should find all users's name not equal 3") t.Errorf("Should find all users's name not equal 3")
} }
} }

View File

@ -32,7 +32,11 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri
case map[string]interface{}: case map[string]interface{}:
var sqls []string var sqls []string
for key, value := range value { for key, value := range value {
sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(key), scope.AddToVars(value))) if value != nil {
sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(key), scope.AddToVars(value)))
} else {
sqls = append(sqls, fmt.Sprintf("(%v IS NULL)", scope.Quote(key)))
}
} }
return strings.Join(sqls, " AND ") return strings.Join(sqls, " AND ")
case interface{}: case interface{}:
@ -98,7 +102,11 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string
case map[string]interface{}: case map[string]interface{}:
var sqls []string var sqls []string
for key, value := range value { for key, value := range value {
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(key), scope.AddToVars(value))) if value != nil {
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(key), scope.AddToVars(value)))
} else {
sqls = append(sqls, fmt.Sprintf("(%v IS NOT NULL)", scope.Quote(key)))
}
} }
return strings.Join(sqls, " AND ") return strings.Join(sqls, " AND ")
case interface{}: case interface{}:

View File

@ -28,7 +28,7 @@ type User struct {
CreditCard CreditCard CreditCard CreditCard
Latitude float64 Latitude float64
Languages []Language `gorm:"many2many:user_languages;"` Languages []Language `gorm:"many2many:user_languages;"`
CompanyID int64 CompanyID *int
Company Company Company Company
Role Role
PasswordHash []byte PasswordHash []byte

View File

@ -87,9 +87,10 @@ func toString(str interface{}) string {
return strings.Join(results, "_") return strings.Join(results, "_")
} else if bytes, ok := str.([]byte); ok { } else if bytes, ok := str.([]byte); ok {
return string(bytes) return string(bytes)
} else { } else if reflectValue := reflect.Indirect(reflect.ValueOf(str)); reflectValue.IsValid() {
return fmt.Sprintf("%v", str) return fmt.Sprintf("%v", reflectValue.Interface())
} }
return ""
} }
func makeSlice(elemType reflect.Type) interface{} { func makeSlice(elemType reflect.Type) interface{} {