diff --git a/dialect/dialect.go b/dialect/dialect.go index 8aa3ae31..5a829735 100644 --- a/dialect/dialect.go +++ b/dialect/dialect.go @@ -1,7 +1,7 @@ package dialect type Dialect interface { - BinVar(i int) string + BinVar() string SupportLastInsertId() bool SqlTag(column interface{}, size int) string PrimaryKeyTag(column interface{}, size int) string diff --git a/dialect/mysql.go b/dialect/mysql.go index 0e71612b..f9e05170 100644 --- a/dialect/mysql.go +++ b/dialect/mysql.go @@ -8,7 +8,7 @@ import ( type mysql struct{} -func (s *mysql) BinVar(i int) string { +func (s *mysql) BinVar() string { return "?" } diff --git a/dialect/postgres.go b/dialect/postgres.go index 13a0afa0..833a4a09 100644 --- a/dialect/postgres.go +++ b/dialect/postgres.go @@ -9,8 +9,8 @@ import ( type postgres struct { } -func (s *postgres) BinVar(i int) string { - return fmt.Sprintf("$%v", i) +func (s *postgres) BinVar() string { + return "$%v" } func (s *postgres) SupportLastInsertId() bool { diff --git a/dialect/sqlite3.go b/dialect/sqlite3.go index 3dec02ef..84e37c48 100644 --- a/dialect/sqlite3.go +++ b/dialect/sqlite3.go @@ -8,7 +8,7 @@ import ( type sqlite3 struct{} -func (s *sqlite3) BinVar(i int) string { +func (s *sqlite3) BinVar() string { return "?" } diff --git a/do.go b/do.go index 81f0e87f..d70d2c5d 100644 --- a/do.go +++ b/do.go @@ -60,7 +60,7 @@ func (s *Do) setModel(value interface{}) *Do { func (s *Do) addToVars(value interface{}) string { s.sqlVars = append(s.sqlVars, value) - return s.chain.d.dialect.BinVar(len(s.sqlVars)) + return fmt.Sprintf(s.chain.d.dialect.BinVar(), len(s.sqlVars)) } func (s *Do) exec(sqls ...string) (err error) { @@ -209,7 +209,7 @@ func (s *Do) setUpdateAttrs(values interface{}, ignore_protected_attrs ...bool) m := &Model{data: values, do: s} s.updateAttrs = map[string]interface{}{} for _, field := range m.columnsHasValue("other") { - s.updateAttrs[field.DbName] = field.Value + s.updateAttrs[field.dbName] = field.Value } } @@ -473,7 +473,7 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) { m := &Model{data: value, do: s} var sqls []string for _, field := range m.columnsHasValue("other") { - sqls = append(sqls, fmt.Sprintf("(%v = %v)", field.DbName, s.addToVars(field.Value))) + sqls = append(sqls, fmt.Sprintf("(%v = %v)", field.dbName, s.addToVars(field.Value))) } return strings.Join(sqls, " AND ") } @@ -532,7 +532,7 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) { m := &Model{data: value, do: s} var sqls []string for _, field := range m.columnsHasValue("other") { - sqls = append(sqls, fmt.Sprintf("(%v <> %v)", field.DbName, s.addToVars(field.Value))) + sqls = append(sqls, fmt.Sprintf("(%v <> %v)", field.dbName, s.addToVars(field.Value))) } return strings.Join(sqls, " AND ") } @@ -641,7 +641,7 @@ func (s *Do) createTable() *Do { var sqls []string for _, field := range s.model.fields("migration") { if len(field.sqlTag()) > 0 { - sqls = append(sqls, field.DbName+" "+field.sqlTag()) + sqls = append(sqls, field.dbName+" "+field.sqlTag()) } } @@ -697,12 +697,12 @@ func (s *Do) autoMigrate() *Do { for _, field := range s.model.fields("migration") { var column_name, data_type string sql := fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %v", s.addToVars(s.tableName())) - s.db.QueryRow(fmt.Sprintf(sql+" and column_name = %v", s.addToVars(field.DbName)), s.sqlVars...).Scan(&column_name, &data_type) + s.db.QueryRow(fmt.Sprintf(sql+" and column_name = %v", s.addToVars(field.dbName)), s.sqlVars...).Scan(&column_name, &data_type) s.sqlVars = []interface{}{} // If column doesn't exist if len(column_name) == 0 && len(field.sqlTag()) > 0 { - s.sql = fmt.Sprintf("ALTER TABLE %v ADD %v %v;", s.tableName(), field.DbName, field.sqlTag()) + s.sql = fmt.Sprintf("ALTER TABLE %v ADD %v %v;", s.tableName(), field.dbName, field.sqlTag()) s.exec() } } @@ -745,7 +745,7 @@ func (s *Do) initializeWithSearchCondition() { case reflect.Struct: m := &Model{data: obj, do: s} for _, field := range m.columnsHasValue("other") { - s.model.setValueByColumn(field.DbName, field.Value, s.value) + s.model.setValueByColumn(field.dbName, field.Value, s.value) } case reflect.Map: for key, value := range obj.(map[string]interface{}) { @@ -756,7 +756,7 @@ func (s *Do) initializeWithSearchCondition() { case interface{}: m := &Model{data: value, do: s} for _, field := range m.columnsHasValue("other") { - s.model.setValueByColumn(field.DbName, field.Value, s.value) + s.model.setValueByColumn(field.dbName, field.Value, s.value) } } } diff --git a/field.go b/field.go index 1c5d399c..c30ef665 100644 --- a/field.go +++ b/field.go @@ -2,85 +2,68 @@ package gorm import ( "database/sql" - "database/sql/driver" - - "time" - + "reflect" "strconv" "strings" - - "reflect" + "time" ) type Field struct { Name string Value interface{} - DbName string - AutoCreateTime bool - AutoUpdateTime bool - IsPrimaryKey bool - structField reflect.StructField - modelValue reflect.Value + model *Model + dbName string + isPrimaryKey bool + autoCreateTime bool + autoUpdateTime bool + foreignKey string beforeAssociation bool afterAssociation bool - foreignKey string - model *Model + reflectValue reflect.Value + structField reflect.StructField } func (f *Field) isBlank() bool { - value := reflect.ValueOf(f.Value) - switch value.Kind() { - case reflect.Int, reflect.Int64, reflect.Int32: - return value.Int() == 0 - case reflect.String: - return value.String() == "" - case reflect.Slice: - return value.Len() == 0 - case reflect.Struct: - time_value, is_time := f.Value.(time.Time) - if is_time { - return time_value.IsZero() - } else { - _, is_scanner := reflect.New(value.Type()).Interface().(sql.Scanner) - if is_scanner { - return !value.FieldByName("Valid").Interface().(bool) - } else { - m := &Model{data: value.Interface(), do: f.model.do} - fields := m.columnsHasValue("other") - if len(fields) == 0 { - return true - } - } - } - } - return false + return isBlank(f.reflectValue) } -func (f *Field) sqlTag() string { - column := getInterfaceValue(f.Value) - field_value := reflect.ValueOf(f.Value) - switch field_value.Kind() { +func (f *Field) isScanner() bool { + _, is_scanner := reflect.New(f.reflectValue.Type()).Interface().(sql.Scanner) + return is_scanner +} + +func (f *Field) isTime() bool { + _, is_time := f.Value.(time.Time) + return is_time +} + +func (f *Field) sqlTag() (str string) { + value := f.Value + if f.isScanner() { + value = f.reflectValue.Field(0).Interface() + } + reflect_value := f.reflectValue + + switch reflect_value.Kind() { case reflect.Slice: - return "" + return case reflect.Struct: - _, is_scanner := reflect.New(field_value.Type()).Interface().(sql.Scanner) - _, is_time := column.(time.Time) - if !is_time && !is_scanner { - return "" + if !f.isTime() && !f.isScanner() { + return } } typ, addational_typ, size := parseSqlTag(f.structField.Tag.Get(tagIdentifier)) if typ == "-" { - return "" + return } if len(typ) == 0 { - if f.IsPrimaryKey { - typ = f.model.do.chain.d.dialect.PrimaryKeyTag(column, size) + if f.isPrimaryKey { + typ = f.model.do.chain.d.dialect.PrimaryKeyTag(value, size) } else { - typ = f.model.do.chain.d.dialect.SqlTag(column, size) + typ = f.model.do.chain.d.dialect.SqlTag(value, size) } } @@ -91,26 +74,23 @@ func (f *Field) sqlTag() string { } func (f *Field) parseAssociation() { - field_value := reflect.ValueOf(f.Value) + reflect_value := f.reflectValue - switch field_value.Kind() { + switch reflect_value.Kind() { case reflect.Slice: foreign_key := f.model.typeName() + "Id" - if reflect.New(field_value.Type().Elem()).Elem().FieldByName(foreign_key).IsValid() { + if reflect.New(reflect_value.Type().Elem()).Elem().FieldByName(foreign_key).IsValid() { f.foreignKey = foreign_key } f.afterAssociation = true case reflect.Struct: - _, is_time := f.Value.(time.Time) - _, is_scanner := reflect.New(field_value.Type()).Interface().(sql.Scanner) - - if !is_scanner && !is_time { - if f.modelValue.FieldByName(f.Name + "Id").IsValid() { + if !f.isTime() && !f.isScanner() { + if f.model.reflectData().FieldByName(f.Name + "Id").IsValid() { f.foreignKey = f.Name + "Id" f.beforeAssociation = true } else { foreign_key := f.model.typeName() + "Id" - if reflect.New(field_value.Type()).Elem().FieldByName(foreign_key).IsValid() { + if reflect.New(reflect_value.Type()).Elem().FieldByName(foreign_key).IsValid() { f.foreignKey = foreign_key } f.afterAssociation = true @@ -147,14 +127,3 @@ func parseSqlTag(str string) (typ string, addational_typ string, size int) { } return } - -func getInterfaceValue(column interface{}) interface{} { - if v, ok := column.(reflect.Value); ok { - column = v.Interface() - } - - if valuer, ok := interface{}(column).(driver.Valuer); ok { - column = reflect.New(reflect.ValueOf(valuer).Field(0).Type()).Elem().Interface() - } - return column -} diff --git a/model.go b/model.go index d4484fcc..da595d71 100644 --- a/model.go +++ b/model.go @@ -14,33 +14,22 @@ type Model struct { _cache_fields map[string][]*Field } -func (m *Model) primaryKeyZero() bool { - return m.primaryKeyValue() <= 0 +func (m *Model) reflectData() reflect.Value { + return reflect.Indirect(reflect.ValueOf(m.data)) } -func (m *Model) primaryKeyValue() int64 { - if m.data == nil { - return -1 - } - data := reflect.Indirect(reflect.ValueOf(m.data)) +func (m *Model) primaryKeyZero() bool { + return isBlank(reflect.ValueOf(m.primaryKeyValue())) +} - switch data.Kind() { - case reflect.Array, reflect.Chan, reflect.Map, reflect.Ptr, reflect.Slice: - return 0 - default: - value := data.FieldByName(m.primaryKey()) - - if value.IsValid() { - switch value.Kind() { - case reflect.Int, reflect.Int64, reflect.Int32: - return value.Int() - default: - return 0 - } - } else { - return 0 +func (m *Model) primaryKeyValue() interface{} { + if data := m.reflectData(); data.Kind() == reflect.Struct { + field := data.FieldByName(m.primaryKey()) + if data.FieldByName(m.primaryKey()).IsValid() { + return field.Interface() } } + return 0 } func (m *Model) primaryKey() string { @@ -56,7 +45,7 @@ func (m *Model) fields(operation string) (fields []*Field) { return m._cache_fields[operation] } - indirect_value := reflect.Indirect(reflect.ValueOf(m.data)) + indirect_value := m.reflectData() if !indirect_value.IsValid() { return } @@ -67,30 +56,29 @@ func (m *Model) fields(operation string) (fields []*Field) { if !p.Anonymous && ast.IsExported(p.Name) { var field Field field.Name = p.Name - field.DbName = toSnake(p.Name) - field.IsPrimaryKey = m.primaryKeyDb() == field.DbName + field.dbName = toSnake(p.Name) + field.isPrimaryKey = m.primaryKeyDb() == field.dbName value := indirect_value.FieldByName(p.Name) - time_value, is_time := value.Interface().(time.Time) field.model = m - field.modelValue = indirect_value - if is_time { - field.AutoCreateTime = "created_at" == field.DbName - field.AutoUpdateTime = "updated_at" == field.DbName + if time_value, is_time := value.Interface().(time.Time); is_time { + field.autoCreateTime = "created_at" == field.dbName + field.autoUpdateTime = "updated_at" == field.dbName switch operation { case "create": - if (field.AutoCreateTime || field.AutoUpdateTime) && time_value.IsZero() { + if (field.autoCreateTime || field.autoUpdateTime) && time_value.IsZero() { value.Set(reflect.ValueOf(time.Now())) } case "update": - if field.AutoUpdateTime { + if field.autoUpdateTime { value.Set(reflect.ValueOf(time.Now())) } } } field.structField = p + field.reflectValue = value field.Value = value.Interface() fields = append(fields, &field) } @@ -117,17 +105,16 @@ func (m *Model) updatedColumnsAndValues(values map[string]interface{}) (results return values, true } - data := reflect.Indirect(reflect.ValueOf(m.data)) + data := m.reflectData() for key, value := range values { - field := data.FieldByName(snakeToUpperCamel(key)) - if field.IsValid() { + if field := data.FieldByName(snakeToUpperCamel(key)); field.IsValid() { if field.Interface() != value { switch field.Kind() { case reflect.Int, reflect.Int32, reflect.Int64: if field.Int() != reflect.ValueOf(value).Int() { any_updated = true + field.SetInt(reflect.ValueOf(value).Int()) } - field.SetInt(reflect.ValueOf(value).Int()) default: any_updated = true field.Set(reflect.ValueOf(value)) @@ -145,51 +132,46 @@ func (m *Model) updatedColumnsAndValues(values map[string]interface{}) (results func (m *Model) columnsAndValues(operation string) map[string]interface{} { results := map[string]interface{}{} - if m.data != nil { - for _, field := range m.fields(operation) { - if !field.IsPrimaryKey && (len(field.sqlTag()) > 0) { - results[field.DbName] = field.Value - } + for _, field := range m.fields(operation) { + if !field.isPrimaryKey && (len(field.sqlTag()) > 0) { + results[field.dbName] = field.Value } } return results } func (m *Model) hasColumn(name string) bool { - if m.data == nil { - return false - } + data := m.reflectData() - data := reflect.Indirect(reflect.ValueOf(m.data)) - if data.Kind() == reflect.Slice { - return reflect.New(data.Type().Elem()).Elem().FieldByName(name).IsValid() - } else { + if data.Kind() == reflect.Struct { return data.FieldByName(name).IsValid() + } else if data.Kind() == reflect.Slice { + return reflect.New(data.Type().Elem()).Elem().FieldByName(name).IsValid() } + return false } func (m *Model) ColumnAndValue(name string) (has_column bool, is_slice bool, value interface{}) { - if m.data != nil { - data := reflect.Indirect(reflect.ValueOf(m.data)) - if data.Kind() == reflect.Slice { - has_column = reflect.New(data.Type().Elem()).Elem().FieldByName(name).IsValid() - is_slice = true - } else { - if has_column = data.FieldByName(name).IsValid(); has_column { - value = data.FieldByName(name).Interface() - } + data := m.reflectData() + + if data.Kind() == reflect.Struct { + if has_column = data.FieldByName(name).IsValid(); has_column { + value = data.FieldByName(name).Interface() } + } else if data.Kind() == reflect.Slice { + has_column = reflect.New(data.Type().Elem()).Elem().FieldByName(name).IsValid() + is_slice = true } return } func (m *Model) typeName() string { - typ := reflect.Indirect(reflect.ValueOf(m.data)).Type() + typ := m.reflectData().Type() if typ.Kind() == reflect.Slice { - typ = typ.Elem() + return typ.Elem().Name() + } else { + return typ.Name() } - - return typ.Name() } func (m *Model) tableName() (str string) { @@ -198,8 +180,8 @@ func (m *Model) tableName() (str string) { return } - fm := reflect.Indirect(reflect.ValueOf(m.data)).MethodByName("TableName") - if fm.IsValid() { + data := m.reflectData() + if fm := data.MethodByName("TableName"); fm.IsValid() { if v := fm.Call([]reflect.Value{}); len(v) > 0 { if result, ok := v[0].Interface().(string); ok { return result @@ -227,8 +209,7 @@ func (m *Model) callMethod(method string) { return } - fm := reflect.ValueOf(m.data).MethodByName(method) - if fm.IsValid() { + if fm := reflect.ValueOf(m.data).MethodByName(method); fm.IsValid() { if v := fm.Call([]reflect.Value{}); len(v) > 0 { if verr, ok := v[0].Interface().(error); ok { m.do.err(verr) @@ -255,7 +236,6 @@ func (m *Model) beforeAssociations() (fields []*Field) { func (m *Model) afterAssociations() (fields []*Field) { for _, field := range m.fields("null") { - field.parseAssociation() if field.afterAssociation && !field.isBlank() { fields = append(fields, field) } diff --git a/db.go b/sql.go similarity index 100% rename from db.go rename to sql.go diff --git a/utils.go b/utils.go index 3ea863d4..ea39a11a 100644 --- a/utils.go +++ b/utils.go @@ -3,13 +3,12 @@ package gorm import ( "bytes" "database/sql" - "errors" + "fmt" "reflect" "strconv" - - "fmt" "strings" + "time" ) func toSnake(s string) string { @@ -87,3 +86,31 @@ func setFieldValue(field reflect.Value, value interface{}) bool { return false } + +func isBlank(value reflect.Value) bool { + switch value.Kind() { + case reflect.Int, reflect.Int64, reflect.Int32: + return value.Int() == 0 + case reflect.String: + return value.String() == "" + case reflect.Slice: + return value.Len() == 0 + case reflect.Struct: + time_value, is_time := value.Interface().(time.Time) + if is_time { + return time_value.IsZero() + } else { + _, is_scanner := reflect.New(value.Type()).Interface().(sql.Scanner) + if is_scanner { + return !value.FieldByName("Valid").Interface().(bool) + } else { + m := &Model{data: value.Interface()} + fields := m.columnsHasValue("other") + if len(fields) == 0 { + return true + } + } + } + } + return false +}