Support SQL expression

This commit is contained in:
Jinzhu 2015-02-24 22:06:35 +08:00
parent eb480cc085
commit 10340e6ad7
5 changed files with 130 additions and 66 deletions

View File

@ -63,56 +63,6 @@ type Relationship struct {
JoinTable string JoinTable string
} }
func (scope *Scope) generateSqlTag(field *StructField) {
var sqlType string
structType := field.Struct.Type
if structType.Kind() == reflect.Ptr {
structType = structType.Elem()
}
reflectValue := reflect.Indirect(reflect.New(structType))
sqlSettings := parseTagSetting(field.Tag.Get("sql"))
if value, ok := sqlSettings["TYPE"]; ok {
sqlType = value
}
additionalType := sqlSettings["NOT NULL"] + " " + sqlSettings["UNIQUE"]
if value, ok := sqlSettings["DEFAULT"]; ok {
additionalType = additionalType + "DEFAULT " + value
}
if field.IsScanner {
var getScannerValue func(reflect.Value)
getScannerValue = func(value reflect.Value) {
reflectValue = value
if _, isScanner := reflect.New(reflectValue.Type()).Interface().(sql.Scanner); isScanner && reflectValue.Kind() == reflect.Struct {
getScannerValue(reflectValue.Field(0))
}
}
getScannerValue(reflectValue)
}
if sqlType == "" {
var size = 255
if value, ok := sqlSettings["SIZE"]; ok {
size, _ = strconv.Atoi(value)
}
if field.IsPrimaryKey {
sqlType = scope.Dialect().PrimaryKeyTag(reflectValue, size)
} else {
sqlType = scope.Dialect().SqlTag(reflectValue, size)
}
}
if strings.TrimSpace(additionalType) == "" {
field.SqlTag = sqlType
} else {
field.SqlTag = fmt.Sprintf("%v %v", sqlType, additionalType)
}
}
var pluralMapKeys = []*regexp.Regexp{regexp.MustCompile("ch$"), regexp.MustCompile("ss$"), regexp.MustCompile("sh$"), regexp.MustCompile("day$"), regexp.MustCompile("y$"), regexp.MustCompile("x$"), regexp.MustCompile("([^s])s?$")} var pluralMapKeys = []*regexp.Regexp{regexp.MustCompile("ch$"), regexp.MustCompile("ss$"), regexp.MustCompile("sh$"), regexp.MustCompile("day$"), regexp.MustCompile("y$"), regexp.MustCompile("x$"), regexp.MustCompile("([^s])s?$")}
var pluralMapValues = []string{"ches", "sses", "shes", "days", "ies", "xes", "${1}s"} var pluralMapValues = []string{"ches", "sses", "shes", "days", "ies", "xes", "${1}s"}
@ -341,3 +291,68 @@ func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct {
func (scope *Scope) GetStructFields() (fields []*StructField) { func (scope *Scope) GetStructFields() (fields []*StructField) {
return scope.GetModelStruct().StructFields return scope.GetModelStruct().StructFields
} }
func (scope *Scope) generateSqlTag(field *StructField) {
var sqlType string
structType := field.Struct.Type
if structType.Kind() == reflect.Ptr {
structType = structType.Elem()
}
reflectValue := reflect.Indirect(reflect.New(structType))
sqlSettings := parseTagSetting(field.Tag.Get("sql"))
if value, ok := sqlSettings["TYPE"]; ok {
sqlType = value
}
additionalType := sqlSettings["NOT NULL"] + " " + sqlSettings["UNIQUE"]
if value, ok := sqlSettings["DEFAULT"]; ok {
additionalType = additionalType + "DEFAULT " + value
}
if field.IsScanner {
var getScannerValue func(reflect.Value)
getScannerValue = func(value reflect.Value) {
reflectValue = value
if _, isScanner := reflect.New(reflectValue.Type()).Interface().(sql.Scanner); isScanner && reflectValue.Kind() == reflect.Struct {
getScannerValue(reflectValue.Field(0))
}
}
getScannerValue(reflectValue)
}
if sqlType == "" {
var size = 255
if value, ok := sqlSettings["SIZE"]; ok {
size, _ = strconv.Atoi(value)
}
if field.IsPrimaryKey {
sqlType = scope.Dialect().PrimaryKeyTag(reflectValue, size)
} else {
sqlType = scope.Dialect().SqlTag(reflectValue, size)
}
}
if strings.TrimSpace(additionalType) == "" {
field.SqlTag = sqlType
} else {
field.SqlTag = fmt.Sprintf("%v %v", sqlType, additionalType)
}
}
func parseTagSetting(str string) map[string]string {
tags := strings.Split(str, ";")
setting := map[string]string{}
for _, value := range tags {
v := strings.Split(value, ":")
k := strings.TrimSpace(strings.ToUpper(v[0]))
if len(v) == 2 {
setting[k] = v[1]
} else {
setting[k] = k
}
}
return setting
}

View File

@ -204,9 +204,17 @@ func (scope *Scope) CallMethodWithErrorCheck(name string) {
// AddToVars add value as sql's vars, gorm will escape them // AddToVars add value as sql's vars, gorm will escape them
func (scope *Scope) AddToVars(value interface{}) string { func (scope *Scope) AddToVars(value interface{}) string {
if expr, ok := value.(*expr); ok {
exp := expr.expr
for _, arg := range expr.args {
exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1)
}
return exp
} else {
scope.SqlVars = append(scope.SqlVars, value) scope.SqlVars = append(scope.SqlVars, value)
return scope.Dialect().BinVar(len(scope.SqlVars)) return scope.Dialect().BinVar(len(scope.SqlVars))
} }
}
// TableName get table name // TableName get table name
func (scope *Scope) TableName() string { func (scope *Scope) TableName() string {

View File

@ -307,17 +307,31 @@ func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignore
return values, true return values, true
} }
var hasExpr bool
fields := scope.Fields() fields := scope.Fields()
for key, value := range values { for key, value := range values {
if field, ok := fields[ToDBName(key)]; ok && field.Field.IsValid() { if field, ok := fields[ToDBName(key)]; ok && field.Field.IsValid() {
if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) { if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) {
if !equalAsString(field.Field.Interface(), value) { if _, ok := value.(*expr); ok {
hasExpr = true
} else if !equalAsString(field.Field.Interface(), value) {
hasUpdate = true hasUpdate = true
field.Set(value) field.Set(value)
} }
} }
} }
} }
if hasExpr {
var updateMap = map[string]interface{}{}
for key, value := range fields {
if v, ok := values[key]; ok {
updateMap[key] = v
} else {
updateMap[key] = value.Field.Interface()
}
}
return updateMap, true
}
return return
} }

View File

@ -3,6 +3,8 @@ package gorm_test
import ( import (
"testing" "testing"
"time" "time"
"github.com/jinzhu/gorm"
) )
func TestUpdate(t *testing.T) { func TestUpdate(t *testing.T) {
@ -67,6 +69,17 @@ func TestUpdate(t *testing.T) {
if count := DB.Model(Product{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(products)) { if count := DB.Model(Product{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(products)) {
t.Error("RowsAffected should be correct when do batch update") t.Error("RowsAffected should be correct when do batch update")
} }
DB.First(&product4, product4.Id)
DB.Model(&product4).Update("price", gorm.Expr("price + ? - ?", 100, 50))
var product5 Product
DB.First(&product5, product4.Id)
if product5.Price != product4.Price+100-50 {
t.Errorf("Update with expression")
}
if product5.UpdatedAt.Format(time.RFC3339Nano) == product4.UpdatedAt.Format(time.RFC3339Nano) {
t.Errorf("Update with expression should update UpdatedAt")
}
} }
func TestUpdateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) { func TestUpdateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) {
@ -148,6 +161,16 @@ func TestUpdates(t *testing.T) {
if DB.First(&Product{}, "code = ?", "product2newcode").RecordNotFound() { if DB.First(&Product{}, "code = ?", "product2newcode").RecordNotFound() {
t.Errorf("product2's code should be updated") t.Errorf("product2's code should be updated")
} }
DB.Model(&product4).Updates(map[string]interface{}{"price": gorm.Expr("price + ?", 100)})
var product5 Product
DB.First(&product5, product4.Id)
if product5.Price != product4.Price+100 {
t.Errorf("Updates with expression")
}
if product5.UpdatedAt.Format(time.RFC3339Nano) == product4.UpdatedAt.Format(time.RFC3339Nano) {
t.Errorf("Updates with expression should update UpdatedAt")
}
} }
func TestUpdateColumn(t *testing.T) { func TestUpdateColumn(t *testing.T) {
@ -172,4 +195,14 @@ func TestUpdateColumn(t *testing.T) {
if updatedAt2.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) { if updatedAt2.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) {
t.Errorf("updatedAt should not be updated with update column") t.Errorf("updatedAt should not be updated with update column")
} }
DB.Model(&product4).UpdateColumn("price", gorm.Expr("price + 100 - 50"))
var product5 Product
DB.First(&product5, product4.Id)
if product5.Price != product4.Price+100-50 {
t.Errorf("UpdateColumn with expression")
}
if product5.UpdatedAt.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) {
t.Errorf("UpdateColumn with expression should not update UpdatedAt")
}
} }

View File

@ -38,17 +38,11 @@ func ToDBName(name string) string {
return s return s
} }
func parseTagSetting(str string) map[string]string { type expr struct {
tags := strings.Split(str, ";") expr string
setting := map[string]string{} args []interface{}
for _, value := range tags {
v := strings.Split(value, ":")
k := strings.TrimSpace(strings.ToUpper(v[0]))
if len(v) == 2 {
setting[k] = v[1]
} else {
setting[k] = k
} }
}
return setting func Expr(expression string, args ...interface{}) *expr {
return &expr{expr: expression, args: args}
} }