mirror of https://github.com/go-gorm/gorm.git
Support SQL expression
This commit is contained in:
parent
eb480cc085
commit
10340e6ad7
115
model_struct.go
115
model_struct.go
|
@ -63,56 +63,6 @@ type Relationship struct {
|
|||
JoinTable string
|
||||
}
|
||||
|
||||
func (scope *Scope) generateSqlTag(field *StructField) {
|
||||
var sqlType string
|
||||
structType := field.Struct.Type
|
||||
if structType.Kind() == reflect.Ptr {
|
||||
structType = structType.Elem()
|
||||
}
|
||||
reflectValue := reflect.Indirect(reflect.New(structType))
|
||||
sqlSettings := parseTagSetting(field.Tag.Get("sql"))
|
||||
|
||||
if value, ok := sqlSettings["TYPE"]; ok {
|
||||
sqlType = value
|
||||
}
|
||||
|
||||
additionalType := sqlSettings["NOT NULL"] + " " + sqlSettings["UNIQUE"]
|
||||
if value, ok := sqlSettings["DEFAULT"]; ok {
|
||||
additionalType = additionalType + "DEFAULT " + value
|
||||
}
|
||||
|
||||
if field.IsScanner {
|
||||
var getScannerValue func(reflect.Value)
|
||||
getScannerValue = func(value reflect.Value) {
|
||||
reflectValue = value
|
||||
if _, isScanner := reflect.New(reflectValue.Type()).Interface().(sql.Scanner); isScanner && reflectValue.Kind() == reflect.Struct {
|
||||
getScannerValue(reflectValue.Field(0))
|
||||
}
|
||||
}
|
||||
getScannerValue(reflectValue)
|
||||
}
|
||||
|
||||
if sqlType == "" {
|
||||
var size = 255
|
||||
|
||||
if value, ok := sqlSettings["SIZE"]; ok {
|
||||
size, _ = strconv.Atoi(value)
|
||||
}
|
||||
|
||||
if field.IsPrimaryKey {
|
||||
sqlType = scope.Dialect().PrimaryKeyTag(reflectValue, size)
|
||||
} else {
|
||||
sqlType = scope.Dialect().SqlTag(reflectValue, size)
|
||||
}
|
||||
}
|
||||
|
||||
if strings.TrimSpace(additionalType) == "" {
|
||||
field.SqlTag = sqlType
|
||||
} else {
|
||||
field.SqlTag = fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||
}
|
||||
}
|
||||
|
||||
var pluralMapKeys = []*regexp.Regexp{regexp.MustCompile("ch$"), regexp.MustCompile("ss$"), regexp.MustCompile("sh$"), regexp.MustCompile("day$"), regexp.MustCompile("y$"), regexp.MustCompile("x$"), regexp.MustCompile("([^s])s?$")}
|
||||
var pluralMapValues = []string{"ches", "sses", "shes", "days", "ies", "xes", "${1}s"}
|
||||
|
||||
|
@ -341,3 +291,68 @@ func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct {
|
|||
func (scope *Scope) GetStructFields() (fields []*StructField) {
|
||||
return scope.GetModelStruct().StructFields
|
||||
}
|
||||
|
||||
func (scope *Scope) generateSqlTag(field *StructField) {
|
||||
var sqlType string
|
||||
structType := field.Struct.Type
|
||||
if structType.Kind() == reflect.Ptr {
|
||||
structType = structType.Elem()
|
||||
}
|
||||
reflectValue := reflect.Indirect(reflect.New(structType))
|
||||
sqlSettings := parseTagSetting(field.Tag.Get("sql"))
|
||||
|
||||
if value, ok := sqlSettings["TYPE"]; ok {
|
||||
sqlType = value
|
||||
}
|
||||
|
||||
additionalType := sqlSettings["NOT NULL"] + " " + sqlSettings["UNIQUE"]
|
||||
if value, ok := sqlSettings["DEFAULT"]; ok {
|
||||
additionalType = additionalType + "DEFAULT " + value
|
||||
}
|
||||
|
||||
if field.IsScanner {
|
||||
var getScannerValue func(reflect.Value)
|
||||
getScannerValue = func(value reflect.Value) {
|
||||
reflectValue = value
|
||||
if _, isScanner := reflect.New(reflectValue.Type()).Interface().(sql.Scanner); isScanner && reflectValue.Kind() == reflect.Struct {
|
||||
getScannerValue(reflectValue.Field(0))
|
||||
}
|
||||
}
|
||||
getScannerValue(reflectValue)
|
||||
}
|
||||
|
||||
if sqlType == "" {
|
||||
var size = 255
|
||||
|
||||
if value, ok := sqlSettings["SIZE"]; ok {
|
||||
size, _ = strconv.Atoi(value)
|
||||
}
|
||||
|
||||
if field.IsPrimaryKey {
|
||||
sqlType = scope.Dialect().PrimaryKeyTag(reflectValue, size)
|
||||
} else {
|
||||
sqlType = scope.Dialect().SqlTag(reflectValue, size)
|
||||
}
|
||||
}
|
||||
|
||||
if strings.TrimSpace(additionalType) == "" {
|
||||
field.SqlTag = sqlType
|
||||
} else {
|
||||
field.SqlTag = fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||
}
|
||||
}
|
||||
|
||||
func parseTagSetting(str string) map[string]string {
|
||||
tags := strings.Split(str, ";")
|
||||
setting := map[string]string{}
|
||||
for _, value := range tags {
|
||||
v := strings.Split(value, ":")
|
||||
k := strings.TrimSpace(strings.ToUpper(v[0]))
|
||||
if len(v) == 2 {
|
||||
setting[k] = v[1]
|
||||
} else {
|
||||
setting[k] = k
|
||||
}
|
||||
}
|
||||
return setting
|
||||
}
|
||||
|
|
12
scope.go
12
scope.go
|
@ -204,8 +204,16 @@ func (scope *Scope) CallMethodWithErrorCheck(name string) {
|
|||
|
||||
// AddToVars add value as sql's vars, gorm will escape them
|
||||
func (scope *Scope) AddToVars(value interface{}) string {
|
||||
scope.SqlVars = append(scope.SqlVars, value)
|
||||
return scope.Dialect().BinVar(len(scope.SqlVars))
|
||||
if expr, ok := value.(*expr); ok {
|
||||
exp := expr.expr
|
||||
for _, arg := range expr.args {
|
||||
exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1)
|
||||
}
|
||||
return exp
|
||||
} else {
|
||||
scope.SqlVars = append(scope.SqlVars, value)
|
||||
return scope.Dialect().BinVar(len(scope.SqlVars))
|
||||
}
|
||||
}
|
||||
|
||||
// TableName get table name
|
||||
|
|
|
@ -307,17 +307,31 @@ func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignore
|
|||
return values, true
|
||||
}
|
||||
|
||||
var hasExpr bool
|
||||
fields := scope.Fields()
|
||||
for key, value := range values {
|
||||
if field, ok := fields[ToDBName(key)]; ok && field.Field.IsValid() {
|
||||
if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) {
|
||||
if !equalAsString(field.Field.Interface(), value) {
|
||||
if _, ok := value.(*expr); ok {
|
||||
hasExpr = true
|
||||
} else if !equalAsString(field.Field.Interface(), value) {
|
||||
hasUpdate = true
|
||||
field.Set(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if hasExpr {
|
||||
var updateMap = map[string]interface{}{}
|
||||
for key, value := range fields {
|
||||
if v, ok := values[key]; ok {
|
||||
updateMap[key] = v
|
||||
} else {
|
||||
updateMap[key] = value.Field.Interface()
|
||||
}
|
||||
}
|
||||
return updateMap, true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -3,6 +3,8 @@ package gorm_test
|
|||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
func TestUpdate(t *testing.T) {
|
||||
|
@ -67,6 +69,17 @@ func TestUpdate(t *testing.T) {
|
|||
if count := DB.Model(Product{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(products)) {
|
||||
t.Error("RowsAffected should be correct when do batch update")
|
||||
}
|
||||
|
||||
DB.First(&product4, product4.Id)
|
||||
DB.Model(&product4).Update("price", gorm.Expr("price + ? - ?", 100, 50))
|
||||
var product5 Product
|
||||
DB.First(&product5, product4.Id)
|
||||
if product5.Price != product4.Price+100-50 {
|
||||
t.Errorf("Update with expression")
|
||||
}
|
||||
if product5.UpdatedAt.Format(time.RFC3339Nano) == product4.UpdatedAt.Format(time.RFC3339Nano) {
|
||||
t.Errorf("Update with expression should update UpdatedAt")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) {
|
||||
|
@ -148,6 +161,16 @@ func TestUpdates(t *testing.T) {
|
|||
if DB.First(&Product{}, "code = ?", "product2newcode").RecordNotFound() {
|
||||
t.Errorf("product2's code should be updated")
|
||||
}
|
||||
|
||||
DB.Model(&product4).Updates(map[string]interface{}{"price": gorm.Expr("price + ?", 100)})
|
||||
var product5 Product
|
||||
DB.First(&product5, product4.Id)
|
||||
if product5.Price != product4.Price+100 {
|
||||
t.Errorf("Updates with expression")
|
||||
}
|
||||
if product5.UpdatedAt.Format(time.RFC3339Nano) == product4.UpdatedAt.Format(time.RFC3339Nano) {
|
||||
t.Errorf("Updates with expression should update UpdatedAt")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateColumn(t *testing.T) {
|
||||
|
@ -172,4 +195,14 @@ func TestUpdateColumn(t *testing.T) {
|
|||
if updatedAt2.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) {
|
||||
t.Errorf("updatedAt should not be updated with update column")
|
||||
}
|
||||
|
||||
DB.Model(&product4).UpdateColumn("price", gorm.Expr("price + 100 - 50"))
|
||||
var product5 Product
|
||||
DB.First(&product5, product4.Id)
|
||||
if product5.Price != product4.Price+100-50 {
|
||||
t.Errorf("UpdateColumn with expression")
|
||||
}
|
||||
if product5.UpdatedAt.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) {
|
||||
t.Errorf("UpdateColumn with expression should not update UpdatedAt")
|
||||
}
|
||||
}
|
||||
|
|
20
utils.go
20
utils.go
|
@ -38,17 +38,11 @@ func ToDBName(name string) string {
|
|||
return s
|
||||
}
|
||||
|
||||
func parseTagSetting(str string) map[string]string {
|
||||
tags := strings.Split(str, ";")
|
||||
setting := map[string]string{}
|
||||
for _, value := range tags {
|
||||
v := strings.Split(value, ":")
|
||||
k := strings.TrimSpace(strings.ToUpper(v[0]))
|
||||
if len(v) == 2 {
|
||||
setting[k] = v[1]
|
||||
} else {
|
||||
setting[k] = k
|
||||
}
|
||||
}
|
||||
return setting
|
||||
type expr struct {
|
||||
expr string
|
||||
args []interface{}
|
||||
}
|
||||
|
||||
func Expr(expression string, args ...interface{}) *expr {
|
||||
return &expr{expr: expression, args: args}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue