diff --git a/model_struct.go b/model_struct.go index 9e7e64df..0df487fb 100644 --- a/model_struct.go +++ b/model_struct.go @@ -63,56 +63,6 @@ type Relationship struct { JoinTable string } -func (scope *Scope) generateSqlTag(field *StructField) { - var sqlType string - structType := field.Struct.Type - if structType.Kind() == reflect.Ptr { - structType = structType.Elem() - } - reflectValue := reflect.Indirect(reflect.New(structType)) - sqlSettings := parseTagSetting(field.Tag.Get("sql")) - - if value, ok := sqlSettings["TYPE"]; ok { - sqlType = value - } - - additionalType := sqlSettings["NOT NULL"] + " " + sqlSettings["UNIQUE"] - if value, ok := sqlSettings["DEFAULT"]; ok { - additionalType = additionalType + "DEFAULT " + value - } - - if field.IsScanner { - var getScannerValue func(reflect.Value) - getScannerValue = func(value reflect.Value) { - reflectValue = value - if _, isScanner := reflect.New(reflectValue.Type()).Interface().(sql.Scanner); isScanner && reflectValue.Kind() == reflect.Struct { - getScannerValue(reflectValue.Field(0)) - } - } - getScannerValue(reflectValue) - } - - if sqlType == "" { - var size = 255 - - if value, ok := sqlSettings["SIZE"]; ok { - size, _ = strconv.Atoi(value) - } - - if field.IsPrimaryKey { - sqlType = scope.Dialect().PrimaryKeyTag(reflectValue, size) - } else { - sqlType = scope.Dialect().SqlTag(reflectValue, size) - } - } - - if strings.TrimSpace(additionalType) == "" { - field.SqlTag = sqlType - } else { - field.SqlTag = fmt.Sprintf("%v %v", sqlType, additionalType) - } -} - var pluralMapKeys = []*regexp.Regexp{regexp.MustCompile("ch$"), regexp.MustCompile("ss$"), regexp.MustCompile("sh$"), regexp.MustCompile("day$"), regexp.MustCompile("y$"), regexp.MustCompile("x$"), regexp.MustCompile("([^s])s?$")} var pluralMapValues = []string{"ches", "sses", "shes", "days", "ies", "xes", "${1}s"} @@ -341,3 +291,68 @@ func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct { func (scope *Scope) GetStructFields() (fields []*StructField) { return scope.GetModelStruct().StructFields } + +func (scope *Scope) generateSqlTag(field *StructField) { + var sqlType string + structType := field.Struct.Type + if structType.Kind() == reflect.Ptr { + structType = structType.Elem() + } + reflectValue := reflect.Indirect(reflect.New(structType)) + sqlSettings := parseTagSetting(field.Tag.Get("sql")) + + if value, ok := sqlSettings["TYPE"]; ok { + sqlType = value + } + + additionalType := sqlSettings["NOT NULL"] + " " + sqlSettings["UNIQUE"] + if value, ok := sqlSettings["DEFAULT"]; ok { + additionalType = additionalType + "DEFAULT " + value + } + + if field.IsScanner { + var getScannerValue func(reflect.Value) + getScannerValue = func(value reflect.Value) { + reflectValue = value + if _, isScanner := reflect.New(reflectValue.Type()).Interface().(sql.Scanner); isScanner && reflectValue.Kind() == reflect.Struct { + getScannerValue(reflectValue.Field(0)) + } + } + getScannerValue(reflectValue) + } + + if sqlType == "" { + var size = 255 + + if value, ok := sqlSettings["SIZE"]; ok { + size, _ = strconv.Atoi(value) + } + + if field.IsPrimaryKey { + sqlType = scope.Dialect().PrimaryKeyTag(reflectValue, size) + } else { + sqlType = scope.Dialect().SqlTag(reflectValue, size) + } + } + + if strings.TrimSpace(additionalType) == "" { + field.SqlTag = sqlType + } else { + field.SqlTag = fmt.Sprintf("%v %v", sqlType, additionalType) + } +} + +func parseTagSetting(str string) map[string]string { + tags := strings.Split(str, ";") + setting := map[string]string{} + for _, value := range tags { + v := strings.Split(value, ":") + k := strings.TrimSpace(strings.ToUpper(v[0])) + if len(v) == 2 { + setting[k] = v[1] + } else { + setting[k] = k + } + } + return setting +} diff --git a/scope.go b/scope.go index bddf7472..1191a637 100644 --- a/scope.go +++ b/scope.go @@ -204,8 +204,16 @@ func (scope *Scope) CallMethodWithErrorCheck(name string) { // AddToVars add value as sql's vars, gorm will escape them func (scope *Scope) AddToVars(value interface{}) string { - scope.SqlVars = append(scope.SqlVars, value) - return scope.Dialect().BinVar(len(scope.SqlVars)) + if expr, ok := value.(*expr); ok { + exp := expr.expr + for _, arg := range expr.args { + exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) + } + return exp + } else { + scope.SqlVars = append(scope.SqlVars, value) + return scope.Dialect().BinVar(len(scope.SqlVars)) + } } // TableName get table name diff --git a/scope_private.go b/scope_private.go index 31c7809c..6636ee83 100644 --- a/scope_private.go +++ b/scope_private.go @@ -307,17 +307,31 @@ func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignore return values, true } + var hasExpr bool fields := scope.Fields() for key, value := range values { if field, ok := fields[ToDBName(key)]; ok && field.Field.IsValid() { if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) { - if !equalAsString(field.Field.Interface(), value) { + if _, ok := value.(*expr); ok { + hasExpr = true + } else if !equalAsString(field.Field.Interface(), value) { hasUpdate = true field.Set(value) } } } } + if hasExpr { + var updateMap = map[string]interface{}{} + for key, value := range fields { + if v, ok := values[key]; ok { + updateMap[key] = v + } else { + updateMap[key] = value.Field.Interface() + } + } + return updateMap, true + } return } diff --git a/update_test.go b/update_test.go index e1227983..9b66fc01 100644 --- a/update_test.go +++ b/update_test.go @@ -3,6 +3,8 @@ package gorm_test import ( "testing" "time" + + "github.com/jinzhu/gorm" ) func TestUpdate(t *testing.T) { @@ -67,6 +69,17 @@ func TestUpdate(t *testing.T) { if count := DB.Model(Product{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(products)) { t.Error("RowsAffected should be correct when do batch update") } + + DB.First(&product4, product4.Id) + DB.Model(&product4).Update("price", gorm.Expr("price + ? - ?", 100, 50)) + var product5 Product + DB.First(&product5, product4.Id) + if product5.Price != product4.Price+100-50 { + t.Errorf("Update with expression") + } + if product5.UpdatedAt.Format(time.RFC3339Nano) == product4.UpdatedAt.Format(time.RFC3339Nano) { + t.Errorf("Update with expression should update UpdatedAt") + } } func TestUpdateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) { @@ -148,6 +161,16 @@ func TestUpdates(t *testing.T) { if DB.First(&Product{}, "code = ?", "product2newcode").RecordNotFound() { t.Errorf("product2's code should be updated") } + + DB.Model(&product4).Updates(map[string]interface{}{"price": gorm.Expr("price + ?", 100)}) + var product5 Product + DB.First(&product5, product4.Id) + if product5.Price != product4.Price+100 { + t.Errorf("Updates with expression") + } + if product5.UpdatedAt.Format(time.RFC3339Nano) == product4.UpdatedAt.Format(time.RFC3339Nano) { + t.Errorf("Updates with expression should update UpdatedAt") + } } func TestUpdateColumn(t *testing.T) { @@ -172,4 +195,14 @@ func TestUpdateColumn(t *testing.T) { if updatedAt2.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) { t.Errorf("updatedAt should not be updated with update column") } + + DB.Model(&product4).UpdateColumn("price", gorm.Expr("price + 100 - 50")) + var product5 Product + DB.First(&product5, product4.Id) + if product5.Price != product4.Price+100-50 { + t.Errorf("UpdateColumn with expression") + } + if product5.UpdatedAt.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) { + t.Errorf("UpdateColumn with expression should not update UpdatedAt") + } } diff --git a/utils.go b/utils.go index 3298eb7d..ca7e04e8 100644 --- a/utils.go +++ b/utils.go @@ -38,17 +38,11 @@ func ToDBName(name string) string { return s } -func parseTagSetting(str string) map[string]string { - tags := strings.Split(str, ";") - setting := map[string]string{} - for _, value := range tags { - v := strings.Split(value, ":") - k := strings.TrimSpace(strings.ToUpper(v[0])) - if len(v) == 2 { - setting[k] = v[1] - } else { - setting[k] = k - } - } - return setting +type expr struct { + expr string + args []interface{} +} + +func Expr(expression string, args ...interface{}) *expr { + return &expr{expr: expression, args: args} }