mirror of https://github.com/go-gorm/gorm.git
clean up code
This commit is contained in:
parent
6b31a9a41f
commit
5341a5ed43
55
do.go
55
do.go
|
@ -452,8 +452,7 @@ func (s *Do) primaryCondiation(value interface{}) string {
|
|||
}
|
||||
|
||||
func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
|
||||
query := clause["query"]
|
||||
switch value := query.(type) {
|
||||
switch value := clause["query"].(type) {
|
||||
case string:
|
||||
if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
|
||||
id, _ := strconv.Atoi(value)
|
||||
|
@ -462,7 +461,7 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
|
|||
str = "(" + value + ")"
|
||||
}
|
||||
case int, int64, int32:
|
||||
return s.primaryCondiation(s.addToVars(query))
|
||||
return s.primaryCondiation(s.addToVars(value))
|
||||
case sql.NullInt64:
|
||||
return s.primaryCondiation(s.addToVars(value.Int64))
|
||||
case []int64, []int, []int32, []string:
|
||||
|
@ -504,10 +503,9 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
|
|||
}
|
||||
|
||||
func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) {
|
||||
query := clause["query"]
|
||||
var not_equal_sql string
|
||||
|
||||
switch value := query.(type) {
|
||||
switch value := clause["query"].(type) {
|
||||
case string:
|
||||
if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
|
||||
id, _ := strconv.Atoi(value)
|
||||
|
@ -520,11 +518,11 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) {
|
|||
not_equal_sql = fmt.Sprintf("(%v <> ?)", value)
|
||||
}
|
||||
case int, int64, int32:
|
||||
return fmt.Sprintf("(%v <> %v)", s.model.primaryKeyDb(), query)
|
||||
return fmt.Sprintf("(%v <> %v)", s.model.primaryKeyDb(), value)
|
||||
case []int64, []int, []int32, []string:
|
||||
if reflect.ValueOf(query).Len() > 0 {
|
||||
if reflect.ValueOf(value).Len() > 0 {
|
||||
str = fmt.Sprintf("(%v not in (?))", s.model.primaryKeyDb())
|
||||
clause["args"] = []interface{}{query}
|
||||
clause["args"] = []interface{}{value}
|
||||
} else {
|
||||
return ""
|
||||
}
|
||||
|
@ -535,7 +533,7 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) {
|
|||
}
|
||||
return strings.Join(sqls, " AND ")
|
||||
case interface{}:
|
||||
m := &Model{data: query, do: s}
|
||||
m := &Model{data: value, do: s}
|
||||
var sqls []string
|
||||
for _, field := range m.columnsHasValue("other") {
|
||||
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", field.DbName, s.addToVars(field.Value)))
|
||||
|
@ -586,24 +584,23 @@ func (s *Do) whereSql() (sql string) {
|
|||
and_conditions = append(and_conditions, s.buildNotCondition(clause))
|
||||
}
|
||||
|
||||
and_sql := strings.Join(and_conditions, " AND ")
|
||||
or_sql := strings.Join(or_conditions, " OR ")
|
||||
combined_conditions := and_sql
|
||||
if len(combined_conditions) > 0 {
|
||||
combined_sql := strings.Join(and_conditions, " AND ")
|
||||
if len(combined_sql) > 0 {
|
||||
if len(or_sql) > 0 {
|
||||
combined_conditions = combined_conditions + " OR " + or_sql
|
||||
combined_sql = combined_sql + " OR " + or_sql
|
||||
}
|
||||
} else {
|
||||
combined_conditions = or_sql
|
||||
combined_sql = or_sql
|
||||
}
|
||||
|
||||
if len(primary_condiations) > 0 {
|
||||
sql = "WHERE " + strings.Join(primary_condiations, " AND ")
|
||||
if len(combined_conditions) > 0 {
|
||||
sql = sql + " AND (" + combined_conditions + ")"
|
||||
if len(combined_sql) > 0 {
|
||||
sql = sql + " AND (" + combined_sql + ")"
|
||||
}
|
||||
} else if len(combined_conditions) > 0 {
|
||||
sql = "WHERE " + combined_conditions
|
||||
} else if len(combined_sql) > 0 {
|
||||
sql = "WHERE " + combined_sql
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -646,7 +643,7 @@ func (s *Do) combinedSql() string {
|
|||
|
||||
func (s *Do) createTable() *Do {
|
||||
var sqls []string
|
||||
for _, field := range s.model.fields("other") {
|
||||
for _, field := range s.model.fields("migration") {
|
||||
if len(field.SqlType) > 0 {
|
||||
sqls = append(sqls, field.DbName+" "+field.SqlType)
|
||||
}
|
||||
|
@ -681,7 +678,7 @@ func (s *Do) autoMigrate() *Do {
|
|||
if len(table_name) == 0 {
|
||||
s.createTable()
|
||||
} else {
|
||||
for _, field := range s.model.fields("other") {
|
||||
for _, field := range s.model.fields("migration") {
|
||||
var column_name, data_type string
|
||||
sql := fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %v", s.addToVars(s.tableName()))
|
||||
s.db.QueryRow(fmt.Sprintf(sql+" and column_name = %v", s.addToVars(field.DbName)), s.sqlVars...).Scan(&column_name, &data_type)
|
||||
|
@ -699,8 +696,7 @@ func (s *Do) autoMigrate() *Do {
|
|||
|
||||
func (s *Do) begin() *Do {
|
||||
if db, ok := s.db.(sql_db); ok {
|
||||
tx, err := db.Begin()
|
||||
if err == nil {
|
||||
if tx, err := db.Begin(); err == nil {
|
||||
s.db = interface{}(tx).(sql_common)
|
||||
s.startedTransaction = true
|
||||
}
|
||||
|
@ -721,14 +717,11 @@ func (s *Do) commit_or_rollback() {
|
|||
}
|
||||
|
||||
func (s *Do) initializeWithSearchCondition() {
|
||||
m := Model{data: s.value, do: s}
|
||||
|
||||
for _, clause := range s.whereClause {
|
||||
query := clause["query"]
|
||||
switch value := query.(type) {
|
||||
switch value := clause["query"].(type) {
|
||||
case map[string]interface{}:
|
||||
for k, v := range value {
|
||||
m.setValueByColumn(k, v, s.value)
|
||||
s.model.setValueByColumn(k, v, s.value)
|
||||
}
|
||||
case []interface{}:
|
||||
for _, obj := range value {
|
||||
|
@ -736,18 +729,18 @@ func (s *Do) initializeWithSearchCondition() {
|
|||
case reflect.Struct:
|
||||
m := &Model{data: obj, do: s}
|
||||
for _, field := range m.columnsHasValue("other") {
|
||||
m.setValueByColumn(field.DbName, field.Value, s.value)
|
||||
s.model.setValueByColumn(field.DbName, field.Value, s.value)
|
||||
}
|
||||
case reflect.Map:
|
||||
for key, value := range obj.(map[string]interface{}) {
|
||||
m.setValueByColumn(key, value, s.value)
|
||||
s.model.setValueByColumn(key, value, s.value)
|
||||
}
|
||||
}
|
||||
}
|
||||
case interface{}:
|
||||
m := &Model{data: query, do: s}
|
||||
m := &Model{data: value, do: s}
|
||||
for _, field := range m.columnsHasValue("other") {
|
||||
m.setValueByColumn(field.DbName, field.Value, s.value)
|
||||
s.model.setValueByColumn(field.DbName, field.Value, s.value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
69
model.go
69
model.go
|
@ -87,8 +87,6 @@ func (m *Model) fields(operation string) (fields []Field) {
|
|||
field.Name = p.Name
|
||||
field.DbName = toSnake(p.Name)
|
||||
field.IsPrimaryKey = m.primaryKeyDb() == field.DbName
|
||||
field.AutoCreateTime = "created_at" == field.DbName
|
||||
field.AutoUpdateTime = "updated_at" == field.DbName
|
||||
value := indirect_value.FieldByName(p.Name)
|
||||
time_value, is_time := value.Interface().(time.Time)
|
||||
|
||||
|
@ -98,9 +96,7 @@ func (m *Model) fields(operation string) (fields []Field) {
|
|||
case reflect.String:
|
||||
field.IsBlank = value.String() == ""
|
||||
case reflect.Slice:
|
||||
if value.Len() == 0 {
|
||||
field.IsBlank = true
|
||||
}
|
||||
field.IsBlank = value.Len() == 0
|
||||
case reflect.Struct:
|
||||
if is_time {
|
||||
field.IsBlank = time_value.IsZero()
|
||||
|
@ -111,7 +107,6 @@ func (m *Model) fields(operation string) (fields []Field) {
|
|||
field.IsBlank = !value.FieldByName("Valid").Interface().(bool)
|
||||
} else {
|
||||
m := &Model{data: value.Interface(), do: m.do}
|
||||
|
||||
fields := m.columnsHasValue("other")
|
||||
if len(fields) == 0 {
|
||||
field.IsBlank = true
|
||||
|
@ -121,20 +116,20 @@ func (m *Model) fields(operation string) (fields []Field) {
|
|||
}
|
||||
|
||||
if is_time {
|
||||
field.AutoCreateTime = "created_at" == field.DbName
|
||||
field.AutoUpdateTime = "updated_at" == field.DbName
|
||||
|
||||
switch operation {
|
||||
case "create":
|
||||
if (field.AutoCreateTime || field.AutoUpdateTime) && time_value.IsZero() {
|
||||
value.Set(reflect.ValueOf(time.Now()))
|
||||
}
|
||||
case "update":
|
||||
if field.AutoCreateTime && time_value.IsZero() {
|
||||
value.Set(reflect.ValueOf(time.Now()))
|
||||
}
|
||||
|
||||
if field.AutoUpdateTime {
|
||||
value.Set(reflect.ValueOf(time.Now()))
|
||||
}
|
||||
}
|
||||
|
||||
field.SqlType = getSqlType(m.do.chain.driver(), value, 0)
|
||||
} else if field.IsPrimaryKey {
|
||||
field.SqlType = getPrimaryKeySqlType(m.do.chain.driver(), value, 0)
|
||||
|
@ -176,7 +171,7 @@ func (m *Model) fields(operation string) (fields []Field) {
|
|||
}
|
||||
|
||||
if len(m._cache_fields) == 0 {
|
||||
m._cache_fields = make(map[string][]Field)
|
||||
m._cache_fields = map[string][]Field{}
|
||||
}
|
||||
m._cache_fields[operation] = fields
|
||||
return
|
||||
|
@ -191,14 +186,12 @@ func (m *Model) columnsHasValue(operation string) (fields []Field) {
|
|||
return
|
||||
}
|
||||
|
||||
func (m *Model) updatedColumnsAndValues(values map[string]interface{}) (map[string]interface{}, bool) {
|
||||
func (m *Model) updatedColumnsAndValues(values map[string]interface{}) (results map[string]interface{}, any_updated bool) {
|
||||
if m.data == nil {
|
||||
return values, true
|
||||
}
|
||||
|
||||
data := reflect.Indirect(reflect.ValueOf(m.data))
|
||||
results := map[string]interface{}{}
|
||||
|
||||
for key, value := range values {
|
||||
field := data.FieldByName(snakeToUpperCamel(key))
|
||||
if field.IsValid() {
|
||||
|
@ -206,33 +199,31 @@ func (m *Model) updatedColumnsAndValues(values map[string]interface{}) (map[stri
|
|||
switch field.Kind() {
|
||||
case reflect.Int, reflect.Int32, reflect.Int64:
|
||||
if field.Int() != reflect.ValueOf(value).Int() {
|
||||
results[key] = value
|
||||
any_updated = true
|
||||
}
|
||||
field.SetInt(reflect.ValueOf(value).Int())
|
||||
default:
|
||||
results[key] = value
|
||||
any_updated = true
|
||||
field.Set(reflect.ValueOf(value))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if values["updated_at"] != nil && len(results) > 0 {
|
||||
if values["updated_at"] != nil && any_updated {
|
||||
setFieldValue(data.FieldByName("UpdatedAt"), time.Now())
|
||||
}
|
||||
result := len(results) > 0
|
||||
return map[string]interface{}{}, result
|
||||
return
|
||||
}
|
||||
|
||||
func (m *Model) columnsAndValues(operation string) map[string]interface{} {
|
||||
if m.data == nil {
|
||||
return map[string]interface{}{}
|
||||
}
|
||||
|
||||
results := map[string]interface{}{}
|
||||
for _, field := range m.fields(operation) {
|
||||
if !field.IsPrimaryKey && (len(field.SqlType) > 0) {
|
||||
results[field.DbName] = field.Value
|
||||
|
||||
if m.data != nil {
|
||||
for _, field := range m.fields(operation) {
|
||||
if !field.IsPrimaryKey && (len(field.SqlType) > 0) {
|
||||
results[field.DbName] = field.Value
|
||||
}
|
||||
}
|
||||
}
|
||||
return results
|
||||
|
@ -252,17 +243,15 @@ func (m *Model) hasColumn(name string) bool {
|
|||
}
|
||||
|
||||
func (m *Model) ColumnAndValue(name string) (has_column bool, is_slice bool, value interface{}) {
|
||||
if m.data == nil {
|
||||
return
|
||||
}
|
||||
|
||||
data := reflect.Indirect(reflect.ValueOf(m.data))
|
||||
if data.Kind() == reflect.Slice {
|
||||
has_column = reflect.New(data.Type().Elem()).Elem().FieldByName(name).IsValid()
|
||||
is_slice = true
|
||||
} else {
|
||||
if has_column = data.FieldByName(name).IsValid(); has_column {
|
||||
value = data.FieldByName(name).Interface()
|
||||
if m.data != nil {
|
||||
data := reflect.Indirect(reflect.ValueOf(m.data))
|
||||
if data.Kind() == reflect.Slice {
|
||||
has_column = reflect.New(data.Type().Elem()).Elem().FieldByName(name).IsValid()
|
||||
is_slice = true
|
||||
} else {
|
||||
if has_column = data.FieldByName(name).IsValid(); has_column {
|
||||
value = data.FieldByName(name).Interface()
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
|
@ -285,8 +274,7 @@ func (m *Model) tableName() (str string) {
|
|||
|
||||
fm := reflect.Indirect(reflect.ValueOf(m.data)).MethodByName("TableName")
|
||||
if fm.IsValid() {
|
||||
v := fm.Call([]reflect.Value{})
|
||||
if len(v) > 0 {
|
||||
if v := fm.Call([]reflect.Value{}); len(v) > 0 {
|
||||
if result, ok := v[0].Interface().(string); ok {
|
||||
return result
|
||||
}
|
||||
|
@ -315,8 +303,7 @@ func (m *Model) callMethod(method string) {
|
|||
|
||||
fm := reflect.ValueOf(m.data).MethodByName(method)
|
||||
if fm.IsValid() {
|
||||
v := fm.Call([]reflect.Value{})
|
||||
if len(v) > 0 {
|
||||
if v := fm.Call([]reflect.Value{}); len(v) > 0 {
|
||||
if verr, ok := v[0].Interface().(error); ok {
|
||||
m.do.err(verr)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue