mirror of https://github.com/go-gorm/gorm.git
Refact code to make it more strong
This commit is contained in:
parent
e250637d12
commit
db7fc2d53a
26
do.go
26
do.go
|
@ -131,7 +131,10 @@ func (s *Do) create() {
|
||||||
|
|
||||||
if !s.hasError() {
|
if !s.hasError() {
|
||||||
result := reflect.ValueOf(s.value).Elem()
|
result := reflect.ValueOf(s.value).Elem()
|
||||||
result.FieldByName(s.model.primaryKey()).SetInt(id)
|
primary_key := result.FieldByName(s.model.primaryKey())
|
||||||
|
if primary_key.IsValid() {
|
||||||
|
primary_key.SetInt(id)
|
||||||
|
}
|
||||||
|
|
||||||
s.err(s.model.callMethod("AfterCreate"))
|
s.err(s.model.callMethod("AfterCreate"))
|
||||||
s.err(s.model.callMethod("AfterSave"))
|
s.err(s.model.callMethod("AfterSave"))
|
||||||
|
@ -238,7 +241,7 @@ func (s *Do) query() {
|
||||||
if is_slice {
|
if is_slice {
|
||||||
dest = reflect.New(dest_type).Elem()
|
dest = reflect.New(dest_type).Elem()
|
||||||
} else {
|
} else {
|
||||||
dest = reflect.ValueOf(s.value).Elem()
|
dest = dest_out
|
||||||
}
|
}
|
||||||
|
|
||||||
columns, _ := rows.Columns()
|
columns, _ := rows.Columns()
|
||||||
|
@ -279,7 +282,7 @@ func (s *Do) count(value interface{}) {
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var dest int64
|
var dest int64
|
||||||
if s.err(rows.Scan(&dest)) == nil {
|
if s.err(rows.Scan(&dest)) == nil {
|
||||||
dest_out.Set(reflect.ValueOf(dest))
|
dest_out.SetInt(dest)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -289,7 +292,13 @@ func (s *Do) count(value interface{}) {
|
||||||
func (s *Do) pluck(column string, value interface{}) {
|
func (s *Do) pluck(column string, value interface{}) {
|
||||||
s.selectStr = column
|
s.selectStr = column
|
||||||
dest_out := reflect.Indirect(reflect.ValueOf(value))
|
dest_out := reflect.Indirect(reflect.ValueOf(value))
|
||||||
|
|
||||||
|
if dest_out.Kind() != reflect.Slice {
|
||||||
|
s.err(errors.New("Return results should be a slice"))
|
||||||
|
return
|
||||||
|
}
|
||||||
dest_type := dest_out.Type().Elem()
|
dest_type := dest_out.Type().Elem()
|
||||||
|
|
||||||
s.prepareQuerySql()
|
s.prepareQuerySql()
|
||||||
|
|
||||||
if !s.hasError() {
|
if !s.hasError() {
|
||||||
|
@ -302,6 +311,7 @@ func (s *Do) pluck(column string, value interface{}) {
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
dest := reflect.New(dest_type).Elem().Interface()
|
dest := reflect.New(dest_type).Elem().Interface()
|
||||||
s.err(rows.Scan(&dest))
|
s.err(rows.Scan(&dest))
|
||||||
|
|
||||||
switch dest.(type) {
|
switch dest.(type) {
|
||||||
case []uint8:
|
case []uint8:
|
||||||
if dest_type.String() == "string" {
|
if dest_type.String() == "string" {
|
||||||
|
@ -362,10 +372,10 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
|
||||||
for _, arg := range args {
|
for _, arg := range args {
|
||||||
switch reflect.TypeOf(arg).Kind() {
|
switch reflect.TypeOf(arg).Kind() {
|
||||||
case reflect.Slice: // For where("id in (?)", []int64{1,2})
|
case reflect.Slice: // For where("id in (?)", []int64{1,2})
|
||||||
v := reflect.ValueOf(arg)
|
values := reflect.ValueOf(arg)
|
||||||
var temp_marks []string
|
var temp_marks []string
|
||||||
for i := 0; i < v.Len(); i++ {
|
for i := 0; i < values.Len(); i++ {
|
||||||
temp_marks = append(temp_marks, s.addToVars(v.Index(i).Addr().Interface()))
|
temp_marks = append(temp_marks, s.addToVars(values.Index(i).Addr().Interface()))
|
||||||
}
|
}
|
||||||
str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1)
|
str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1)
|
||||||
default:
|
default:
|
||||||
|
@ -381,6 +391,7 @@ func (s *Do) whereSql() (sql string) {
|
||||||
if !s.unscoped && s.model.hasColumn("DeletedAt") {
|
if !s.unscoped && s.model.hasColumn("DeletedAt") {
|
||||||
primary_condiations = append(primary_condiations, "(deleted_at is null or deleted_at <= '0001-01-02')")
|
primary_condiations = append(primary_condiations, "(deleted_at is null or deleted_at <= '0001-01-02')")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !s.model.primaryKeyZero() {
|
if !s.model.primaryKeyZero() {
|
||||||
primary_condiations = append(primary_condiations, s.primaryCondiation(s.addToVars(s.model.primaryKeyValue())))
|
primary_condiations = append(primary_condiations, s.primaryCondiation(s.addToVars(s.model.primaryKeyValue())))
|
||||||
}
|
}
|
||||||
|
@ -453,9 +464,10 @@ func (s *Do) combinedSql() string {
|
||||||
|
|
||||||
func (s *Do) createTable() *Do {
|
func (s *Do) createTable() *Do {
|
||||||
var sqls []string
|
var sqls []string
|
||||||
for _, field := range s.model.fields("null") {
|
for _, field := range s.model.fields("") {
|
||||||
sqls = append(sqls, field.DbName+" "+field.SqlType)
|
sqls = append(sqls, field.DbName+" "+field.SqlType)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.sql = fmt.Sprintf(
|
s.sql = fmt.Sprintf(
|
||||||
"CREATE TABLE \"%v\" (%v)",
|
"CREATE TABLE \"%v\" (%v)",
|
||||||
s.tableName(),
|
s.tableName(),
|
||||||
|
|
|
@ -48,7 +48,7 @@ func init() {
|
||||||
panic(fmt.Sprintf("No error should happen when connect database, but got %+v", err))
|
panic(fmt.Sprintf("No error should happen when connect database, but got %+v", err))
|
||||||
}
|
}
|
||||||
db.SetPool(10)
|
db.SetPool(10)
|
||||||
db.DebugMode = true
|
// db.DebugMode = true
|
||||||
|
|
||||||
err = db.Exec("drop table users;").Error
|
err = db.Exec("drop table users;").Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
44
model.go
44
model.go
|
@ -34,15 +34,21 @@ func (m *Model) primaryKeyValue() int64 {
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
t := reflect.TypeOf(m.data).Elem()
|
data := reflect.ValueOf(m.data).Elem()
|
||||||
switch t.Kind() {
|
|
||||||
|
switch data.Kind() {
|
||||||
case reflect.Array, reflect.Chan, reflect.Map, reflect.Ptr, reflect.Slice:
|
case reflect.Array, reflect.Chan, reflect.Map, reflect.Ptr, reflect.Slice:
|
||||||
return 0
|
return 0
|
||||||
default:
|
default:
|
||||||
result := reflect.ValueOf(m.data).Elem()
|
value := data.FieldByName(m.primaryKey())
|
||||||
value := result.FieldByName(m.primaryKey())
|
|
||||||
if value.IsValid() {
|
if value.IsValid() {
|
||||||
return value.Interface().(int64)
|
switch value.Kind() {
|
||||||
|
case reflect.Int, reflect.Int64, reflect.Int32:
|
||||||
|
return value.Int()
|
||||||
|
default:
|
||||||
|
return 0
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
@ -83,17 +89,19 @@ func (m *Model) fields(operation string) (fields []Field) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if v, ok := value.Interface().(time.Time); ok {
|
||||||
switch operation {
|
switch operation {
|
||||||
case "create":
|
case "create":
|
||||||
if (field.AutoCreateTime || field.AutoUpdateTime) && value.Interface().(time.Time).IsZero() {
|
if (field.AutoCreateTime || field.AutoUpdateTime) && v.IsZero() {
|
||||||
value.Set(reflect.ValueOf(time.Now()))
|
value.Set(reflect.ValueOf(time.Now()))
|
||||||
}
|
}
|
||||||
case "update":
|
case "update":
|
||||||
if field.AutoUpdateTime {
|
if field.AutoUpdateTime {
|
||||||
value.Set(reflect.ValueOf(time.Now()))
|
value.Set(reflect.ValueOf(time.Now()))
|
||||||
}
|
}
|
||||||
default:
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
field.Value = value.Interface()
|
field.Value = value.Interface()
|
||||||
|
|
||||||
if field.IsPrimaryKey {
|
if field.IsPrimaryKey {
|
||||||
|
@ -134,7 +142,7 @@ func (m *Model) hasColumn(name string) bool {
|
||||||
|
|
||||||
data := reflect.Indirect(reflect.ValueOf(m.data))
|
data := reflect.Indirect(reflect.ValueOf(m.data))
|
||||||
if data.Kind() == reflect.Slice {
|
if data.Kind() == reflect.Slice {
|
||||||
return false //FIXME data.Elem().FieldByName(name).IsValid()
|
return reflect.New(data.Type().Elem()).Elem().FieldByName(name).IsValid()
|
||||||
} else {
|
} else {
|
||||||
return data.FieldByName(name).IsValid()
|
return data.FieldByName(name).IsValid()
|
||||||
}
|
}
|
||||||
|
@ -146,21 +154,12 @@ func (m *Model) tableName() (str string, err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
t := reflect.TypeOf(m.data)
|
typ := reflect.Indirect(reflect.ValueOf(m.data)).Type()
|
||||||
for {
|
if typ.Kind() == reflect.Slice {
|
||||||
c := false
|
typ = typ.Elem()
|
||||||
switch t.Kind() {
|
|
||||||
case reflect.Array, reflect.Chan, reflect.Map, reflect.Ptr, reflect.Slice:
|
|
||||||
t = t.Elem()
|
|
||||||
c = true
|
|
||||||
}
|
|
||||||
if !c {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
str = toSnake(t.Name())
|
str = toSnake(typ.Name())
|
||||||
|
|
||||||
pluralMap := map[string]string{"ch": "ches", "ss": "sses", "sh": "shes", "day": "days", "y": "ies", "x": "xes", "s?": "s"}
|
pluralMap := map[string]string{"ch": "ches", "ss": "sses", "sh": "shes", "day": "days", "y": "ies", "x": "xes", "s?": "s"}
|
||||||
for key, value := range pluralMap {
|
for key, value := range pluralMap {
|
||||||
reg := regexp.MustCompile(key + "$")
|
reg := regexp.MustCompile(key + "$")
|
||||||
|
@ -200,7 +199,8 @@ func (m *Model) missingColumns() (results []string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) setValueByColumn(name string, value interface{}, out interface{}) {
|
func (m *Model) setValueByColumn(name string, value interface{}, out interface{}) {
|
||||||
data := reflect.ValueOf(out).Elem()
|
data := reflect.Indirect(reflect.ValueOf(out))
|
||||||
|
|
||||||
field := data.FieldByName(snakeToUpperCamel(name))
|
field := data.FieldByName(snakeToUpperCamel(name))
|
||||||
if field.IsValid() {
|
if field.IsValid() {
|
||||||
field.Set(reflect.ValueOf(value))
|
field.Set(reflect.ValueOf(value))
|
||||||
|
|
Loading…
Reference in New Issue