Review and Refactor

This commit is contained in:
Jinzhu 2015-02-17 22:55:14 +08:00
parent 38cbff9c79
commit 0b32041135
18 changed files with 264 additions and 381 deletions

View File

@ -9,7 +9,6 @@ import (
type Association struct {
Scope *Scope
PrimaryKey interface{}
PrimaryType interface{}
Column string
Error error
Field *Field
@ -158,11 +157,11 @@ func (association *Association) Count() int {
whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName))
countScope := scope.db.Model("").Table(newScope.QuotedTableName()).Where(whereSql, association.PrimaryKey)
if relationship.ForeignType != "" {
countScope = countScope.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(ToSnake(relationship.ForeignType))), association.PrimaryType)
countScope = countScope.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(ToSnake(relationship.ForeignType))), scope.TableName())
}
countScope.Count(&count)
} else if relationship.Kind == "belongs_to" {
if v, err := scope.FieldValueByName(association.Column); err == nil {
if v, ok := scope.FieldByName(association.Column); ok {
whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName))
scope.db.Model("").Table(newScope.QuotedTableName()).Where(whereSql, v).Count(&count)
}

View File

@ -26,13 +26,12 @@ func Create(scope *Scope) {
var sqls, columns []string
for _, field := range scope.Fields() {
if (field.IsNormal && !field.IsPrimaryKey) || (field.IsPrimaryKey && !field.IsBlank) {
if field.DefaultValue != nil && field.IsBlank {
continue
}
if !field.IsBlank || field.DefaultValue == nil {
columns = append(columns, scope.Quote(field.DBName))
sqls = append(sqls, scope.AddToVars(field.Field.Interface()))
}
}
}
returningKey := "*"
primaryField := scope.PrimaryKeyField()

View File

@ -41,21 +41,19 @@ func Update(scope *Scope) {
if !scope.HasError() {
var sqls []string
updateAttrs, ok := scope.InstanceGet("gorm:update_attrs")
if ok {
if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
for key, value := range updateAttrs.(map[string]interface{}) {
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(key), scope.AddToVars(value)))
}
} else {
for _, field := range scope.Fields() {
if !field.IsPrimaryKey && field.IsNormal && !field.IsIgnored {
if field.DefaultValue != nil && field.IsBlank {
continue
}
if !field.IsPrimaryKey && field.IsNormal {
if !field.IsBlank || field.DefaultValue == nil {
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
}
}
}
}
scope.Raw(fmt.Sprintf(
"UPDATE %v SET %v %v",
@ -68,8 +66,7 @@ func Update(scope *Scope) {
}
func AfterUpdate(scope *Scope) {
_, ok := scope.Get("gorm:update_column")
if !ok {
if _, ok := scope.Get("gorm:update_column"); !ok {
scope.CallMethod("AfterUpdate")
scope.CallMethod("AfterSave")
}

View File

@ -4,6 +4,7 @@ import (
"fmt"
"reflect"
"strings"
"time"
)
type commonDialect struct{}
@ -36,7 +37,7 @@ func (s *commonDialect) SqlTag(value reflect.Value, size int) string {
}
return "VARCHAR(65532)"
case reflect.Struct:
if value.Type() == timeType {
if _, ok := value.Interface().(time.Time); ok {
return "TIMESTAMP"
}
default:

View File

@ -3,11 +3,8 @@ package gorm
import (
"fmt"
"reflect"
"time"
)
var timeType = reflect.TypeOf(time.Time{})
type Dialect interface {
BinVar(i int) string
SupportLastInsertId() bool

View File

@ -12,7 +12,7 @@ type Field struct {
Field reflect.Value
}
func (field *Field) Set(value interface{}) (err error) {
func (field *Field) Set(value interface{}) error {
if !field.Field.IsValid() {
return errors.New("field value not valid")
}
@ -26,16 +26,26 @@ func (field *Field) Set(value interface{}) (err error) {
}
if scanner, ok := field.Field.Addr().Interface().(sql.Scanner); ok {
if v, ok := value.(reflect.Value); ok {
scanner.Scan(v.Interface())
} else {
scanner.Scan(value)
} else if reflect.TypeOf(value).ConvertibleTo(field.Field.Type()) {
field.Field.Set(reflect.ValueOf(value).Convert(field.Field.Type()))
}
} else {
reflectValue, ok := value.(reflect.Value)
if !ok {
reflectValue = reflect.ValueOf(value)
}
if reflectValue.Type().ConvertibleTo(field.Field.Type()) {
field.Field.Set(reflectValue.Convert(field.Field.Type()))
} else {
return errors.New("could not convert argument")
}
}
field.IsBlank = isBlank(field.Field)
return
return nil
}
// Fields get value's fields
@ -44,8 +54,14 @@ func (scope *Scope) Fields() map[string]*Field {
fields := map[string]*Field{}
structFields := scope.GetStructFields()
indirectValue := scope.IndirectValue()
isStruct := indirectValue.Kind() == reflect.Struct
for _, structField := range structFields {
fields[structField.DBName] = scope.getField(structField)
if isStruct {
fields[structField.DBName] = getField(indirectValue, structField)
} else {
fields[structField.DBName] = &Field{StructField: structField, IsBlank: true}
}
}
scope.fields = fields
@ -53,15 +69,12 @@ func (scope *Scope) Fields() map[string]*Field {
return scope.fields
}
func (scope *Scope) getField(structField *StructField) *Field {
field := Field{StructField: structField}
indirectValue := scope.IndirectValue()
if len(structField.Names) > 0 && indirectValue.Kind() == reflect.Struct {
func getField(indirectValue reflect.Value, structField *StructField) *Field {
field := &Field{StructField: structField}
for _, name := range structField.Names {
indirectValue = reflect.Indirect(indirectValue).FieldByName(name)
}
field.Field = indirectValue
}
field.IsBlank = isBlank(indirectValue)
return &field
return field
}

View File

@ -21,21 +21,21 @@ var defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)}
// Format log
var sqlRegexp = regexp.MustCompile(`(\$\d+)|\?`)
func (logger Logger) Print(v ...interface{}) {
if len(v) > 1 {
level := v[0]
func (logger Logger) Print(values ...interface{}) {
if len(values) > 1 {
level := values[0]
currentTime := "\n\033[33m[" + NowFunc().Format("2006-01-02 15:04:05") + "]\033[0m"
source := fmt.Sprintf("\033[35m(%v)\033[0m", v[1])
source := fmt.Sprintf("\033[35m(%v)\033[0m", values[1])
messages := []interface{}{source, currentTime}
if level == "sql" {
// duration
messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(v[2].(time.Duration).Nanoseconds()/1e4)/100.0))
messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0))
// sql
messages = append(messages, fmt.Sprintf(sqlRegexp.ReplaceAllString(v[3].(string), "'%v'"), v[4].([]interface{})...))
messages = append(messages, fmt.Sprintf(sqlRegexp.ReplaceAllString(values[3].(string), "'%v'"), values[4].([]interface{})...))
} else {
messages = append(messages, "\033[31;1m")
messages = append(messages, v[2:]...)
messages = append(messages, values[2:]...)
messages = append(messages, "\033[0m")
}
logger.Println(messages...)

82
main.go
View File

@ -30,7 +30,6 @@ type DB struct {
logMode int
logger logger
dialect Dialect
tagIdentifier string
singularTable bool
source string
values map[string]interface{}
@ -39,12 +38,12 @@ type DB struct {
func Open(dialect string, args ...interface{}) (DB, error) {
var db DB
var err error
var source string
var dbSql sqlCommon
if len(args) == 0 {
err = errors.New("invalid database source")
}
} else {
var source string
var dbSql sqlCommon
switch value := args[0].(type) {
case string:
@ -61,11 +60,16 @@ func Open(dialect string, args ...interface{}) (DB, error) {
dbSql = value
}
db = DB{dialect: NewDialect(dialect), tagIdentifier: "sql",
logger: defaultLogger, callback: DefaultCallback, source: source,
values: map[string]interface{}{}}
db.db = dbSql
db = DB{
dialect: NewDialect(dialect),
logger: defaultLogger,
callback: DefaultCallback,
source: source,
values: map[string]interface{}{},
db: dbSql,
}
db.parent = &db
}
return db, err
}
@ -84,7 +88,7 @@ func (s *DB) New() *DB {
return clone
}
// Return the underlying sql.DB or sql.Tx instance.
// CommonDB Return the underlying sql.DB or sql.Tx instance.
// Use of this method is discouraged. It's mainly intended to allow
// coexistence with legacy non-GORM code.
func (s *DB) CommonDB() sqlCommon {
@ -96,16 +100,12 @@ func (s *DB) Callback() *callback {
return s.parent.callback
}
func (s *DB) SetTagIdentifier(str string) {
s.parent.tagIdentifier = str
}
func (s *DB) SetLogger(l logger) {
s.parent.logger = l
}
func (s *DB) LogMode(b bool) *DB {
if b {
func (s *DB) LogMode(enable bool) *DB {
if enable {
s.logMode = 2
} else {
s.logMode = 1
@ -113,8 +113,8 @@ func (s *DB) LogMode(b bool) *DB {
return s
}
func (s *DB) SingularTable(b bool) {
s.parent.singularTable = b
func (s *DB) SingularTable(enable bool) {
s.parent.singularTable = enable
}
func (s *DB) Where(query interface{}, args ...interface{}) *DB {
@ -158,11 +158,10 @@ func (s *DB) Joins(query string) *DB {
}
func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB {
c := s
for _, f := range funcs {
c = f(c)
s = f(s)
}
return c
return s
}
func (s *DB) Unscoped() *DB {
@ -179,16 +178,14 @@ func (s *DB) Assign(attrs ...interface{}) *DB {
func (s *DB) First(out interface{}, where ...interface{}) *DB {
newScope := s.clone().NewScope(out)
newScope.Search = newScope.Search.clone()
newScope.Search.limit(1)
newScope.Search = newScope.Search.clone().limit(1)
return newScope.InstanceSet("gorm:order_by_primary_key", "ASC").
inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
}
func (s *DB) Last(out interface{}, where ...interface{}) *DB {
newScope := s.clone().NewScope(out)
newScope.Search = newScope.Search.clone()
newScope.Search.limit(1)
newScope.Search = newScope.Search.clone().limit(1)
return newScope.InstanceSet("gorm:order_by_primary_key", "DESC").
inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
}
@ -213,10 +210,9 @@ func (s *DB) Scan(dest interface{}) *DB {
func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
c := s.clone()
r := c.First(out, where...)
if r.Error != nil {
if !r.RecordNotFound() {
return r
if result := c.First(out, where...); result.Error != nil {
if !result.RecordNotFound() {
return result
}
c.NewScope(out).inlineCondition(where...).initialize()
} else {
@ -227,10 +223,9 @@ func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
c := s.clone()
r := c.First(out, where...)
if r.Error != nil {
if !r.RecordNotFound() {
return r
if result := c.First(out, where...); result.Error != nil {
if !result.RecordNotFound() {
return result
}
c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(s.parent.callback.creates)
} else if len(c.search.AssignAttrs) > 0 {
@ -418,25 +413,24 @@ func (s *DB) RemoveIndex(indexName string) *DB {
}
func (s *DB) Association(column string) *Association {
var err error
scope := s.clone().NewScope(s.Value)
primaryKey := scope.PrimaryKeyValue()
primaryType := scope.TableName()
if reflect.DeepEqual(reflect.ValueOf(primaryKey), reflect.Zero(reflect.ValueOf(primaryKey).Type())) {
scope.Err(errors.New("primary key can't be nil"))
}
var field *Field
var ok bool
if field, ok = scope.FieldByName(column); ok {
if primaryField := scope.PrimaryKeyField(); primaryField.IsBlank {
err = errors.New("primary key can't be nil")
} else {
if field, ok := scope.FieldByName(column); ok {
if field.Relationship == nil || field.Relationship.ForeignFieldName == "" {
scope.Err(fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type()))
err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type())
} else {
return &Association{Scope: scope, Column: column, PrimaryKey: primaryField.Field.Interface(), Field: field}
}
} else {
scope.Err(fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column))
err = fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column)
}
}
return &Association{Scope: scope, Column: column, Error: s.Error, PrimaryKey: primaryKey, PrimaryType: primaryType, Field: field}
return &Association{Error: err}
}
func (s *DB) Preload(column string, conditions ...interface{}) *DB {

View File

@ -133,11 +133,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
// Set fields
for i := 0; i < scopeType.NumField(); i++ {
fieldStruct := scopeType.Field(i)
if !ast.IsExported(fieldStruct.Name) {
continue
}
if fieldStruct := scopeType.Field(i); ast.IsExported(fieldStruct.Name) {
field := &StructField{
Struct: fieldStruct,
Name: fieldStruct.Name,
@ -265,6 +261,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
}
modelStruct.StructFields = append(modelStruct.StructFields, field)
}
}
for _, field := range modelStruct.StructFields {
if field.IsNormal {

View File

@ -4,6 +4,7 @@ import (
"fmt"
"reflect"
"strings"
"time"
)
type mssql struct{}
@ -36,7 +37,7 @@ func (s *mssql) SqlTag(value reflect.Value, size int) string {
}
return "text"
case reflect.Struct:
if value.Type() == timeType {
if _, ok := value.Interface().(time.Time); ok {
return "datetime2"
}
default:

View File

@ -4,6 +4,7 @@ import (
"fmt"
"reflect"
"strings"
"time"
)
type mysql struct{}
@ -36,7 +37,7 @@ func (s *mysql) SqlTag(value reflect.Value, size int) string {
}
return "longtext"
case reflect.Struct:
if value.Type() == timeType {
if _, ok := value.Interface().(time.Time); ok {
return "timestamp NULL"
}
default:

View File

@ -5,6 +5,7 @@ import (
"database/sql/driver"
"fmt"
"reflect"
"time"
"github.com/lib/pq/hstore"
)
@ -40,7 +41,7 @@ func (s *postgres) SqlTag(value reflect.Value, size int) string {
}
return "text"
case reflect.Struct:
if value.Type() == timeType {
if _, ok := value.Interface().(time.Time); ok {
return "timestamp with time zone"
}
case reflect.Map:

View File

@ -7,7 +7,7 @@ import (
"reflect"
)
func getFieldValue(value reflect.Value, field string) interface{} {
func getRealValue(value reflect.Value, field string) interface{} {
result := reflect.Indirect(value).FieldByName(field).Interface()
if r, ok := result.(driver.Valuer); ok {
result, _ = r.Value()
@ -20,26 +20,14 @@ func equalAsString(a interface{}, b interface{}) bool {
}
func Preload(scope *Scope) {
// Get Fields
var fields map[string]*Field
var isSlice bool
if scope.IndirectValue().Kind() == reflect.Slice {
isSlice = true
typ := scope.IndirectValue().Type().Elem()
if typ.Kind() == reflect.Ptr {
typ = typ.Elem()
}
elem := reflect.New(typ).Elem()
fields = scope.New(elem.Addr().Interface()).Fields()
} else {
fields = scope.Fields()
}
fields := scope.Fields()
isSlice := scope.IndirectValue().Kind() == reflect.Slice
if scope.Search.Preload != nil {
for key, conditions := range scope.Search.Preload {
for _, field := range fields {
if field.Name == key && field.Relationship != nil {
results := makeSlice(field.Field)
results := makeSlice(field.Struct.Type)
relation := field.Relationship
primaryName := scope.PrimaryKeyField().Name
associationPrimaryKey := scope.New(results).PrimaryKeyField().Name
@ -53,10 +41,10 @@ func Preload(scope *Scope) {
for i := 0; i < resultValues.Len(); i++ {
result := resultValues.Index(i)
if isSlice {
value := getFieldValue(result, relation.ForeignFieldName)
value := getRealValue(result, relation.ForeignFieldName)
objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ {
if equalAsString(getFieldValue(objects.Index(j), primaryName), value) {
if equalAsString(getRealValue(objects.Index(j), primaryName), value) {
reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result)
break
}
@ -72,11 +60,11 @@ func Preload(scope *Scope) {
if isSlice {
for i := 0; i < resultValues.Len(); i++ {
result := resultValues.Index(i)
value := getFieldValue(result, relation.ForeignFieldName)
value := getRealValue(result, relation.ForeignFieldName)
objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ {
object := reflect.Indirect(objects.Index(j))
if equalAsString(getFieldValue(object, primaryName), value) {
if equalAsString(getRealValue(object, primaryName), value) {
f := object.FieldByName(field.Name)
f.Set(reflect.Append(f, result))
break
@ -92,11 +80,11 @@ func Preload(scope *Scope) {
for i := 0; i < resultValues.Len(); i++ {
result := resultValues.Index(i)
if isSlice {
value := getFieldValue(result, associationPrimaryKey)
value := getRealValue(result, associationPrimaryKey)
objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ {
object := reflect.Indirect(objects.Index(j))
if equalAsString(getFieldValue(object, relation.ForeignFieldName), value) {
if equalAsString(getRealValue(object, relation.ForeignFieldName), value) {
object.FieldByName(field.Name).Set(result)
}
}
@ -116,9 +104,8 @@ func Preload(scope *Scope) {
}
}
func makeSlice(value reflect.Value) interface{} {
typ := value.Type()
if value.Kind() == reflect.Slice {
func makeSlice(typ reflect.Type) interface{} {
if typ.Kind() == reflect.Slice {
typ = typ.Elem()
}
sliceType := reflect.SliceOf(typ)
@ -132,8 +119,7 @@ func (scope *Scope) getColumnAsArray(column string) (primaryKeys []interface{})
switch values.Kind() {
case reflect.Slice:
for i := 0; i < values.Len(); i++ {
value := values.Index(i)
primaryKeys = append(primaryKeys, reflect.Indirect(value).FieldByName(column).Interface())
primaryKeys = append(primaryKeys, reflect.Indirect(values.Index(i)).FieldByName(column).Interface())
}
case reflect.Struct:
return []interface{}{values.FieldByName(column).Interface()}

View File

@ -111,7 +111,8 @@ func (scope *Scope) PrimaryKey() string {
// PrimaryKeyZero check the primary key is blank or not
func (scope *Scope) PrimaryKeyZero() bool {
return isBlank(reflect.ValueOf(scope.PrimaryKeyValue()))
field := scope.PrimaryKeyField()
return field == nil || field.IsBlank
}
// PrimaryKeyValue get the primary key's value
@ -125,35 +126,23 @@ func (scope *Scope) PrimaryKeyValue() interface{} {
// HasColumn to check if has column
func (scope *Scope) HasColumn(column string) bool {
for _, field := range scope.GetStructFields() {
if !field.IsIgnored {
if field.Name == column || field.DBName == column {
if field.IsNormal && (field.Name == column || field.DBName == column) {
return true
}
}
}
return false
}
// FieldValueByName to get column's value and existence
func (scope *Scope) FieldValueByName(name string) (interface{}, error) {
return FieldValueByName(name, scope.Value)
}
// SetColumn to set the column's value
func (scope *Scope) SetColumn(column interface{}, value interface{}) error {
if field, ok := column.(*Field); ok {
return field.Set(value)
} else if dbName, ok := column.(string); ok {
if scope.Value == nil {
return errors.New("scope value must not be nil for string columns")
}
if field, ok := scope.Fields()[dbName]; ok {
return field.Set(value)
}
dbName = ToSnake(dbName)
if field, ok := scope.Fields()[dbName]; ok {
return field.Set(value)
}
@ -204,45 +193,11 @@ func (scope *Scope) AddToVars(value interface{}) string {
}
// TableName get table name
func (scope *Scope) TableName() string {
if scope.Search != nil && len(scope.Search.TableName) > 0 {
return scope.Search.TableName
}
if scope.Value == nil {
scope.Err(errors.New("can't get table name"))
return ""
}
data := scope.IndirectValue()
if data.Kind() == reflect.Slice {
elem := data.Type().Elem()
if elem.Kind() == reflect.Ptr {
elem = elem.Elem()
}
data = reflect.New(elem).Elem()
}
if fm := data.MethodByName("TableName"); fm.IsValid() {
if v := fm.Call([]reflect.Value{}); len(v) > 0 {
if result, ok := v[0].Interface().(string); ok {
return result
}
}
}
str := ToSnake(data.Type().Name())
if scope.db == nil || !scope.db.parent.singularTable {
for index, reg := range pluralMapKeys {
if reg.MatchString(str) {
return reg.ReplaceAllString(str, pluralMapValues[index])
}
}
}
return str
return scope.GetModelStruct().TableName
}
func (scope *Scope) QuotedTableName() string {
@ -284,8 +239,7 @@ func (scope *Scope) Exec() *Scope {
defer scope.Trace(NowFunc())
if !scope.HasError() {
result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...)
if scope.Err(err) == nil {
if result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
if count, err := result.RowsAffected(); err == nil {
scope.db.RowsAffected = count
}

View File

@ -25,10 +25,8 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri
} else if value != "" {
str = fmt.Sprintf("(%v)", value)
}
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
return scope.primaryCondition(scope.AddToVars(value))
case sql.NullInt64:
return scope.primaryCondition(scope.AddToVars(value.Int64))
case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}:
str = fmt.Sprintf("(%v in (?))", scope.Quote(scope.PrimaryKey()))
clause["args"] = []interface{}{value}
@ -71,12 +69,14 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri
func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) {
var notEqualSql string
var primaryKey = scope.PrimaryKey()
switch value := clause["query"].(type) {
case string:
// is number
if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
id, _ := strconv.Atoi(value)
return fmt.Sprintf("(%v <> %v)", scope.Quote(scope.PrimaryKey()), id)
return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id)
} else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS) ").MatchString(value) {
str = fmt.Sprintf(" NOT (%v) ", value)
notEqualSql = fmt.Sprintf("NOT (%v)", value)
@ -84,15 +84,14 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string
str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(value))
notEqualSql = fmt.Sprintf("(%v <> ?)", scope.Quote(value))
}
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
return fmt.Sprintf("(%v <> %v)", scope.Quote(scope.PrimaryKey()), value)
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), value)
case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string:
if reflect.ValueOf(value).Len() > 0 {
str = fmt.Sprintf("(%v not in (?))", scope.Quote(scope.PrimaryKey()))
str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(primaryKey))
clause["args"] = []interface{}{value}
} else {
return ""
}
return ""
case map[string]interface{}:
var sqls []string
for key, value := range value {
@ -157,16 +156,10 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string)
return
}
func (scope *Scope) where(where ...interface{}) {
if len(where) > 0 {
scope.Search = scope.Search.clone().where(where[0], where[1:]...)
}
}
func (scope *Scope) whereSql() (sql string) {
var primaryConditions, andConditions, orConditions []string
if !scope.Search.Unscope && scope.HasColumn("DeletedAt") {
if !scope.Search.Unscope && scope.Fields()["deleted_at"] != nil {
sql := fmt.Sprintf("(%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02')", scope.QuotedTableName(), scope.QuotedTableName())
primaryConditions = append(primaryConditions, sql)
}
@ -317,41 +310,19 @@ func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
}
func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignoreProtectedAttrs bool) (results map[string]interface{}, hasUpdate bool) {
data := scope.IndirectValue()
if !data.CanAddr() {
if !scope.IndirectValue().CanAddr() {
return values, true
}
fields := scope.Fields()
for key, value := range values {
if field, ok := scope.FieldByName(SnakeToUpperCamel(key)); ok && field.Field.IsValid() {
func() {
defer func() {
if err := recover(); err != nil {
hasUpdate = true
field.Set(value)
}
}()
if field.Field.Interface() != value {
switch field.Field.Kind() {
case reflect.Int, reflect.Int32, reflect.Int64:
if s, ok := value.(string); ok {
i, err := strconv.Atoi(s)
if scope.Err(err) == nil {
value = i
}
}
if field.Field.Int() != reflect.ValueOf(value).Int() {
hasUpdate = true
field.Set(value)
}
default:
if field, ok := fields[ToSnake(key)]; ok && field.Field.IsValid() {
if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) {
if !equalAsString(field.Field.Interface(), value) {
hasUpdate = true
field.Set(value)
}
}
}()
}
}
return
@ -464,7 +435,7 @@ func (scope *Scope) createJoinTable(field *StructField) {
if field.Relationship != nil && field.Relationship.JoinTable != "" {
if !scope.Dialect().HasTable(scope, field.Relationship.JoinTable) {
newScope := scope.db.NewScope("")
primaryKeySqlType := scope.Dialect().SqlTag(reflect.ValueOf(scope.PrimaryKeyValue()), 255)
primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryKeyField().Field, 255)
newScope.Raw(fmt.Sprintf("CREATE TABLE %v (%v)",
field.Relationship.JoinTable,
strings.Join([]string{
@ -523,14 +494,7 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) {
var table = scope.TableName()
var keyName = fmt.Sprintf("%s_%s_foreign", table, field)
var query = `
ALTER TABLE %s
ADD CONSTRAINT %s
FOREIGN KEY (%s)
REFERENCES %s
ON DELETE %s
ON UPDATE %s;
`
var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;`
scope.Raw(fmt.Sprintf(query, table, keyName, field, dest, onDelete, onUpdate)).Exec()
}

View File

@ -3,6 +3,7 @@ package gorm
import (
"fmt"
"reflect"
"time"
)
type sqlite3 struct{}
@ -35,7 +36,7 @@ func (s *sqlite3) SqlTag(value reflect.Value, size int) string {
}
return "text"
case reflect.Struct:
if value.Type() == timeType {
if _, ok := value.Interface().(time.Time); ok {
return "datetime"
}
default:

View File

@ -2,9 +2,6 @@ package gorm
import (
"bytes"
"errors"
"fmt"
"reflect"
"strings"
"sync"
)
@ -26,23 +23,6 @@ func (s *safeMap) Get(key string) string {
return s.m[key]
}
func FieldValueByName(name string, value interface{}) (i interface{}, err error) {
data := reflect.Indirect(reflect.ValueOf(value))
name = SnakeToUpperCamel(name)
if data.Kind() == reflect.Struct {
if field := data.FieldByName(name); field.IsValid() {
i = field.Interface()
} else {
return nil, fmt.Errorf("struct has no field with name %s", name)
}
} else {
return nil, errors.New("value must be of kind struct")
}
return
}
func newSafeMap() *safeMap {
return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)}
}

View File

@ -2,18 +2,16 @@ package gorm
import (
"fmt"
"os"
"reflect"
"regexp"
"runtime"
"strings"
)
func fileWithLineNum() string {
for i := 2; i < 15; i++ {
_, file, line, ok := runtime.Caller(i)
if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) {
return fmt.Sprintf("%v:%v", strings.TrimPrefix(file, os.Getenv("GOPATH")+"src/"), line)
return fmt.Sprintf("%v:%v", file, line)
}
}
return ""