mirror of https://github.com/go-gorm/gorm.git
Clean up more code
This commit is contained in:
parent
788da015d1
commit
b22289b249
12
chain.go
12
chain.go
|
@ -132,16 +132,12 @@ func (s *Chain) Select(value interface{}) *Chain {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Chain) Save(value interface{}) *Chain {
|
func (s *Chain) Save(value interface{}) *Chain {
|
||||||
do := s.do(value).begin()
|
s.do(value).begin().save().commit_or_rollback()
|
||||||
do.save()
|
|
||||||
do.commit_or_rollback()
|
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Chain) Delete(value interface{}) *Chain {
|
func (s *Chain) Delete(value interface{}) *Chain {
|
||||||
do := s.do(value).begin()
|
s.do(value).begin().delete().commit_or_rollback()
|
||||||
do.delete()
|
|
||||||
do.commit_or_rollback()
|
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -150,9 +146,7 @@ func (s *Chain) Update(attrs ...interface{}) *Chain {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Chain) Updates(values interface{}, ignore_protected_attrs ...bool) *Chain {
|
func (s *Chain) Updates(values interface{}, ignore_protected_attrs ...bool) *Chain {
|
||||||
do := s.do(s.value).begin().setUpdateAttrs(values, ignore_protected_attrs...)
|
s.do(s.value).begin().setUpdateAttrs(values, ignore_protected_attrs...).update().commit_or_rollback()
|
||||||
do.update()
|
|
||||||
do.commit_or_rollback()
|
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
126
do.go
126
do.go
|
@ -68,12 +68,11 @@ func (s *Do) addToVars(value interface{}) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) exec(sqls ...string) (err error) {
|
func (s *Do) exec(sqls ...string) (err error) {
|
||||||
if s.chain.hasError() {
|
if !s.chain.hasError() {
|
||||||
return
|
|
||||||
} else {
|
|
||||||
if len(sqls) > 0 {
|
if len(sqls) > 0 {
|
||||||
s.sql = sqls[0]
|
s.sql = sqls[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
_, err = s.db.Exec(s.sql, s.sqlVars...)
|
_, err = s.db.Exec(s.sql, s.sqlVars...)
|
||||||
s.chain.slog(s.sql, now, s.sqlVars...)
|
s.chain.slog(s.sql, now, s.sqlVars...)
|
||||||
|
@ -81,13 +80,13 @@ func (s *Do) exec(sqls ...string) (err error) {
|
||||||
return s.err(err)
|
return s.err(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) save() (value interface{}) {
|
func (s *Do) save() *Do {
|
||||||
if s.model.primaryKeyZero() {
|
if s.model.primaryKeyZero() {
|
||||||
value = s.create()
|
s.create()
|
||||||
} else {
|
} else {
|
||||||
value = s.update()
|
s.update()
|
||||||
}
|
}
|
||||||
return
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) prepareCreateSql() {
|
func (s *Do) prepareCreateSql() {
|
||||||
|
@ -110,24 +109,24 @@ func (s *Do) prepareCreateSql() {
|
||||||
|
|
||||||
func (s *Do) saveBeforeAssociations() {
|
func (s *Do) saveBeforeAssociations() {
|
||||||
for _, field := range s.model.beforeAssociations() {
|
for _, field := range s.model.beforeAssociations() {
|
||||||
var id interface{}
|
|
||||||
do := &Do{chain: s.chain, db: s.db}
|
do := &Do{chain: s.chain, db: s.db}
|
||||||
|
|
||||||
reflect_value := reflect.ValueOf(field.Value)
|
reflect_value := reflect.ValueOf(field.Value)
|
||||||
if reflect_value.CanAddr() {
|
if reflect_value.CanAddr() {
|
||||||
id = do.setModel(reflect_value.Addr().Interface()).save()
|
do.setModel(reflect_value.Addr().Interface()).save()
|
||||||
} else {
|
} else {
|
||||||
|
// If can't take address, then clone the value and set it back
|
||||||
dest_value := reflect.New(reflect_value.Type()).Elem()
|
dest_value := reflect.New(reflect_value.Type()).Elem()
|
||||||
m := &Model{data: field.Value, do: s}
|
m := &Model{data: field.Value, do: s}
|
||||||
for _, f := range m.columnsHasValue("other") {
|
for _, f := range m.columnsHasValue("other") {
|
||||||
dest_value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value))
|
dest_value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value))
|
||||||
}
|
}
|
||||||
id = do.setModel(dest_value.Addr().Interface()).save()
|
do.setModel(dest_value.Addr().Interface()).save()
|
||||||
m.setValueByColumn(field.Name, dest_value.Interface(), s.value)
|
m.setValueByColumn(field.Name, dest_value.Interface(), s.value)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(field.foreignKey) > 0 {
|
if len(field.foreignKey) > 0 {
|
||||||
s.model.setValueByColumn(field.foreignKey, id, s.model.data)
|
s.model.setValueByColumn(field.foreignKey, do.model.primaryKeyValue(), s.model.data)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -136,11 +135,12 @@ func (s *Do) saveAfterAssociations() {
|
||||||
for _, field := range s.model.afterAssociations() {
|
for _, field := range s.model.afterAssociations() {
|
||||||
reflect_value := reflect.ValueOf(field.Value)
|
reflect_value := reflect.ValueOf(field.Value)
|
||||||
|
|
||||||
switch reflect.TypeOf(field.Value).Kind() {
|
switch reflect_value.Kind() {
|
||||||
case reflect.Slice:
|
case reflect.Slice:
|
||||||
for i := 0; i < reflect_value.Len(); i++ {
|
for i := 0; i < reflect_value.Len(); i++ {
|
||||||
value := reflect_value.Index(i).Addr().Interface()
|
|
||||||
do := &Do{chain: s.chain, db: s.db}
|
do := &Do{chain: s.chain, db: s.db}
|
||||||
|
|
||||||
|
value := reflect_value.Index(i).Addr().Interface()
|
||||||
if len(field.foreignKey) > 0 {
|
if len(field.foreignKey) > 0 {
|
||||||
s.model.setValueByColumn(field.foreignKey, s.model.primaryKeyValue(), value)
|
s.model.setValueByColumn(field.foreignKey, s.model.primaryKeyValue(), value)
|
||||||
}
|
}
|
||||||
|
@ -175,8 +175,9 @@ func (s *Do) create() (i interface{}) {
|
||||||
s.prepareCreateSql()
|
s.prepareCreateSql()
|
||||||
|
|
||||||
if !s.chain.hasError() {
|
if !s.chain.hasError() {
|
||||||
now := time.Now()
|
|
||||||
var id interface{}
|
var id interface{}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
if s.chain.driver() == "postgres" {
|
if s.chain.driver() == "postgres" {
|
||||||
s.err(s.db.QueryRow(s.sql, s.sqlVars...).Scan(&id))
|
s.err(s.db.QueryRow(s.sql, s.sqlVars...).Scan(&id))
|
||||||
} else {
|
} else {
|
||||||
|
@ -188,12 +189,9 @@ func (s *Do) create() (i interface{}) {
|
||||||
s.chain.slog(s.sql, now, s.sqlVars...)
|
s.chain.slog(s.sql, now, s.sqlVars...)
|
||||||
|
|
||||||
if !s.chain.hasError() {
|
if !s.chain.hasError() {
|
||||||
result := reflect.Indirect(reflect.ValueOf(s.value))
|
s.model.setValueByColumn(s.model.primaryKey(), id, s.value)
|
||||||
if !setFieldValue(result.FieldByName(s.model.primaryKey()), id) {
|
|
||||||
fmt.Printf("Can't set primary key for %#v\n", result.Interface())
|
|
||||||
}
|
|
||||||
s.saveAfterAssociations()
|
|
||||||
|
|
||||||
|
s.saveAfterAssociations()
|
||||||
s.model.callMethod("AfterCreate")
|
s.model.callMethod("AfterCreate")
|
||||||
s.model.callMethod("AfterSave")
|
s.model.callMethod("AfterSave")
|
||||||
}
|
}
|
||||||
|
@ -204,19 +202,17 @@ func (s *Do) create() (i interface{}) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) setUpdateAttrs(values interface{}, ignore_protected_attrs ...bool) *Do {
|
func (s *Do) setUpdateAttrs(values interface{}, ignore_protected_attrs ...bool) *Do {
|
||||||
switch values.(type) {
|
switch vs := values.(type) {
|
||||||
case map[string]interface{}:
|
case map[string]interface{}:
|
||||||
s.updateAttrs = values.(map[string]interface{})
|
s.updateAttrs = vs
|
||||||
case []interface{}:
|
case []interface{}:
|
||||||
for _, value := range values.([]interface{}) {
|
for _, value := range vs {
|
||||||
s.setUpdateAttrs(value, ignore_protected_attrs...)
|
s.setUpdateAttrs(value, ignore_protected_attrs...)
|
||||||
}
|
}
|
||||||
case interface{}:
|
case interface{}:
|
||||||
m := &Model{data: values, do: s}
|
m := &Model{data: values, do: s}
|
||||||
fields := m.columnsHasValue("other")
|
s.updateAttrs = map[string]interface{}{}
|
||||||
|
for _, field := range m.columnsHasValue("other") {
|
||||||
s.updateAttrs = make(map[string]interface{}, len(fields))
|
|
||||||
for _, field := range fields {
|
|
||||||
s.updateAttrs[field.DbName] = field.Value
|
s.updateAttrs[field.DbName] = field.Value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -238,8 +234,7 @@ func (s *Do) prepareUpdateSql(results map[string]interface{}) {
|
||||||
sqls = append(sqls, fmt.Sprintf("%v = %v", key, s.addToVars(value)))
|
sqls = append(sqls, fmt.Sprintf("%v = %v", key, s.addToVars(value)))
|
||||||
}
|
}
|
||||||
|
|
||||||
update_attrs := s.model.columnsAndValues("update")
|
for key, value := range s.model.columnsAndValues("update") {
|
||||||
for key, value := range update_attrs {
|
|
||||||
sqls = append(sqls, fmt.Sprintf("%v = %v", key, s.addToVars(value)))
|
sqls = append(sqls, fmt.Sprintf("%v = %v", key, s.addToVars(value)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -252,13 +247,12 @@ func (s *Do) prepareUpdateSql(results map[string]interface{}) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) update() (i interface{}) {
|
func (s *Do) update() *Do {
|
||||||
update_attrs := s.updateAttrs
|
update_attrs := s.updateAttrs
|
||||||
if len(update_attrs) > 0 {
|
if len(update_attrs) > 0 {
|
||||||
var need_update bool
|
var need_update bool
|
||||||
update_attrs, need_update = s.prepareUpdateAttrs()
|
if update_attrs, need_update = s.prepareUpdateAttrs(); !need_update {
|
||||||
if !need_update {
|
return s
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -276,29 +270,22 @@ func (s *Do) update() (i interface{}) {
|
||||||
s.model.callMethod("AfterSave")
|
s.model.callMethod("AfterSave")
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.model.primaryKeyValue()
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) prepareDeleteSql() {
|
func (s *Do) delete() *Do {
|
||||||
s.sql = fmt.Sprintf("DELETE FROM %v %v", s.tableName(), s.combinedSql())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Do) delete() {
|
|
||||||
s.model.callMethod("BeforeDelete")
|
s.model.callMethod("BeforeDelete")
|
||||||
|
|
||||||
if !s.chain.hasError() {
|
if !s.chain.hasError() {
|
||||||
if !s.unscoped && s.model.hasColumn("DeletedAt") {
|
if !s.unscoped && s.model.hasColumn("DeletedAt") {
|
||||||
delete_sql := "deleted_at=" + s.addToVars(time.Now())
|
s.sql = fmt.Sprintf("UPDATE %v SET deleted_at=%v %v", s.tableName(), s.addToVars(time.Now()), s.combinedSql())
|
||||||
s.sql = fmt.Sprintf("UPDATE %v SET %v %v", s.tableName(), delete_sql, s.combinedSql())
|
|
||||||
s.exec()
|
|
||||||
} else {
|
} else {
|
||||||
s.prepareDeleteSql()
|
s.sql = fmt.Sprintf("DELETE FROM %v %v", s.tableName(), s.combinedSql())
|
||||||
s.exec()
|
|
||||||
}
|
}
|
||||||
|
s.exec()
|
||||||
s.model.callMethod("AfterDelete")
|
s.model.callMethod("AfterDelete")
|
||||||
}
|
}
|
||||||
return
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) prepareQuerySql() {
|
func (s *Do) prepareQuerySql() {
|
||||||
|
@ -379,24 +366,18 @@ func (s *Do) query() {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
rows, err := s.db.Query(s.sql, s.sqlVars...)
|
rows, err := s.db.Query(s.sql, s.sqlVars...)
|
||||||
s.chain.slog(s.sql, now, s.sqlVars...)
|
s.chain.slog(s.sql, now, s.sqlVars...)
|
||||||
|
|
||||||
if s.err(err) != nil {
|
if s.err(err) != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
var has_record bool
|
||||||
if rows.Err() != nil {
|
|
||||||
s.err(rows.Err())
|
|
||||||
}
|
|
||||||
|
|
||||||
counts := 0
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
counts += 1
|
has_record = true
|
||||||
var dest reflect.Value
|
dest := dest_out
|
||||||
if is_slice {
|
if is_slice {
|
||||||
dest = reflect.New(dest_type).Elem()
|
dest = reflect.New(dest_type).Elem()
|
||||||
} else {
|
|
||||||
dest = dest_out
|
|
||||||
}
|
}
|
||||||
|
|
||||||
columns, _ := rows.Columns()
|
columns, _ := rows.Columns()
|
||||||
|
@ -404,15 +385,10 @@ func (s *Do) query() {
|
||||||
for _, value := range columns {
|
for _, value := range columns {
|
||||||
field := dest.FieldByName(snakeToUpperCamel(value))
|
field := dest.FieldByName(snakeToUpperCamel(value))
|
||||||
if field.IsValid() {
|
if field.IsValid() {
|
||||||
if field.CanAddr() {
|
|
||||||
values = append(values, field.Addr().Interface())
|
values = append(values, field.Addr().Interface())
|
||||||
} else {
|
} else {
|
||||||
s.err(errors.New(fmt.Sprintf("Can't take address of %v, should be ptr", dest)))
|
var ignore interface{}
|
||||||
return
|
values = append(values, &ignore)
|
||||||
}
|
|
||||||
} else {
|
|
||||||
var null interface{}
|
|
||||||
values = append(values, &null)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
s.err(rows.Scan(values...))
|
s.err(rows.Scan(values...))
|
||||||
|
@ -422,7 +398,7 @@ func (s *Do) query() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (counts == 0) && !is_slice {
|
if !has_record && !is_slice {
|
||||||
s.err(errors.New("Record not found!"))
|
s.err(errors.New("Record not found!"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -477,9 +453,8 @@ func (s *Do) primaryCondiation(value interface{}) string {
|
||||||
|
|
||||||
func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
|
func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
|
||||||
query := clause["query"]
|
query := clause["query"]
|
||||||
switch query.(type) {
|
switch value := query.(type) {
|
||||||
case string:
|
case string:
|
||||||
value := query.(string)
|
|
||||||
if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
|
if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
|
||||||
id, _ := strconv.Atoi(value)
|
id, _ := strconv.Atoi(value)
|
||||||
return s.primaryCondiation(s.addToVars(id))
|
return s.primaryCondiation(s.addToVars(id))
|
||||||
|
@ -489,18 +464,18 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
|
||||||
case int, int64, int32:
|
case int, int64, int32:
|
||||||
return s.primaryCondiation(s.addToVars(query))
|
return s.primaryCondiation(s.addToVars(query))
|
||||||
case sql.NullInt64:
|
case sql.NullInt64:
|
||||||
return s.primaryCondiation(s.addToVars(query.(sql.NullInt64).Int64))
|
return s.primaryCondiation(s.addToVars(value.Int64))
|
||||||
case []int64, []int, []int32, []string:
|
case []int64, []int, []int32, []string:
|
||||||
str = fmt.Sprintf("(%v in (?))", s.model.primaryKeyDb())
|
str = fmt.Sprintf("(%v in (?))", s.model.primaryKeyDb())
|
||||||
clause["args"] = []interface{}{query}
|
clause["args"] = []interface{}{value}
|
||||||
case map[string]interface{}:
|
case map[string]interface{}:
|
||||||
var sqls []string
|
var sqls []string
|
||||||
for key, value := range query.(map[string]interface{}) {
|
for key, value := range value {
|
||||||
sqls = append(sqls, fmt.Sprintf("(%v = %v)", key, s.addToVars(value)))
|
sqls = append(sqls, fmt.Sprintf("(%v = %v)", key, s.addToVars(value)))
|
||||||
}
|
}
|
||||||
return strings.Join(sqls, " AND ")
|
return strings.Join(sqls, " AND ")
|
||||||
case interface{}:
|
case interface{}:
|
||||||
m := &Model{data: query, do: s}
|
m := &Model{data: value, do: s}
|
||||||
var sqls []string
|
var sqls []string
|
||||||
for _, field := range m.columnsHasValue("other") {
|
for _, field := range m.columnsHasValue("other") {
|
||||||
sqls = append(sqls, fmt.Sprintf("(%v = %v)", field.DbName, s.addToVars(field.Value)))
|
sqls = append(sqls, fmt.Sprintf("(%v = %v)", field.DbName, s.addToVars(field.Value)))
|
||||||
|
@ -532,9 +507,8 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) {
|
||||||
query := clause["query"]
|
query := clause["query"]
|
||||||
var not_equal_sql string
|
var not_equal_sql string
|
||||||
|
|
||||||
switch query.(type) {
|
switch value := query.(type) {
|
||||||
case string:
|
case string:
|
||||||
value := query.(string)
|
|
||||||
if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
|
if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
|
||||||
id, _ := strconv.Atoi(value)
|
id, _ := strconv.Atoi(value)
|
||||||
return fmt.Sprintf("(%v <> %v)", s.model.primaryKeyDb(), id)
|
return fmt.Sprintf("(%v <> %v)", s.model.primaryKeyDb(), id)
|
||||||
|
@ -556,7 +530,7 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) {
|
||||||
}
|
}
|
||||||
case map[string]interface{}:
|
case map[string]interface{}:
|
||||||
var sqls []string
|
var sqls []string
|
||||||
for key, value := range query.(map[string]interface{}) {
|
for key, value := range value {
|
||||||
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", key, s.addToVars(value)))
|
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", key, s.addToVars(value)))
|
||||||
}
|
}
|
||||||
return strings.Join(sqls, " AND ")
|
return strings.Join(sqls, " AND ")
|
||||||
|
@ -751,13 +725,13 @@ func (s *Do) initializeWithSearchCondition() {
|
||||||
|
|
||||||
for _, clause := range s.whereClause {
|
for _, clause := range s.whereClause {
|
||||||
query := clause["query"]
|
query := clause["query"]
|
||||||
switch query.(type) {
|
switch value := query.(type) {
|
||||||
case map[string]interface{}:
|
case map[string]interface{}:
|
||||||
for key, value := range query.(map[string]interface{}) {
|
for k, v := range value {
|
||||||
m.setValueByColumn(key, value, s.value)
|
m.setValueByColumn(k, v, s.value)
|
||||||
}
|
}
|
||||||
case []interface{}:
|
case []interface{}:
|
||||||
for _, obj := range query.([]interface{}) {
|
for _, obj := range value {
|
||||||
switch reflect.ValueOf(obj).Kind() {
|
switch reflect.ValueOf(obj).Kind() {
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
m := &Model{data: obj, do: s}
|
m := &Model{data: obj, do: s}
|
||||||
|
|
Loading…
Reference in New Issue