mirror of https://github.com/go-gorm/gorm.git
Support SQL expression
This commit is contained in:
parent
eb480cc085
commit
10340e6ad7
115
model_struct.go
115
model_struct.go
|
@ -63,56 +63,6 @@ type Relationship struct {
|
||||||
JoinTable string
|
JoinTable string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) generateSqlTag(field *StructField) {
|
|
||||||
var sqlType string
|
|
||||||
structType := field.Struct.Type
|
|
||||||
if structType.Kind() == reflect.Ptr {
|
|
||||||
structType = structType.Elem()
|
|
||||||
}
|
|
||||||
reflectValue := reflect.Indirect(reflect.New(structType))
|
|
||||||
sqlSettings := parseTagSetting(field.Tag.Get("sql"))
|
|
||||||
|
|
||||||
if value, ok := sqlSettings["TYPE"]; ok {
|
|
||||||
sqlType = value
|
|
||||||
}
|
|
||||||
|
|
||||||
additionalType := sqlSettings["NOT NULL"] + " " + sqlSettings["UNIQUE"]
|
|
||||||
if value, ok := sqlSettings["DEFAULT"]; ok {
|
|
||||||
additionalType = additionalType + "DEFAULT " + value
|
|
||||||
}
|
|
||||||
|
|
||||||
if field.IsScanner {
|
|
||||||
var getScannerValue func(reflect.Value)
|
|
||||||
getScannerValue = func(value reflect.Value) {
|
|
||||||
reflectValue = value
|
|
||||||
if _, isScanner := reflect.New(reflectValue.Type()).Interface().(sql.Scanner); isScanner && reflectValue.Kind() == reflect.Struct {
|
|
||||||
getScannerValue(reflectValue.Field(0))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
getScannerValue(reflectValue)
|
|
||||||
}
|
|
||||||
|
|
||||||
if sqlType == "" {
|
|
||||||
var size = 255
|
|
||||||
|
|
||||||
if value, ok := sqlSettings["SIZE"]; ok {
|
|
||||||
size, _ = strconv.Atoi(value)
|
|
||||||
}
|
|
||||||
|
|
||||||
if field.IsPrimaryKey {
|
|
||||||
sqlType = scope.Dialect().PrimaryKeyTag(reflectValue, size)
|
|
||||||
} else {
|
|
||||||
sqlType = scope.Dialect().SqlTag(reflectValue, size)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.TrimSpace(additionalType) == "" {
|
|
||||||
field.SqlTag = sqlType
|
|
||||||
} else {
|
|
||||||
field.SqlTag = fmt.Sprintf("%v %v", sqlType, additionalType)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var pluralMapKeys = []*regexp.Regexp{regexp.MustCompile("ch$"), regexp.MustCompile("ss$"), regexp.MustCompile("sh$"), regexp.MustCompile("day$"), regexp.MustCompile("y$"), regexp.MustCompile("x$"), regexp.MustCompile("([^s])s?$")}
|
var pluralMapKeys = []*regexp.Regexp{regexp.MustCompile("ch$"), regexp.MustCompile("ss$"), regexp.MustCompile("sh$"), regexp.MustCompile("day$"), regexp.MustCompile("y$"), regexp.MustCompile("x$"), regexp.MustCompile("([^s])s?$")}
|
||||||
var pluralMapValues = []string{"ches", "sses", "shes", "days", "ies", "xes", "${1}s"}
|
var pluralMapValues = []string{"ches", "sses", "shes", "days", "ies", "xes", "${1}s"}
|
||||||
|
|
||||||
|
@ -341,3 +291,68 @@ func (scope *Scope) GetModelStruct(noRelationship ...bool) *ModelStruct {
|
||||||
func (scope *Scope) GetStructFields() (fields []*StructField) {
|
func (scope *Scope) GetStructFields() (fields []*StructField) {
|
||||||
return scope.GetModelStruct().StructFields
|
return scope.GetModelStruct().StructFields
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) generateSqlTag(field *StructField) {
|
||||||
|
var sqlType string
|
||||||
|
structType := field.Struct.Type
|
||||||
|
if structType.Kind() == reflect.Ptr {
|
||||||
|
structType = structType.Elem()
|
||||||
|
}
|
||||||
|
reflectValue := reflect.Indirect(reflect.New(structType))
|
||||||
|
sqlSettings := parseTagSetting(field.Tag.Get("sql"))
|
||||||
|
|
||||||
|
if value, ok := sqlSettings["TYPE"]; ok {
|
||||||
|
sqlType = value
|
||||||
|
}
|
||||||
|
|
||||||
|
additionalType := sqlSettings["NOT NULL"] + " " + sqlSettings["UNIQUE"]
|
||||||
|
if value, ok := sqlSettings["DEFAULT"]; ok {
|
||||||
|
additionalType = additionalType + "DEFAULT " + value
|
||||||
|
}
|
||||||
|
|
||||||
|
if field.IsScanner {
|
||||||
|
var getScannerValue func(reflect.Value)
|
||||||
|
getScannerValue = func(value reflect.Value) {
|
||||||
|
reflectValue = value
|
||||||
|
if _, isScanner := reflect.New(reflectValue.Type()).Interface().(sql.Scanner); isScanner && reflectValue.Kind() == reflect.Struct {
|
||||||
|
getScannerValue(reflectValue.Field(0))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
getScannerValue(reflectValue)
|
||||||
|
}
|
||||||
|
|
||||||
|
if sqlType == "" {
|
||||||
|
var size = 255
|
||||||
|
|
||||||
|
if value, ok := sqlSettings["SIZE"]; ok {
|
||||||
|
size, _ = strconv.Atoi(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
if field.IsPrimaryKey {
|
||||||
|
sqlType = scope.Dialect().PrimaryKeyTag(reflectValue, size)
|
||||||
|
} else {
|
||||||
|
sqlType = scope.Dialect().SqlTag(reflectValue, size)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.TrimSpace(additionalType) == "" {
|
||||||
|
field.SqlTag = sqlType
|
||||||
|
} else {
|
||||||
|
field.SqlTag = fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseTagSetting(str string) map[string]string {
|
||||||
|
tags := strings.Split(str, ";")
|
||||||
|
setting := map[string]string{}
|
||||||
|
for _, value := range tags {
|
||||||
|
v := strings.Split(value, ":")
|
||||||
|
k := strings.TrimSpace(strings.ToUpper(v[0]))
|
||||||
|
if len(v) == 2 {
|
||||||
|
setting[k] = v[1]
|
||||||
|
} else {
|
||||||
|
setting[k] = k
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return setting
|
||||||
|
}
|
||||||
|
|
8
scope.go
8
scope.go
|
@ -204,9 +204,17 @@ func (scope *Scope) CallMethodWithErrorCheck(name string) {
|
||||||
|
|
||||||
// AddToVars add value as sql's vars, gorm will escape them
|
// AddToVars add value as sql's vars, gorm will escape them
|
||||||
func (scope *Scope) AddToVars(value interface{}) string {
|
func (scope *Scope) AddToVars(value interface{}) string {
|
||||||
|
if expr, ok := value.(*expr); ok {
|
||||||
|
exp := expr.expr
|
||||||
|
for _, arg := range expr.args {
|
||||||
|
exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1)
|
||||||
|
}
|
||||||
|
return exp
|
||||||
|
} else {
|
||||||
scope.SqlVars = append(scope.SqlVars, value)
|
scope.SqlVars = append(scope.SqlVars, value)
|
||||||
return scope.Dialect().BinVar(len(scope.SqlVars))
|
return scope.Dialect().BinVar(len(scope.SqlVars))
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TableName get table name
|
// TableName get table name
|
||||||
func (scope *Scope) TableName() string {
|
func (scope *Scope) TableName() string {
|
||||||
|
|
|
@ -307,17 +307,31 @@ func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignore
|
||||||
return values, true
|
return values, true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var hasExpr bool
|
||||||
fields := scope.Fields()
|
fields := scope.Fields()
|
||||||
for key, value := range values {
|
for key, value := range values {
|
||||||
if field, ok := fields[ToDBName(key)]; ok && field.Field.IsValid() {
|
if field, ok := fields[ToDBName(key)]; ok && field.Field.IsValid() {
|
||||||
if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) {
|
if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) {
|
||||||
if !equalAsString(field.Field.Interface(), value) {
|
if _, ok := value.(*expr); ok {
|
||||||
|
hasExpr = true
|
||||||
|
} else if !equalAsString(field.Field.Interface(), value) {
|
||||||
hasUpdate = true
|
hasUpdate = true
|
||||||
field.Set(value)
|
field.Set(value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if hasExpr {
|
||||||
|
var updateMap = map[string]interface{}{}
|
||||||
|
for key, value := range fields {
|
||||||
|
if v, ok := values[key]; ok {
|
||||||
|
updateMap[key] = v
|
||||||
|
} else {
|
||||||
|
updateMap[key] = value.Field.Interface()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return updateMap, true
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,8 @@ package gorm_test
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestUpdate(t *testing.T) {
|
func TestUpdate(t *testing.T) {
|
||||||
|
@ -67,6 +69,17 @@ func TestUpdate(t *testing.T) {
|
||||||
if count := DB.Model(Product{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(products)) {
|
if count := DB.Model(Product{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(products)) {
|
||||||
t.Error("RowsAffected should be correct when do batch update")
|
t.Error("RowsAffected should be correct when do batch update")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DB.First(&product4, product4.Id)
|
||||||
|
DB.Model(&product4).Update("price", gorm.Expr("price + ? - ?", 100, 50))
|
||||||
|
var product5 Product
|
||||||
|
DB.First(&product5, product4.Id)
|
||||||
|
if product5.Price != product4.Price+100-50 {
|
||||||
|
t.Errorf("Update with expression")
|
||||||
|
}
|
||||||
|
if product5.UpdatedAt.Format(time.RFC3339Nano) == product4.UpdatedAt.Format(time.RFC3339Nano) {
|
||||||
|
t.Errorf("Update with expression should update UpdatedAt")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUpdateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) {
|
func TestUpdateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) {
|
||||||
|
@ -148,6 +161,16 @@ func TestUpdates(t *testing.T) {
|
||||||
if DB.First(&Product{}, "code = ?", "product2newcode").RecordNotFound() {
|
if DB.First(&Product{}, "code = ?", "product2newcode").RecordNotFound() {
|
||||||
t.Errorf("product2's code should be updated")
|
t.Errorf("product2's code should be updated")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DB.Model(&product4).Updates(map[string]interface{}{"price": gorm.Expr("price + ?", 100)})
|
||||||
|
var product5 Product
|
||||||
|
DB.First(&product5, product4.Id)
|
||||||
|
if product5.Price != product4.Price+100 {
|
||||||
|
t.Errorf("Updates with expression")
|
||||||
|
}
|
||||||
|
if product5.UpdatedAt.Format(time.RFC3339Nano) == product4.UpdatedAt.Format(time.RFC3339Nano) {
|
||||||
|
t.Errorf("Updates with expression should update UpdatedAt")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUpdateColumn(t *testing.T) {
|
func TestUpdateColumn(t *testing.T) {
|
||||||
|
@ -172,4 +195,14 @@ func TestUpdateColumn(t *testing.T) {
|
||||||
if updatedAt2.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) {
|
if updatedAt2.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) {
|
||||||
t.Errorf("updatedAt should not be updated with update column")
|
t.Errorf("updatedAt should not be updated with update column")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DB.Model(&product4).UpdateColumn("price", gorm.Expr("price + 100 - 50"))
|
||||||
|
var product5 Product
|
||||||
|
DB.First(&product5, product4.Id)
|
||||||
|
if product5.Price != product4.Price+100-50 {
|
||||||
|
t.Errorf("UpdateColumn with expression")
|
||||||
|
}
|
||||||
|
if product5.UpdatedAt.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) {
|
||||||
|
t.Errorf("UpdateColumn with expression should not update UpdatedAt")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
18
utils.go
18
utils.go
|
@ -38,17 +38,11 @@ func ToDBName(name string) string {
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseTagSetting(str string) map[string]string {
|
type expr struct {
|
||||||
tags := strings.Split(str, ";")
|
expr string
|
||||||
setting := map[string]string{}
|
args []interface{}
|
||||||
for _, value := range tags {
|
|
||||||
v := strings.Split(value, ":")
|
|
||||||
k := strings.TrimSpace(strings.ToUpper(v[0]))
|
|
||||||
if len(v) == 2 {
|
|
||||||
setting[k] = v[1]
|
|
||||||
} else {
|
|
||||||
setting[k] = k
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return setting
|
func Expr(expression string, args ...interface{}) *expr {
|
||||||
|
return &expr{expr: expression, args: args}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue