forked from mirror/gorm
Clean up more code
This commit is contained in:
parent
788da015d1
commit
b22289b249
12
chain.go
12
chain.go
|
@ -132,16 +132,12 @@ func (s *Chain) Select(value interface{}) *Chain {
|
|||
}
|
||||
|
||||
func (s *Chain) Save(value interface{}) *Chain {
|
||||
do := s.do(value).begin()
|
||||
do.save()
|
||||
do.commit_or_rollback()
|
||||
s.do(value).begin().save().commit_or_rollback()
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Chain) Delete(value interface{}) *Chain {
|
||||
do := s.do(value).begin()
|
||||
do.delete()
|
||||
do.commit_or_rollback()
|
||||
s.do(value).begin().delete().commit_or_rollback()
|
||||
return s
|
||||
}
|
||||
|
||||
|
@ -150,9 +146,7 @@ func (s *Chain) Update(attrs ...interface{}) *Chain {
|
|||
}
|
||||
|
||||
func (s *Chain) Updates(values interface{}, ignore_protected_attrs ...bool) *Chain {
|
||||
do := s.do(s.value).begin().setUpdateAttrs(values, ignore_protected_attrs...)
|
||||
do.update()
|
||||
do.commit_or_rollback()
|
||||
s.do(s.value).begin().setUpdateAttrs(values, ignore_protected_attrs...).update().commit_or_rollback()
|
||||
return s
|
||||
}
|
||||
|
||||
|
|
126
do.go
126
do.go
|
@ -68,12 +68,11 @@ func (s *Do) addToVars(value interface{}) string {
|
|||
}
|
||||
|
||||
func (s *Do) exec(sqls ...string) (err error) {
|
||||
if s.chain.hasError() {
|
||||
return
|
||||
} else {
|
||||
if !s.chain.hasError() {
|
||||
if len(sqls) > 0 {
|
||||
s.sql = sqls[0]
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
_, err = s.db.Exec(s.sql, s.sqlVars...)
|
||||
s.chain.slog(s.sql, now, s.sqlVars...)
|
||||
|
@ -81,13 +80,13 @@ func (s *Do) exec(sqls ...string) (err error) {
|
|||
return s.err(err)
|
||||
}
|
||||
|
||||
func (s *Do) save() (value interface{}) {
|
||||
func (s *Do) save() *Do {
|
||||
if s.model.primaryKeyZero() {
|
||||
value = s.create()
|
||||
s.create()
|
||||
} else {
|
||||
value = s.update()
|
||||
s.update()
|
||||
}
|
||||
return
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Do) prepareCreateSql() {
|
||||
|
@ -110,24 +109,24 @@ func (s *Do) prepareCreateSql() {
|
|||
|
||||
func (s *Do) saveBeforeAssociations() {
|
||||
for _, field := range s.model.beforeAssociations() {
|
||||
var id interface{}
|
||||
do := &Do{chain: s.chain, db: s.db}
|
||||
|
||||
reflect_value := reflect.ValueOf(field.Value)
|
||||
if reflect_value.CanAddr() {
|
||||
id = do.setModel(reflect_value.Addr().Interface()).save()
|
||||
do.setModel(reflect_value.Addr().Interface()).save()
|
||||
} else {
|
||||
// If can't take address, then clone the value and set it back
|
||||
dest_value := reflect.New(reflect_value.Type()).Elem()
|
||||
m := &Model{data: field.Value, do: s}
|
||||
for _, f := range m.columnsHasValue("other") {
|
||||
dest_value.FieldByName(f.Name).Set(reflect.ValueOf(f.Value))
|
||||
}
|
||||
id = do.setModel(dest_value.Addr().Interface()).save()
|
||||
do.setModel(dest_value.Addr().Interface()).save()
|
||||
m.setValueByColumn(field.Name, dest_value.Interface(), s.value)
|
||||
}
|
||||
|
||||
if len(field.foreignKey) > 0 {
|
||||
s.model.setValueByColumn(field.foreignKey, id, s.model.data)
|
||||
s.model.setValueByColumn(field.foreignKey, do.model.primaryKeyValue(), s.model.data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -136,11 +135,12 @@ func (s *Do) saveAfterAssociations() {
|
|||
for _, field := range s.model.afterAssociations() {
|
||||
reflect_value := reflect.ValueOf(field.Value)
|
||||
|
||||
switch reflect.TypeOf(field.Value).Kind() {
|
||||
switch reflect_value.Kind() {
|
||||
case reflect.Slice:
|
||||
for i := 0; i < reflect_value.Len(); i++ {
|
||||
value := reflect_value.Index(i).Addr().Interface()
|
||||
do := &Do{chain: s.chain, db: s.db}
|
||||
|
||||
value := reflect_value.Index(i).Addr().Interface()
|
||||
if len(field.foreignKey) > 0 {
|
||||
s.model.setValueByColumn(field.foreignKey, s.model.primaryKeyValue(), value)
|
||||
}
|
||||
|
@ -175,8 +175,9 @@ func (s *Do) create() (i interface{}) {
|
|||
s.prepareCreateSql()
|
||||
|
||||
if !s.chain.hasError() {
|
||||
now := time.Now()
|
||||
var id interface{}
|
||||
|
||||
now := time.Now()
|
||||
if s.chain.driver() == "postgres" {
|
||||
s.err(s.db.QueryRow(s.sql, s.sqlVars...).Scan(&id))
|
||||
} else {
|
||||
|
@ -188,12 +189,9 @@ func (s *Do) create() (i interface{}) {
|
|||
s.chain.slog(s.sql, now, s.sqlVars...)
|
||||
|
||||
if !s.chain.hasError() {
|
||||
result := reflect.Indirect(reflect.ValueOf(s.value))
|
||||
if !setFieldValue(result.FieldByName(s.model.primaryKey()), id) {
|
||||
fmt.Printf("Can't set primary key for %#v\n", result.Interface())
|
||||
}
|
||||
s.saveAfterAssociations()
|
||||
s.model.setValueByColumn(s.model.primaryKey(), id, s.value)
|
||||
|
||||
s.saveAfterAssociations()
|
||||
s.model.callMethod("AfterCreate")
|
||||
s.model.callMethod("AfterSave")
|
||||
}
|
||||
|
@ -204,19 +202,17 @@ func (s *Do) create() (i interface{}) {
|
|||
}
|
||||
|
||||
func (s *Do) setUpdateAttrs(values interface{}, ignore_protected_attrs ...bool) *Do {
|
||||
switch values.(type) {
|
||||
switch vs := values.(type) {
|
||||
case map[string]interface{}:
|
||||
s.updateAttrs = values.(map[string]interface{})
|
||||
s.updateAttrs = vs
|
||||
case []interface{}:
|
||||
for _, value := range values.([]interface{}) {
|
||||
for _, value := range vs {
|
||||
s.setUpdateAttrs(value, ignore_protected_attrs...)
|
||||
}
|
||||
case interface{}:
|
||||
m := &Model{data: values, do: s}
|
||||
fields := m.columnsHasValue("other")
|
||||
|
||||
s.updateAttrs = make(map[string]interface{}, len(fields))
|
||||
for _, field := range fields {
|
||||
s.updateAttrs = map[string]interface{}{}
|
||||
for _, field := range m.columnsHasValue("other") {
|
||||
s.updateAttrs[field.DbName] = field.Value
|
||||
}
|
||||
}
|
||||
|
@ -238,8 +234,7 @@ func (s *Do) prepareUpdateSql(results map[string]interface{}) {
|
|||
sqls = append(sqls, fmt.Sprintf("%v = %v", key, s.addToVars(value)))
|
||||
}
|
||||
|
||||
update_attrs := s.model.columnsAndValues("update")
|
||||
for key, value := range update_attrs {
|
||||
for key, value := range s.model.columnsAndValues("update") {
|
||||
sqls = append(sqls, fmt.Sprintf("%v = %v", key, s.addToVars(value)))
|
||||
}
|
||||
|
||||
|
@ -252,13 +247,12 @@ func (s *Do) prepareUpdateSql(results map[string]interface{}) {
|
|||
return
|
||||
}
|
||||
|
||||
func (s *Do) update() (i interface{}) {
|
||||
func (s *Do) update() *Do {
|
||||
update_attrs := s.updateAttrs
|
||||
if len(update_attrs) > 0 {
|
||||
var need_update bool
|
||||
update_attrs, need_update = s.prepareUpdateAttrs()
|
||||
if !need_update {
|
||||
return
|
||||
if update_attrs, need_update = s.prepareUpdateAttrs(); !need_update {
|
||||
return s
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -276,29 +270,22 @@ func (s *Do) update() (i interface{}) {
|
|||
s.model.callMethod("AfterSave")
|
||||
}
|
||||
|
||||
return s.model.primaryKeyValue()
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Do) prepareDeleteSql() {
|
||||
s.sql = fmt.Sprintf("DELETE FROM %v %v", s.tableName(), s.combinedSql())
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Do) delete() {
|
||||
func (s *Do) delete() *Do {
|
||||
s.model.callMethod("BeforeDelete")
|
||||
|
||||
if !s.chain.hasError() {
|
||||
if !s.unscoped && s.model.hasColumn("DeletedAt") {
|
||||
delete_sql := "deleted_at=" + s.addToVars(time.Now())
|
||||
s.sql = fmt.Sprintf("UPDATE %v SET %v %v", s.tableName(), delete_sql, s.combinedSql())
|
||||
s.exec()
|
||||
s.sql = fmt.Sprintf("UPDATE %v SET deleted_at=%v %v", s.tableName(), s.addToVars(time.Now()), s.combinedSql())
|
||||
} else {
|
||||
s.prepareDeleteSql()
|
||||
s.exec()
|
||||
s.sql = fmt.Sprintf("DELETE FROM %v %v", s.tableName(), s.combinedSql())
|
||||
}
|
||||
s.exec()
|
||||
s.model.callMethod("AfterDelete")
|
||||
}
|
||||
return
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Do) prepareQuerySql() {
|
||||
|
@ -379,24 +366,18 @@ func (s *Do) query() {
|
|||
now := time.Now()
|
||||
rows, err := s.db.Query(s.sql, s.sqlVars...)
|
||||
s.chain.slog(s.sql, now, s.sqlVars...)
|
||||
|
||||
if s.err(err) != nil {
|
||||
return
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
if rows.Err() != nil {
|
||||
s.err(rows.Err())
|
||||
}
|
||||
|
||||
counts := 0
|
||||
var has_record bool
|
||||
for rows.Next() {
|
||||
counts += 1
|
||||
var dest reflect.Value
|
||||
has_record = true
|
||||
dest := dest_out
|
||||
if is_slice {
|
||||
dest = reflect.New(dest_type).Elem()
|
||||
} else {
|
||||
dest = dest_out
|
||||
}
|
||||
|
||||
columns, _ := rows.Columns()
|
||||
|
@ -404,15 +385,10 @@ func (s *Do) query() {
|
|||
for _, value := range columns {
|
||||
field := dest.FieldByName(snakeToUpperCamel(value))
|
||||
if field.IsValid() {
|
||||
if field.CanAddr() {
|
||||
values = append(values, field.Addr().Interface())
|
||||
} else {
|
||||
s.err(errors.New(fmt.Sprintf("Can't take address of %v, should be ptr", dest)))
|
||||
return
|
||||
}
|
||||
} else {
|
||||
var null interface{}
|
||||
values = append(values, &null)
|
||||
var ignore interface{}
|
||||
values = append(values, &ignore)
|
||||
}
|
||||
}
|
||||
s.err(rows.Scan(values...))
|
||||
|
@ -422,7 +398,7 @@ func (s *Do) query() {
|
|||
}
|
||||
}
|
||||
|
||||
if (counts == 0) && !is_slice {
|
||||
if !has_record && !is_slice {
|
||||
s.err(errors.New("Record not found!"))
|
||||
}
|
||||
}
|
||||
|
@ -477,9 +453,8 @@ func (s *Do) primaryCondiation(value interface{}) string {
|
|||
|
||||
func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
|
||||
query := clause["query"]
|
||||
switch query.(type) {
|
||||
switch value := query.(type) {
|
||||
case string:
|
||||
value := query.(string)
|
||||
if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
|
||||
id, _ := strconv.Atoi(value)
|
||||
return s.primaryCondiation(s.addToVars(id))
|
||||
|
@ -489,18 +464,18 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
|
|||
case int, int64, int32:
|
||||
return s.primaryCondiation(s.addToVars(query))
|
||||
case sql.NullInt64:
|
||||
return s.primaryCondiation(s.addToVars(query.(sql.NullInt64).Int64))
|
||||
return s.primaryCondiation(s.addToVars(value.Int64))
|
||||
case []int64, []int, []int32, []string:
|
||||
str = fmt.Sprintf("(%v in (?))", s.model.primaryKeyDb())
|
||||
clause["args"] = []interface{}{query}
|
||||
clause["args"] = []interface{}{value}
|
||||
case map[string]interface{}:
|
||||
var sqls []string
|
||||
for key, value := range query.(map[string]interface{}) {
|
||||
for key, value := range value {
|
||||
sqls = append(sqls, fmt.Sprintf("(%v = %v)", key, s.addToVars(value)))
|
||||
}
|
||||
return strings.Join(sqls, " AND ")
|
||||
case interface{}:
|
||||
m := &Model{data: query, do: s}
|
||||
m := &Model{data: value, do: s}
|
||||
var sqls []string
|
||||
for _, field := range m.columnsHasValue("other") {
|
||||
sqls = append(sqls, fmt.Sprintf("(%v = %v)", field.DbName, s.addToVars(field.Value)))
|
||||
|
@ -532,9 +507,8 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) {
|
|||
query := clause["query"]
|
||||
var not_equal_sql string
|
||||
|
||||
switch query.(type) {
|
||||
switch value := query.(type) {
|
||||
case string:
|
||||
value := query.(string)
|
||||
if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
|
||||
id, _ := strconv.Atoi(value)
|
||||
return fmt.Sprintf("(%v <> %v)", s.model.primaryKeyDb(), id)
|
||||
|
@ -556,7 +530,7 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) {
|
|||
}
|
||||
case map[string]interface{}:
|
||||
var sqls []string
|
||||
for key, value := range query.(map[string]interface{}) {
|
||||
for key, value := range value {
|
||||
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", key, s.addToVars(value)))
|
||||
}
|
||||
return strings.Join(sqls, " AND ")
|
||||
|
@ -751,13 +725,13 @@ func (s *Do) initializeWithSearchCondition() {
|
|||
|
||||
for _, clause := range s.whereClause {
|
||||
query := clause["query"]
|
||||
switch query.(type) {
|
||||
switch value := query.(type) {
|
||||
case map[string]interface{}:
|
||||
for key, value := range query.(map[string]interface{}) {
|
||||
m.setValueByColumn(key, value, s.value)
|
||||
for k, v := range value {
|
||||
m.setValueByColumn(k, v, s.value)
|
||||
}
|
||||
case []interface{}:
|
||||
for _, obj := range query.([]interface{}) {
|
||||
for _, obj := range value {
|
||||
switch reflect.ValueOf(obj).Kind() {
|
||||
case reflect.Struct:
|
||||
m := &Model{data: obj, do: s}
|
||||
|
|
Loading…
Reference in New Issue