mirror of https://github.com/go-gorm/gorm.git
Refactor Scope
This commit is contained in:
parent
9c7ff3d211
commit
953c347ba7
|
@ -28,27 +28,29 @@ func (association *Association) Find(value interface{}) *Association {
|
|||
|
||||
func (association *Association) Append(values ...interface{}) *Association {
|
||||
scope := association.Scope
|
||||
field := scope.IndirectValue().FieldByName(association.Column)
|
||||
field := association.Field
|
||||
fieldType := field.Field.Type()
|
||||
|
||||
for _, value := range values {
|
||||
reflectvalue := reflect.ValueOf(value)
|
||||
if reflectvalue.Kind() == reflect.Ptr {
|
||||
if reflectvalue.Elem().Kind() == reflect.Struct {
|
||||
if field.Type().Elem().Kind() == reflect.Ptr {
|
||||
field.Set(reflect.Append(field, reflectvalue))
|
||||
} else if field.Type().Elem().Kind() == reflect.Struct {
|
||||
field.Set(reflect.Append(field, reflectvalue.Elem()))
|
||||
if fieldType.Elem().Kind() == reflect.Ptr {
|
||||
field.Set(reflect.Append(field.Field, reflectvalue))
|
||||
} else if fieldType.Elem().Kind() == reflect.Struct {
|
||||
field.Set(reflect.Append(field.Field, reflectvalue.Elem()))
|
||||
}
|
||||
} else if reflectvalue.Elem().Kind() == reflect.Slice {
|
||||
if field.Type().Elem().Kind() == reflect.Ptr {
|
||||
field.Set(reflect.AppendSlice(field, reflectvalue))
|
||||
} else if field.Type().Elem().Kind() == reflect.Struct {
|
||||
field.Set(reflect.AppendSlice(field, reflectvalue.Elem()))
|
||||
if fieldType.Elem().Kind() == reflect.Ptr {
|
||||
field.Set(reflect.AppendSlice(field.Field, reflectvalue))
|
||||
} else if fieldType.Elem().Kind() == reflect.Struct {
|
||||
field.Set(reflect.AppendSlice(field.Field, reflectvalue.Elem()))
|
||||
}
|
||||
}
|
||||
} else if reflectvalue.Kind() == reflect.Struct && field.Type().Elem().Kind() == reflect.Struct {
|
||||
field.Set(reflect.Append(field, reflectvalue))
|
||||
} else if reflectvalue.Kind() == reflect.Slice && field.Type().Elem() == reflectvalue.Type().Elem() {
|
||||
field.Set(reflect.AppendSlice(field, reflectvalue))
|
||||
} else if reflectvalue.Kind() == reflect.Struct && fieldType.Elem().Kind() == reflect.Struct {
|
||||
field.Set(reflect.Append(field.Field, reflectvalue))
|
||||
} else if reflectvalue.Kind() == reflect.Slice && fieldType.Elem() == reflectvalue.Type().Elem() {
|
||||
field.Set(reflect.AppendSlice(field.Field, reflectvalue))
|
||||
} else {
|
||||
association.err(errors.New("invalid association type"))
|
||||
}
|
||||
|
@ -107,7 +109,7 @@ func (association *Association) Replace(values ...interface{}) *Association {
|
|||
relationship := association.Field.Relationship
|
||||
scope := association.Scope
|
||||
if relationship.Kind == "many_to_many" {
|
||||
field := scope.IndirectValue().FieldByName(association.Column)
|
||||
field := association.Field.Field
|
||||
|
||||
oldPrimaryKeys := association.getPrimaryKeys(field.Interface())
|
||||
association.Append(values...)
|
||||
|
@ -154,7 +156,7 @@ func (association *Association) Count() int {
|
|||
count := -1
|
||||
relationship := association.Field.Relationship
|
||||
scope := association.Scope
|
||||
field := scope.IndirectValue().FieldByName(association.Column)
|
||||
field := association.Field.Field
|
||||
fieldValue := field.Interface()
|
||||
newScope := scope.New(fieldValue)
|
||||
|
||||
|
|
|
@ -158,10 +158,12 @@ func TestManyToMany(t *testing.T) {
|
|||
languageA := Language{Name: "AA"}
|
||||
DB.Save(&languageA)
|
||||
DB.Model(&User{Id: user.Id}).Association("Languages").Append(languageA)
|
||||
|
||||
languageC := Language{Name: "CC"}
|
||||
DB.Save(&languageC)
|
||||
DB.Model(&user).Association("Languages").Append(&[]Language{{Name: "BB"}, languageC})
|
||||
DB.Model(&User{Id: user.Id}).Association("Languages").Append([]Language{{Name: "DD"}, {Name: "EE"}})
|
||||
|
||||
DB.Model(&User{Id: user.Id}).Association("Languages").Append(&[]Language{{Name: "DD"}, {Name: "EE"}})
|
||||
|
||||
totalLanguages := []string{"ZH", "EN", "DE", "AA", "BB", "CC", "DD", "EE"}
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ func Create(scope *Scope) {
|
|||
for _, field := range scope.Fields() {
|
||||
if field.IsNormal && (!field.IsPrimaryKey || !scope.PrimaryKeyZero()) {
|
||||
columns = append(columns, scope.Quote(field.DBName))
|
||||
sqls = append(sqls, scope.AddToVars(field.Value))
|
||||
sqls = append(sqls, scope.AddToVars(field.Field.Interface()))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -19,20 +19,18 @@ func SaveBeforeAssociations(scope *Scope) {
|
|||
if !field.IsBlank && !field.IsIgnored {
|
||||
relationship := field.Relationship
|
||||
if relationship != nil && relationship.Kind == "belongs_to" {
|
||||
value := reflect.ValueOf(field.Value)
|
||||
value := field.Field
|
||||
newDB := scope.NewDB()
|
||||
|
||||
if value.CanAddr() {
|
||||
scope.Err(newDB.Save(value.Addr().Interface()).Error)
|
||||
} else {
|
||||
if !value.CanAddr() {
|
||||
// If can't take address, then clone the value and set it back
|
||||
value = reflect.New(reflect.ValueOf(field.Value).Type()).Elem()
|
||||
for _, f := range newDB.NewScope(field.Value).Fields() {
|
||||
value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value))
|
||||
value = reflect.New(value.Type()).Elem()
|
||||
for _, f := range newDB.NewScope(field.Field.Interface()).Fields() {
|
||||
value.FieldByName(f.Name).Set(reflect.ValueOf(f.Field.Interface()))
|
||||
}
|
||||
scope.Err(newDB.Save(value.Addr().Interface()).Error)
|
||||
scope.SetColumn(field.Name, value.Interface())
|
||||
}
|
||||
scope.Err(newDB.Save(value.Addr().Interface()).Error)
|
||||
|
||||
if relationship.ForeignKey != "" {
|
||||
scope.SetColumn(relationship.ForeignKey, newDB.NewScope(value.Interface()).PrimaryKeyValue())
|
||||
|
@ -48,7 +46,7 @@ func SaveAfterAssociations(scope *Scope) {
|
|||
relationship := field.Relationship
|
||||
if relationship != nil &&
|
||||
(relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") {
|
||||
value := reflect.ValueOf(field.Value)
|
||||
value := field.Field
|
||||
|
||||
switch value.Kind() {
|
||||
case reflect.Slice:
|
||||
|
@ -89,14 +87,14 @@ func SaveAfterAssociations(scope *Scope) {
|
|||
newDB := scope.NewDB()
|
||||
if value.CanAddr() {
|
||||
if relationship.ForeignKey != "" {
|
||||
newDB.NewScope(field.Value).SetColumn(relationship.ForeignKey, scope.PrimaryKeyValue())
|
||||
newDB.NewScope(value.Addr().Interface()).SetColumn(relationship.ForeignKey, scope.PrimaryKeyValue())
|
||||
}
|
||||
scope.Err(newDB.Save(field.Value).Error)
|
||||
scope.Err(newDB.Save(value.Addr().Interface()).Error)
|
||||
} else {
|
||||
destValue := reflect.New(reflect.TypeOf(field.Value)).Elem()
|
||||
destValue := reflect.New(field.Field.Type()).Elem()
|
||||
|
||||
for _, f := range newDB.NewScope(field.Value).Fields() {
|
||||
destValue.FieldByName(f.Name).Set(reflect.ValueOf(f.Value))
|
||||
for _, f := range newDB.NewScope(field.Field.Interface()).Fields() {
|
||||
destValue.FieldByName(f.Name).Set(f.Field)
|
||||
}
|
||||
|
||||
elem := destValue.Addr().Interface()
|
||||
|
|
|
@ -49,7 +49,7 @@ func Update(scope *Scope) {
|
|||
} else {
|
||||
for _, field := range scope.Fields() {
|
||||
if !field.IsPrimaryKey && field.IsNormal && !field.IsIgnored {
|
||||
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Value)))
|
||||
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
26
field.go
26
field.go
|
@ -17,7 +17,6 @@ type Field struct {
|
|||
Name string
|
||||
DBName string
|
||||
Field reflect.Value
|
||||
Value interface{}
|
||||
Tag reflect.StructTag
|
||||
Relationship *relationship
|
||||
IsNormal bool
|
||||
|
@ -26,12 +25,29 @@ type Field struct {
|
|||
IsPrimaryKey bool
|
||||
}
|
||||
|
||||
func (f *Field) IsScanner() bool {
|
||||
_, isScanner := reflect.New(reflect.ValueOf(f.Value).Type()).Interface().(sql.Scanner)
|
||||
func (field *Field) IsScanner() bool {
|
||||
_, isScanner := reflect.New(field.Field.Type()).Interface().(sql.Scanner)
|
||||
return isScanner
|
||||
}
|
||||
|
||||
func (f *Field) IsTime() bool {
|
||||
_, isTime := f.Value.(time.Time)
|
||||
func (field *Field) IsTime() bool {
|
||||
_, isTime := field.Field.Interface().(time.Time)
|
||||
return isTime
|
||||
}
|
||||
|
||||
func (field *Field) Set(value interface{}) (result bool) {
|
||||
if field.Field.IsValid() && field.Field.CanAddr() {
|
||||
result = true
|
||||
if scanner, ok := field.Field.Addr().Interface().(sql.Scanner); ok {
|
||||
scanner.Scan(value)
|
||||
} else if reflect.TypeOf(value).ConvertibleTo(field.Field.Type()) {
|
||||
field.Field.Set(reflect.ValueOf(value).Convert(field.Field.Type()))
|
||||
} else {
|
||||
result = false
|
||||
}
|
||||
}
|
||||
if result {
|
||||
field.IsBlank = isBlank(field.Field)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
33
scope.go
33
scope.go
|
@ -150,13 +150,17 @@ func (scope *Scope) FieldValueByName(name string) (interface{}, bool) {
|
|||
}
|
||||
|
||||
// SetColumn to set the column's value
|
||||
func (scope *Scope) SetColumn(column string, value interface{}) bool {
|
||||
if scope.Value == nil {
|
||||
return false
|
||||
}
|
||||
for _, field := range scope.Fields() {
|
||||
if field.Name == column || field.DBName == column {
|
||||
return setFieldValue(field.Field, value)
|
||||
func (scope *Scope) SetColumn(column interface{}, value interface{}) bool {
|
||||
if field, ok := column.(*Field); ok {
|
||||
return field.Set(value)
|
||||
} else if str, ok := column.(string); ok {
|
||||
if scope.Value == nil {
|
||||
return false
|
||||
}
|
||||
for _, field := range scope.Fields() {
|
||||
if field.Name == str || field.DBName == str {
|
||||
return field.Set(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
|
@ -267,11 +271,9 @@ func (scope *Scope) CombinedConditionSql() string {
|
|||
}
|
||||
|
||||
func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
|
||||
if scope.Value != nil {
|
||||
if scope.IndirectValue().Kind() == reflect.Struct {
|
||||
if f, ok := scope.IndirectValue().Type().FieldByName(SnakeToUpperCamel(name)); ok {
|
||||
return scope.fieldFromStruct(f, true)[0], true
|
||||
}
|
||||
for _, field := range scope.Fields() {
|
||||
if field.Name == name {
|
||||
return field, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
|
@ -285,7 +287,6 @@ func (scope *Scope) fieldFromStruct(fieldStruct reflect.StructField, withRelatio
|
|||
value := scope.IndirectValue().FieldByName(fieldStruct.Name)
|
||||
indirectValue := reflect.Indirect(value)
|
||||
field.Field = value
|
||||
field.Value = value.Interface()
|
||||
field.IsBlank = isBlank(value)
|
||||
|
||||
// Search for primary key tag identifier
|
||||
|
@ -416,9 +417,9 @@ func (scope *Scope) Fields(noRelations ...bool) map[string]*Field {
|
|||
}
|
||||
}
|
||||
|
||||
// if withRelation {
|
||||
// scope.fields = fields
|
||||
// }
|
||||
if withRelation {
|
||||
scope.fields = fields
|
||||
}
|
||||
|
||||
return fields
|
||||
}
|
||||
|
|
|
@ -42,7 +42,7 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri
|
|||
var sqls []string
|
||||
for _, field := range scope.New(value).Fields() {
|
||||
if !field.IsBlank {
|
||||
sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(field.DBName), scope.AddToVars(field.Value)))
|
||||
sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
|
||||
}
|
||||
}
|
||||
return strings.Join(sqls, " AND ")
|
||||
|
@ -103,7 +103,7 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string
|
|||
var sqls []string
|
||||
for _, field := range scope.New(value).Fields() {
|
||||
if !field.IsBlank {
|
||||
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(field.DBName), scope.AddToVars(field.Value)))
|
||||
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
|
||||
}
|
||||
}
|
||||
return strings.Join(sqls, " AND ")
|
||||
|
@ -264,17 +264,17 @@ func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignore
|
|||
}
|
||||
|
||||
for key, value := range values {
|
||||
if field := data.FieldByName(SnakeToUpperCamel(key)); field.IsValid() {
|
||||
if field, ok := scope.FieldByName(SnakeToUpperCamel(key)); ok && field.Field.IsValid() {
|
||||
func() {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
hasUpdate = true
|
||||
setFieldValue(field, value)
|
||||
field.Set(value)
|
||||
}
|
||||
}()
|
||||
|
||||
if field.Interface() != value {
|
||||
switch field.Kind() {
|
||||
if field.Field.Interface() != value {
|
||||
switch field.Field.Kind() {
|
||||
case reflect.Int, reflect.Int32, reflect.Int64:
|
||||
if s, ok := value.(string); ok {
|
||||
i, err := strconv.Atoi(s)
|
||||
|
@ -283,13 +283,13 @@ func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignore
|
|||
}
|
||||
}
|
||||
|
||||
if field.Int() != reflect.ValueOf(value).Int() {
|
||||
if field.Field.Int() != reflect.ValueOf(value).Int() {
|
||||
hasUpdate = true
|
||||
setFieldValue(field, value)
|
||||
field.Set(value)
|
||||
}
|
||||
default:
|
||||
hasUpdate = true
|
||||
setFieldValue(field, value)
|
||||
field.Set(value)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
@ -324,8 +324,8 @@ func (scope *Scope) sqlTagForField(field *Field) (typ string) {
|
|||
additionalType = additionalType + "DEFAULT " + value
|
||||
}
|
||||
|
||||
value := field.Value
|
||||
reflectValue := reflect.ValueOf(value)
|
||||
value := field.Field.Interface()
|
||||
reflectValue := field.Field
|
||||
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Slice:
|
||||
|
|
4
utils.go
4
utils.go
|
@ -68,7 +68,7 @@ func ToSnake(u string) string {
|
|||
}
|
||||
|
||||
s := strings.ToLower(buf.String())
|
||||
go smap.Set(u, s)
|
||||
smap.Set(u, s)
|
||||
return s
|
||||
}
|
||||
|
||||
|
@ -86,7 +86,7 @@ func SnakeToUpperCamel(s string) string {
|
|||
}
|
||||
|
||||
u := buf.String()
|
||||
go umap.Set(s, u)
|
||||
umap.Set(s, u)
|
||||
return u
|
||||
}
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"reflect"
|
||||
|
@ -11,7 +10,7 @@ import (
|
|||
)
|
||||
|
||||
func fileWithLineNum() string {
|
||||
for i := 1; i < 15; i++ {
|
||||
for i := 2; i < 15; i++ {
|
||||
_, file, line, ok := runtime.Caller(i)
|
||||
if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) {
|
||||
return fmt.Sprintf("%v:%v", strings.TrimPrefix(file, os.Getenv("GOPATH")+"src/"), line)
|
||||
|
@ -20,21 +19,6 @@ func fileWithLineNum() string {
|
|||
return ""
|
||||
}
|
||||
|
||||
func setFieldValue(field reflect.Value, value interface{}) (result bool) {
|
||||
result = false
|
||||
if field.IsValid() && field.CanAddr() {
|
||||
result = true
|
||||
if scanner, ok := field.Addr().Interface().(sql.Scanner); ok {
|
||||
scanner.Scan(value)
|
||||
} else if reflect.TypeOf(value).ConvertibleTo(field.Type()) {
|
||||
field.Set(reflect.ValueOf(value).Convert(field.Type()))
|
||||
} else {
|
||||
result = false
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func isBlank(value reflect.Value) bool {
|
||||
return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface())
|
||||
}
|
||||
|
@ -82,7 +66,7 @@ func convertInterfaceToMap(values interface{}) map[string]interface{} {
|
|||
scope := Scope{Value: values}
|
||||
for _, field := range scope.Fields() {
|
||||
if !field.IsBlank {
|
||||
attrs[field.DBName] = field.Value
|
||||
attrs[field.DBName] = field.Field.Interface()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue