Review and Refactor

This commit is contained in:
Jinzhu 2015-02-17 22:55:14 +08:00
parent 38cbff9c79
commit 0b32041135
18 changed files with 264 additions and 381 deletions

View File

@ -9,7 +9,6 @@ import (
type Association struct { type Association struct {
Scope *Scope Scope *Scope
PrimaryKey interface{} PrimaryKey interface{}
PrimaryType interface{}
Column string Column string
Error error Error error
Field *Field Field *Field
@ -158,11 +157,11 @@ func (association *Association) Count() int {
whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName)) whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName))
countScope := scope.db.Model("").Table(newScope.QuotedTableName()).Where(whereSql, association.PrimaryKey) countScope := scope.db.Model("").Table(newScope.QuotedTableName()).Where(whereSql, association.PrimaryKey)
if relationship.ForeignType != "" { if relationship.ForeignType != "" {
countScope = countScope.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(ToSnake(relationship.ForeignType))), association.PrimaryType) countScope = countScope.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(ToSnake(relationship.ForeignType))), scope.TableName())
} }
countScope.Count(&count) countScope.Count(&count)
} else if relationship.Kind == "belongs_to" { } else if relationship.Kind == "belongs_to" {
if v, err := scope.FieldValueByName(association.Column); err == nil { if v, ok := scope.FieldByName(association.Column); ok {
whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName)) whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName))
scope.db.Model("").Table(newScope.QuotedTableName()).Where(whereSql, v).Count(&count) scope.db.Model("").Table(newScope.QuotedTableName()).Where(whereSql, v).Count(&count)
} }

View File

@ -26,13 +26,12 @@ func Create(scope *Scope) {
var sqls, columns []string var sqls, columns []string
for _, field := range scope.Fields() { for _, field := range scope.Fields() {
if (field.IsNormal && !field.IsPrimaryKey) || (field.IsPrimaryKey && !field.IsBlank) { if (field.IsNormal && !field.IsPrimaryKey) || (field.IsPrimaryKey && !field.IsBlank) {
if field.DefaultValue != nil && field.IsBlank { if !field.IsBlank || field.DefaultValue == nil {
continue
}
columns = append(columns, scope.Quote(field.DBName)) columns = append(columns, scope.Quote(field.DBName))
sqls = append(sqls, scope.AddToVars(field.Field.Interface())) sqls = append(sqls, scope.AddToVars(field.Field.Interface()))
} }
} }
}
returningKey := "*" returningKey := "*"
primaryField := scope.PrimaryKeyField() primaryField := scope.PrimaryKeyField()

View File

@ -41,21 +41,19 @@ func Update(scope *Scope) {
if !scope.HasError() { if !scope.HasError() {
var sqls []string var sqls []string
updateAttrs, ok := scope.InstanceGet("gorm:update_attrs") if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
if ok {
for key, value := range updateAttrs.(map[string]interface{}) { for key, value := range updateAttrs.(map[string]interface{}) {
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(key), scope.AddToVars(value))) sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(key), scope.AddToVars(value)))
} }
} else { } else {
for _, field := range scope.Fields() { for _, field := range scope.Fields() {
if !field.IsPrimaryKey && field.IsNormal && !field.IsIgnored { if !field.IsPrimaryKey && field.IsNormal {
if field.DefaultValue != nil && field.IsBlank { if !field.IsBlank || field.DefaultValue == nil {
continue
}
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
} }
} }
} }
}
scope.Raw(fmt.Sprintf( scope.Raw(fmt.Sprintf(
"UPDATE %v SET %v %v", "UPDATE %v SET %v %v",
@ -68,8 +66,7 @@ func Update(scope *Scope) {
} }
func AfterUpdate(scope *Scope) { func AfterUpdate(scope *Scope) {
_, ok := scope.Get("gorm:update_column") if _, ok := scope.Get("gorm:update_column"); !ok {
if !ok {
scope.CallMethod("AfterUpdate") scope.CallMethod("AfterUpdate")
scope.CallMethod("AfterSave") scope.CallMethod("AfterSave")
} }

View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
"time"
) )
type commonDialect struct{} type commonDialect struct{}
@ -36,7 +37,7 @@ func (s *commonDialect) SqlTag(value reflect.Value, size int) string {
} }
return "VARCHAR(65532)" return "VARCHAR(65532)"
case reflect.Struct: case reflect.Struct:
if value.Type() == timeType { if _, ok := value.Interface().(time.Time); ok {
return "TIMESTAMP" return "TIMESTAMP"
} }
default: default:

View File

@ -3,11 +3,8 @@ package gorm
import ( import (
"fmt" "fmt"
"reflect" "reflect"
"time"
) )
var timeType = reflect.TypeOf(time.Time{})
type Dialect interface { type Dialect interface {
BinVar(i int) string BinVar(i int) string
SupportLastInsertId() bool SupportLastInsertId() bool

View File

@ -12,7 +12,7 @@ type Field struct {
Field reflect.Value Field reflect.Value
} }
func (field *Field) Set(value interface{}) (err error) { func (field *Field) Set(value interface{}) error {
if !field.Field.IsValid() { if !field.Field.IsValid() {
return errors.New("field value not valid") return errors.New("field value not valid")
} }
@ -26,16 +26,26 @@ func (field *Field) Set(value interface{}) (err error) {
} }
if scanner, ok := field.Field.Addr().Interface().(sql.Scanner); ok { if scanner, ok := field.Field.Addr().Interface().(sql.Scanner); ok {
if v, ok := value.(reflect.Value); ok {
scanner.Scan(v.Interface())
} else {
scanner.Scan(value) scanner.Scan(value)
} else if reflect.TypeOf(value).ConvertibleTo(field.Field.Type()) { }
field.Field.Set(reflect.ValueOf(value).Convert(field.Field.Type())) } else {
reflectValue, ok := value.(reflect.Value)
if !ok {
reflectValue = reflect.ValueOf(value)
}
if reflectValue.Type().ConvertibleTo(field.Field.Type()) {
field.Field.Set(reflectValue.Convert(field.Field.Type()))
} else { } else {
return errors.New("could not convert argument") return errors.New("could not convert argument")
} }
}
field.IsBlank = isBlank(field.Field) field.IsBlank = isBlank(field.Field)
return nil
return
} }
// Fields get value's fields // Fields get value's fields
@ -44,8 +54,14 @@ func (scope *Scope) Fields() map[string]*Field {
fields := map[string]*Field{} fields := map[string]*Field{}
structFields := scope.GetStructFields() structFields := scope.GetStructFields()
indirectValue := scope.IndirectValue()
isStruct := indirectValue.Kind() == reflect.Struct
for _, structField := range structFields { for _, structField := range structFields {
fields[structField.DBName] = scope.getField(structField) if isStruct {
fields[structField.DBName] = getField(indirectValue, structField)
} else {
fields[structField.DBName] = &Field{StructField: structField, IsBlank: true}
}
} }
scope.fields = fields scope.fields = fields
@ -53,15 +69,12 @@ func (scope *Scope) Fields() map[string]*Field {
return scope.fields return scope.fields
} }
func (scope *Scope) getField(structField *StructField) *Field { func getField(indirectValue reflect.Value, structField *StructField) *Field {
field := Field{StructField: structField} field := &Field{StructField: structField}
indirectValue := scope.IndirectValue()
if len(structField.Names) > 0 && indirectValue.Kind() == reflect.Struct {
for _, name := range structField.Names { for _, name := range structField.Names {
indirectValue = reflect.Indirect(indirectValue).FieldByName(name) indirectValue = reflect.Indirect(indirectValue).FieldByName(name)
} }
field.Field = indirectValue field.Field = indirectValue
}
field.IsBlank = isBlank(indirectValue) field.IsBlank = isBlank(indirectValue)
return &field return field
} }

View File

@ -21,21 +21,21 @@ var defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)}
// Format log // Format log
var sqlRegexp = regexp.MustCompile(`(\$\d+)|\?`) var sqlRegexp = regexp.MustCompile(`(\$\d+)|\?`)
func (logger Logger) Print(v ...interface{}) { func (logger Logger) Print(values ...interface{}) {
if len(v) > 1 { if len(values) > 1 {
level := v[0] level := values[0]
currentTime := "\n\033[33m[" + NowFunc().Format("2006-01-02 15:04:05") + "]\033[0m" currentTime := "\n\033[33m[" + NowFunc().Format("2006-01-02 15:04:05") + "]\033[0m"
source := fmt.Sprintf("\033[35m(%v)\033[0m", v[1]) source := fmt.Sprintf("\033[35m(%v)\033[0m", values[1])
messages := []interface{}{source, currentTime} messages := []interface{}{source, currentTime}
if level == "sql" { if level == "sql" {
// duration // duration
messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(v[2].(time.Duration).Nanoseconds()/1e4)/100.0)) messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0))
// sql // sql
messages = append(messages, fmt.Sprintf(sqlRegexp.ReplaceAllString(v[3].(string), "'%v'"), v[4].([]interface{})...)) messages = append(messages, fmt.Sprintf(sqlRegexp.ReplaceAllString(values[3].(string), "'%v'"), values[4].([]interface{})...))
} else { } else {
messages = append(messages, "\033[31;1m") messages = append(messages, "\033[31;1m")
messages = append(messages, v[2:]...) messages = append(messages, values[2:]...)
messages = append(messages, "\033[0m") messages = append(messages, "\033[0m")
} }
logger.Println(messages...) logger.Println(messages...)

82
main.go
View File

@ -30,7 +30,6 @@ type DB struct {
logMode int logMode int
logger logger logger logger
dialect Dialect dialect Dialect
tagIdentifier string
singularTable bool singularTable bool
source string source string
values map[string]interface{} values map[string]interface{}
@ -39,12 +38,12 @@ type DB struct {
func Open(dialect string, args ...interface{}) (DB, error) { func Open(dialect string, args ...interface{}) (DB, error) {
var db DB var db DB
var err error var err error
var source string
var dbSql sqlCommon
if len(args) == 0 { if len(args) == 0 {
err = errors.New("invalid database source") err = errors.New("invalid database source")
} } else {
var source string
var dbSql sqlCommon
switch value := args[0].(type) { switch value := args[0].(type) {
case string: case string:
@ -61,11 +60,16 @@ func Open(dialect string, args ...interface{}) (DB, error) {
dbSql = value dbSql = value
} }
db = DB{dialect: NewDialect(dialect), tagIdentifier: "sql", db = DB{
logger: defaultLogger, callback: DefaultCallback, source: source, dialect: NewDialect(dialect),
values: map[string]interface{}{}} logger: defaultLogger,
db.db = dbSql callback: DefaultCallback,
source: source,
values: map[string]interface{}{},
db: dbSql,
}
db.parent = &db db.parent = &db
}
return db, err return db, err
} }
@ -84,7 +88,7 @@ func (s *DB) New() *DB {
return clone return clone
} }
// Return the underlying sql.DB or sql.Tx instance. // CommonDB Return the underlying sql.DB or sql.Tx instance.
// Use of this method is discouraged. It's mainly intended to allow // Use of this method is discouraged. It's mainly intended to allow
// coexistence with legacy non-GORM code. // coexistence with legacy non-GORM code.
func (s *DB) CommonDB() sqlCommon { func (s *DB) CommonDB() sqlCommon {
@ -96,16 +100,12 @@ func (s *DB) Callback() *callback {
return s.parent.callback return s.parent.callback
} }
func (s *DB) SetTagIdentifier(str string) {
s.parent.tagIdentifier = str
}
func (s *DB) SetLogger(l logger) { func (s *DB) SetLogger(l logger) {
s.parent.logger = l s.parent.logger = l
} }
func (s *DB) LogMode(b bool) *DB { func (s *DB) LogMode(enable bool) *DB {
if b { if enable {
s.logMode = 2 s.logMode = 2
} else { } else {
s.logMode = 1 s.logMode = 1
@ -113,8 +113,8 @@ func (s *DB) LogMode(b bool) *DB {
return s return s
} }
func (s *DB) SingularTable(b bool) { func (s *DB) SingularTable(enable bool) {
s.parent.singularTable = b s.parent.singularTable = enable
} }
func (s *DB) Where(query interface{}, args ...interface{}) *DB { func (s *DB) Where(query interface{}, args ...interface{}) *DB {
@ -158,11 +158,10 @@ func (s *DB) Joins(query string) *DB {
} }
func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB { func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB {
c := s
for _, f := range funcs { for _, f := range funcs {
c = f(c) s = f(s)
} }
return c return s
} }
func (s *DB) Unscoped() *DB { func (s *DB) Unscoped() *DB {
@ -179,16 +178,14 @@ func (s *DB) Assign(attrs ...interface{}) *DB {
func (s *DB) First(out interface{}, where ...interface{}) *DB { func (s *DB) First(out interface{}, where ...interface{}) *DB {
newScope := s.clone().NewScope(out) newScope := s.clone().NewScope(out)
newScope.Search = newScope.Search.clone() newScope.Search = newScope.Search.clone().limit(1)
newScope.Search.limit(1)
return newScope.InstanceSet("gorm:order_by_primary_key", "ASC"). return newScope.InstanceSet("gorm:order_by_primary_key", "ASC").
inlineCondition(where...).callCallbacks(s.parent.callback.queries).db inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
} }
func (s *DB) Last(out interface{}, where ...interface{}) *DB { func (s *DB) Last(out interface{}, where ...interface{}) *DB {
newScope := s.clone().NewScope(out) newScope := s.clone().NewScope(out)
newScope.Search = newScope.Search.clone() newScope.Search = newScope.Search.clone().limit(1)
newScope.Search.limit(1)
return newScope.InstanceSet("gorm:order_by_primary_key", "DESC"). return newScope.InstanceSet("gorm:order_by_primary_key", "DESC").
inlineCondition(where...).callCallbacks(s.parent.callback.queries).db inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
} }
@ -213,10 +210,9 @@ func (s *DB) Scan(dest interface{}) *DB {
func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
c := s.clone() c := s.clone()
r := c.First(out, where...) if result := c.First(out, where...); result.Error != nil {
if r.Error != nil { if !result.RecordNotFound() {
if !r.RecordNotFound() { return result
return r
} }
c.NewScope(out).inlineCondition(where...).initialize() c.NewScope(out).inlineCondition(where...).initialize()
} else { } else {
@ -227,10 +223,9 @@ func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB { func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
c := s.clone() c := s.clone()
r := c.First(out, where...) if result := c.First(out, where...); result.Error != nil {
if r.Error != nil { if !result.RecordNotFound() {
if !r.RecordNotFound() { return result
return r
} }
c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(s.parent.callback.creates) c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(s.parent.callback.creates)
} else if len(c.search.AssignAttrs) > 0 { } else if len(c.search.AssignAttrs) > 0 {
@ -418,25 +413,24 @@ func (s *DB) RemoveIndex(indexName string) *DB {
} }
func (s *DB) Association(column string) *Association { func (s *DB) Association(column string) *Association {
var err error
scope := s.clone().NewScope(s.Value) scope := s.clone().NewScope(s.Value)
primaryKey := scope.PrimaryKeyValue() if primaryField := scope.PrimaryKeyField(); primaryField.IsBlank {
primaryType := scope.TableName() err = errors.New("primary key can't be nil")
if reflect.DeepEqual(reflect.ValueOf(primaryKey), reflect.Zero(reflect.ValueOf(primaryKey).Type())) { } else {
scope.Err(errors.New("primary key can't be nil")) if field, ok := scope.FieldByName(column); ok {
}
var field *Field
var ok bool
if field, ok = scope.FieldByName(column); ok {
if field.Relationship == nil || field.Relationship.ForeignFieldName == "" { if field.Relationship == nil || field.Relationship.ForeignFieldName == "" {
scope.Err(fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type())) err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type())
} else {
return &Association{Scope: scope, Column: column, PrimaryKey: primaryField.Field.Interface(), Field: field}
} }
} else { } else {
scope.Err(fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column)) err = fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column)
}
} }
return &Association{Scope: scope, Column: column, Error: s.Error, PrimaryKey: primaryKey, PrimaryType: primaryType, Field: field} return &Association{Error: err}
} }
func (s *DB) Preload(column string, conditions ...interface{}) *DB { func (s *DB) Preload(column string, conditions ...interface{}) *DB {

View File

@ -133,11 +133,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
// Set fields // Set fields
for i := 0; i < scopeType.NumField(); i++ { for i := 0; i < scopeType.NumField(); i++ {
fieldStruct := scopeType.Field(i) if fieldStruct := scopeType.Field(i); ast.IsExported(fieldStruct.Name) {
if !ast.IsExported(fieldStruct.Name) {
continue
}
field := &StructField{ field := &StructField{
Struct: fieldStruct, Struct: fieldStruct,
Name: fieldStruct.Name, Name: fieldStruct.Name,
@ -265,6 +261,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
} }
modelStruct.StructFields = append(modelStruct.StructFields, field) modelStruct.StructFields = append(modelStruct.StructFields, field)
} }
}
for _, field := range modelStruct.StructFields { for _, field := range modelStruct.StructFields {
if field.IsNormal { if field.IsNormal {

View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
"time"
) )
type mssql struct{} type mssql struct{}
@ -36,7 +37,7 @@ func (s *mssql) SqlTag(value reflect.Value, size int) string {
} }
return "text" return "text"
case reflect.Struct: case reflect.Struct:
if value.Type() == timeType { if _, ok := value.Interface().(time.Time); ok {
return "datetime2" return "datetime2"
} }
default: default:

View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
"time"
) )
type mysql struct{} type mysql struct{}
@ -36,7 +37,7 @@ func (s *mysql) SqlTag(value reflect.Value, size int) string {
} }
return "longtext" return "longtext"
case reflect.Struct: case reflect.Struct:
if value.Type() == timeType { if _, ok := value.Interface().(time.Time); ok {
return "timestamp NULL" return "timestamp NULL"
} }
default: default:

View File

@ -5,6 +5,7 @@ import (
"database/sql/driver" "database/sql/driver"
"fmt" "fmt"
"reflect" "reflect"
"time"
"github.com/lib/pq/hstore" "github.com/lib/pq/hstore"
) )
@ -40,7 +41,7 @@ func (s *postgres) SqlTag(value reflect.Value, size int) string {
} }
return "text" return "text"
case reflect.Struct: case reflect.Struct:
if value.Type() == timeType { if _, ok := value.Interface().(time.Time); ok {
return "timestamp with time zone" return "timestamp with time zone"
} }
case reflect.Map: case reflect.Map:

View File

@ -7,7 +7,7 @@ import (
"reflect" "reflect"
) )
func getFieldValue(value reflect.Value, field string) interface{} { func getRealValue(value reflect.Value, field string) interface{} {
result := reflect.Indirect(value).FieldByName(field).Interface() result := reflect.Indirect(value).FieldByName(field).Interface()
if r, ok := result.(driver.Valuer); ok { if r, ok := result.(driver.Valuer); ok {
result, _ = r.Value() result, _ = r.Value()
@ -20,26 +20,14 @@ func equalAsString(a interface{}, b interface{}) bool {
} }
func Preload(scope *Scope) { func Preload(scope *Scope) {
// Get Fields fields := scope.Fields()
var fields map[string]*Field isSlice := scope.IndirectValue().Kind() == reflect.Slice
var isSlice bool
if scope.IndirectValue().Kind() == reflect.Slice {
isSlice = true
typ := scope.IndirectValue().Type().Elem()
if typ.Kind() == reflect.Ptr {
typ = typ.Elem()
}
elem := reflect.New(typ).Elem()
fields = scope.New(elem.Addr().Interface()).Fields()
} else {
fields = scope.Fields()
}
if scope.Search.Preload != nil { if scope.Search.Preload != nil {
for key, conditions := range scope.Search.Preload { for key, conditions := range scope.Search.Preload {
for _, field := range fields { for _, field := range fields {
if field.Name == key && field.Relationship != nil { if field.Name == key && field.Relationship != nil {
results := makeSlice(field.Field) results := makeSlice(field.Struct.Type)
relation := field.Relationship relation := field.Relationship
primaryName := scope.PrimaryKeyField().Name primaryName := scope.PrimaryKeyField().Name
associationPrimaryKey := scope.New(results).PrimaryKeyField().Name associationPrimaryKey := scope.New(results).PrimaryKeyField().Name
@ -53,10 +41,10 @@ func Preload(scope *Scope) {
for i := 0; i < resultValues.Len(); i++ { for i := 0; i < resultValues.Len(); i++ {
result := resultValues.Index(i) result := resultValues.Index(i)
if isSlice { if isSlice {
value := getFieldValue(result, relation.ForeignFieldName) value := getRealValue(result, relation.ForeignFieldName)
objects := scope.IndirectValue() objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ { for j := 0; j < objects.Len(); j++ {
if equalAsString(getFieldValue(objects.Index(j), primaryName), value) { if equalAsString(getRealValue(objects.Index(j), primaryName), value) {
reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result) reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result)
break break
} }
@ -72,11 +60,11 @@ func Preload(scope *Scope) {
if isSlice { if isSlice {
for i := 0; i < resultValues.Len(); i++ { for i := 0; i < resultValues.Len(); i++ {
result := resultValues.Index(i) result := resultValues.Index(i)
value := getFieldValue(result, relation.ForeignFieldName) value := getRealValue(result, relation.ForeignFieldName)
objects := scope.IndirectValue() objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ { for j := 0; j < objects.Len(); j++ {
object := reflect.Indirect(objects.Index(j)) object := reflect.Indirect(objects.Index(j))
if equalAsString(getFieldValue(object, primaryName), value) { if equalAsString(getRealValue(object, primaryName), value) {
f := object.FieldByName(field.Name) f := object.FieldByName(field.Name)
f.Set(reflect.Append(f, result)) f.Set(reflect.Append(f, result))
break break
@ -92,11 +80,11 @@ func Preload(scope *Scope) {
for i := 0; i < resultValues.Len(); i++ { for i := 0; i < resultValues.Len(); i++ {
result := resultValues.Index(i) result := resultValues.Index(i)
if isSlice { if isSlice {
value := getFieldValue(result, associationPrimaryKey) value := getRealValue(result, associationPrimaryKey)
objects := scope.IndirectValue() objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ { for j := 0; j < objects.Len(); j++ {
object := reflect.Indirect(objects.Index(j)) object := reflect.Indirect(objects.Index(j))
if equalAsString(getFieldValue(object, relation.ForeignFieldName), value) { if equalAsString(getRealValue(object, relation.ForeignFieldName), value) {
object.FieldByName(field.Name).Set(result) object.FieldByName(field.Name).Set(result)
} }
} }
@ -116,9 +104,8 @@ func Preload(scope *Scope) {
} }
} }
func makeSlice(value reflect.Value) interface{} { func makeSlice(typ reflect.Type) interface{} {
typ := value.Type() if typ.Kind() == reflect.Slice {
if value.Kind() == reflect.Slice {
typ = typ.Elem() typ = typ.Elem()
} }
sliceType := reflect.SliceOf(typ) sliceType := reflect.SliceOf(typ)
@ -132,8 +119,7 @@ func (scope *Scope) getColumnAsArray(column string) (primaryKeys []interface{})
switch values.Kind() { switch values.Kind() {
case reflect.Slice: case reflect.Slice:
for i := 0; i < values.Len(); i++ { for i := 0; i < values.Len(); i++ {
value := values.Index(i) primaryKeys = append(primaryKeys, reflect.Indirect(values.Index(i)).FieldByName(column).Interface())
primaryKeys = append(primaryKeys, reflect.Indirect(value).FieldByName(column).Interface())
} }
case reflect.Struct: case reflect.Struct:
return []interface{}{values.FieldByName(column).Interface()} return []interface{}{values.FieldByName(column).Interface()}

View File

@ -111,7 +111,8 @@ func (scope *Scope) PrimaryKey() string {
// PrimaryKeyZero check the primary key is blank or not // PrimaryKeyZero check the primary key is blank or not
func (scope *Scope) PrimaryKeyZero() bool { func (scope *Scope) PrimaryKeyZero() bool {
return isBlank(reflect.ValueOf(scope.PrimaryKeyValue())) field := scope.PrimaryKeyField()
return field == nil || field.IsBlank
} }
// PrimaryKeyValue get the primary key's value // PrimaryKeyValue get the primary key's value
@ -125,35 +126,23 @@ func (scope *Scope) PrimaryKeyValue() interface{} {
// HasColumn to check if has column // HasColumn to check if has column
func (scope *Scope) HasColumn(column string) bool { func (scope *Scope) HasColumn(column string) bool {
for _, field := range scope.GetStructFields() { for _, field := range scope.GetStructFields() {
if !field.IsIgnored { if field.IsNormal && (field.Name == column || field.DBName == column) {
if field.Name == column || field.DBName == column {
return true return true
} }
} }
}
return false return false
} }
// FieldValueByName to get column's value and existence
func (scope *Scope) FieldValueByName(name string) (interface{}, error) {
return FieldValueByName(name, scope.Value)
}
// SetColumn to set the column's value // SetColumn to set the column's value
func (scope *Scope) SetColumn(column interface{}, value interface{}) error { func (scope *Scope) SetColumn(column interface{}, value interface{}) error {
if field, ok := column.(*Field); ok { if field, ok := column.(*Field); ok {
return field.Set(value) return field.Set(value)
} else if dbName, ok := column.(string); ok { } else if dbName, ok := column.(string); ok {
if scope.Value == nil {
return errors.New("scope value must not be nil for string columns")
}
if field, ok := scope.Fields()[dbName]; ok { if field, ok := scope.Fields()[dbName]; ok {
return field.Set(value) return field.Set(value)
} }
dbName = ToSnake(dbName) dbName = ToSnake(dbName)
if field, ok := scope.Fields()[dbName]; ok { if field, ok := scope.Fields()[dbName]; ok {
return field.Set(value) return field.Set(value)
} }
@ -204,45 +193,11 @@ func (scope *Scope) AddToVars(value interface{}) string {
} }
// TableName get table name // TableName get table name
func (scope *Scope) TableName() string { func (scope *Scope) TableName() string {
if scope.Search != nil && len(scope.Search.TableName) > 0 { if scope.Search != nil && len(scope.Search.TableName) > 0 {
return scope.Search.TableName return scope.Search.TableName
} }
return scope.GetModelStruct().TableName
if scope.Value == nil {
scope.Err(errors.New("can't get table name"))
return ""
}
data := scope.IndirectValue()
if data.Kind() == reflect.Slice {
elem := data.Type().Elem()
if elem.Kind() == reflect.Ptr {
elem = elem.Elem()
}
data = reflect.New(elem).Elem()
}
if fm := data.MethodByName("TableName"); fm.IsValid() {
if v := fm.Call([]reflect.Value{}); len(v) > 0 {
if result, ok := v[0].Interface().(string); ok {
return result
}
}
}
str := ToSnake(data.Type().Name())
if scope.db == nil || !scope.db.parent.singularTable {
for index, reg := range pluralMapKeys {
if reg.MatchString(str) {
return reg.ReplaceAllString(str, pluralMapValues[index])
}
}
}
return str
} }
func (scope *Scope) QuotedTableName() string { func (scope *Scope) QuotedTableName() string {
@ -284,8 +239,7 @@ func (scope *Scope) Exec() *Scope {
defer scope.Trace(NowFunc()) defer scope.Trace(NowFunc())
if !scope.HasError() { if !scope.HasError() {
result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...) if result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
if scope.Err(err) == nil {
if count, err := result.RowsAffected(); err == nil { if count, err := result.RowsAffected(); err == nil {
scope.db.RowsAffected = count scope.db.RowsAffected = count
} }

View File

@ -25,10 +25,8 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri
} else if value != "" { } else if value != "" {
str = fmt.Sprintf("(%v)", value) str = fmt.Sprintf("(%v)", value)
} }
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
return scope.primaryCondition(scope.AddToVars(value)) return scope.primaryCondition(scope.AddToVars(value))
case sql.NullInt64:
return scope.primaryCondition(scope.AddToVars(value.Int64))
case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}: case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}:
str = fmt.Sprintf("(%v in (?))", scope.Quote(scope.PrimaryKey())) str = fmt.Sprintf("(%v in (?))", scope.Quote(scope.PrimaryKey()))
clause["args"] = []interface{}{value} clause["args"] = []interface{}{value}
@ -71,12 +69,14 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri
func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) { func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) {
var notEqualSql string var notEqualSql string
var primaryKey = scope.PrimaryKey()
switch value := clause["query"].(type) { switch value := clause["query"].(type) {
case string: case string:
// is number
if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) { if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
id, _ := strconv.Atoi(value) id, _ := strconv.Atoi(value)
return fmt.Sprintf("(%v <> %v)", scope.Quote(scope.PrimaryKey()), id) return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id)
} else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS) ").MatchString(value) { } else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS) ").MatchString(value) {
str = fmt.Sprintf(" NOT (%v) ", value) str = fmt.Sprintf(" NOT (%v) ", value)
notEqualSql = fmt.Sprintf("NOT (%v)", value) notEqualSql = fmt.Sprintf("NOT (%v)", value)
@ -84,15 +84,14 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string
str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(value)) str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(value))
notEqualSql = fmt.Sprintf("(%v <> ?)", scope.Quote(value)) notEqualSql = fmt.Sprintf("(%v <> ?)", scope.Quote(value))
} }
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
return fmt.Sprintf("(%v <> %v)", scope.Quote(scope.PrimaryKey()), value) return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), value)
case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string: case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string:
if reflect.ValueOf(value).Len() > 0 { if reflect.ValueOf(value).Len() > 0 {
str = fmt.Sprintf("(%v not in (?))", scope.Quote(scope.PrimaryKey())) str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(primaryKey))
clause["args"] = []interface{}{value} clause["args"] = []interface{}{value}
} else {
return ""
} }
return ""
case map[string]interface{}: case map[string]interface{}:
var sqls []string var sqls []string
for key, value := range value { for key, value := range value {
@ -157,16 +156,10 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string)
return return
} }
func (scope *Scope) where(where ...interface{}) {
if len(where) > 0 {
scope.Search = scope.Search.clone().where(where[0], where[1:]...)
}
}
func (scope *Scope) whereSql() (sql string) { func (scope *Scope) whereSql() (sql string) {
var primaryConditions, andConditions, orConditions []string var primaryConditions, andConditions, orConditions []string
if !scope.Search.Unscope && scope.HasColumn("DeletedAt") { if !scope.Search.Unscope && scope.Fields()["deleted_at"] != nil {
sql := fmt.Sprintf("(%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02')", scope.QuotedTableName(), scope.QuotedTableName()) sql := fmt.Sprintf("(%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02')", scope.QuotedTableName(), scope.QuotedTableName())
primaryConditions = append(primaryConditions, sql) primaryConditions = append(primaryConditions, sql)
} }
@ -317,41 +310,19 @@ func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
} }
func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignoreProtectedAttrs bool) (results map[string]interface{}, hasUpdate bool) { func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignoreProtectedAttrs bool) (results map[string]interface{}, hasUpdate bool) {
data := scope.IndirectValue() if !scope.IndirectValue().CanAddr() {
if !data.CanAddr() {
return values, true return values, true
} }
fields := scope.Fields()
for key, value := range values { for key, value := range values {
if field, ok := scope.FieldByName(SnakeToUpperCamel(key)); ok && field.Field.IsValid() { if field, ok := fields[ToSnake(key)]; ok && field.Field.IsValid() {
func() { if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) {
defer func() { if !equalAsString(field.Field.Interface(), value) {
if err := recover(); err != nil {
hasUpdate = true
field.Set(value)
}
}()
if field.Field.Interface() != value {
switch field.Field.Kind() {
case reflect.Int, reflect.Int32, reflect.Int64:
if s, ok := value.(string); ok {
i, err := strconv.Atoi(s)
if scope.Err(err) == nil {
value = i
}
}
if field.Field.Int() != reflect.ValueOf(value).Int() {
hasUpdate = true
field.Set(value)
}
default:
hasUpdate = true hasUpdate = true
field.Set(value) field.Set(value)
} }
} }
}()
} }
} }
return return
@ -464,7 +435,7 @@ func (scope *Scope) createJoinTable(field *StructField) {
if field.Relationship != nil && field.Relationship.JoinTable != "" { if field.Relationship != nil && field.Relationship.JoinTable != "" {
if !scope.Dialect().HasTable(scope, field.Relationship.JoinTable) { if !scope.Dialect().HasTable(scope, field.Relationship.JoinTable) {
newScope := scope.db.NewScope("") newScope := scope.db.NewScope("")
primaryKeySqlType := scope.Dialect().SqlTag(reflect.ValueOf(scope.PrimaryKeyValue()), 255) primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryKeyField().Field, 255)
newScope.Raw(fmt.Sprintf("CREATE TABLE %v (%v)", newScope.Raw(fmt.Sprintf("CREATE TABLE %v (%v)",
field.Relationship.JoinTable, field.Relationship.JoinTable,
strings.Join([]string{ strings.Join([]string{
@ -523,14 +494,7 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) { func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) {
var table = scope.TableName() var table = scope.TableName()
var keyName = fmt.Sprintf("%s_%s_foreign", table, field) var keyName = fmt.Sprintf("%s_%s_foreign", table, field)
var query = ` var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;`
ALTER TABLE %s
ADD CONSTRAINT %s
FOREIGN KEY (%s)
REFERENCES %s
ON DELETE %s
ON UPDATE %s;
`
scope.Raw(fmt.Sprintf(query, table, keyName, field, dest, onDelete, onUpdate)).Exec() scope.Raw(fmt.Sprintf(query, table, keyName, field, dest, onDelete, onUpdate)).Exec()
} }

View File

@ -3,6 +3,7 @@ package gorm
import ( import (
"fmt" "fmt"
"reflect" "reflect"
"time"
) )
type sqlite3 struct{} type sqlite3 struct{}
@ -35,7 +36,7 @@ func (s *sqlite3) SqlTag(value reflect.Value, size int) string {
} }
return "text" return "text"
case reflect.Struct: case reflect.Struct:
if value.Type() == timeType { if _, ok := value.Interface().(time.Time); ok {
return "datetime" return "datetime"
} }
default: default:

View File

@ -2,9 +2,6 @@ package gorm
import ( import (
"bytes" "bytes"
"errors"
"fmt"
"reflect"
"strings" "strings"
"sync" "sync"
) )
@ -26,23 +23,6 @@ func (s *safeMap) Get(key string) string {
return s.m[key] return s.m[key]
} }
func FieldValueByName(name string, value interface{}) (i interface{}, err error) {
data := reflect.Indirect(reflect.ValueOf(value))
name = SnakeToUpperCamel(name)
if data.Kind() == reflect.Struct {
if field := data.FieldByName(name); field.IsValid() {
i = field.Interface()
} else {
return nil, fmt.Errorf("struct has no field with name %s", name)
}
} else {
return nil, errors.New("value must be of kind struct")
}
return
}
func newSafeMap() *safeMap { func newSafeMap() *safeMap {
return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)} return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)}
} }

View File

@ -2,18 +2,16 @@ package gorm
import ( import (
"fmt" "fmt"
"os"
"reflect" "reflect"
"regexp" "regexp"
"runtime" "runtime"
"strings"
) )
func fileWithLineNum() string { func fileWithLineNum() string {
for i := 2; i < 15; i++ { for i := 2; i < 15; i++ {
_, file, line, ok := runtime.Caller(i) _, file, line, ok := runtime.Caller(i)
if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) { if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) {
return fmt.Sprintf("%v:%v", strings.TrimPrefix(file, os.Getenv("GOPATH")+"src/"), line) return fmt.Sprintf("%v:%v", file, line)
} }
} }
return "" return ""