mirror of https://github.com/go-gorm/gorm.git
Refact code to make it more strong
This commit is contained in:
parent
e250637d12
commit
db7fc2d53a
26
do.go
26
do.go
|
@ -131,7 +131,10 @@ func (s *Do) create() {
|
|||
|
||||
if !s.hasError() {
|
||||
result := reflect.ValueOf(s.value).Elem()
|
||||
result.FieldByName(s.model.primaryKey()).SetInt(id)
|
||||
primary_key := result.FieldByName(s.model.primaryKey())
|
||||
if primary_key.IsValid() {
|
||||
primary_key.SetInt(id)
|
||||
}
|
||||
|
||||
s.err(s.model.callMethod("AfterCreate"))
|
||||
s.err(s.model.callMethod("AfterSave"))
|
||||
|
@ -238,7 +241,7 @@ func (s *Do) query() {
|
|||
if is_slice {
|
||||
dest = reflect.New(dest_type).Elem()
|
||||
} else {
|
||||
dest = reflect.ValueOf(s.value).Elem()
|
||||
dest = dest_out
|
||||
}
|
||||
|
||||
columns, _ := rows.Columns()
|
||||
|
@ -279,7 +282,7 @@ func (s *Do) count(value interface{}) {
|
|||
for rows.Next() {
|
||||
var dest int64
|
||||
if s.err(rows.Scan(&dest)) == nil {
|
||||
dest_out.Set(reflect.ValueOf(dest))
|
||||
dest_out.SetInt(dest)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -289,7 +292,13 @@ func (s *Do) count(value interface{}) {
|
|||
func (s *Do) pluck(column string, value interface{}) {
|
||||
s.selectStr = column
|
||||
dest_out := reflect.Indirect(reflect.ValueOf(value))
|
||||
|
||||
if dest_out.Kind() != reflect.Slice {
|
||||
s.err(errors.New("Return results should be a slice"))
|
||||
return
|
||||
}
|
||||
dest_type := dest_out.Type().Elem()
|
||||
|
||||
s.prepareQuerySql()
|
||||
|
||||
if !s.hasError() {
|
||||
|
@ -302,6 +311,7 @@ func (s *Do) pluck(column string, value interface{}) {
|
|||
for rows.Next() {
|
||||
dest := reflect.New(dest_type).Elem().Interface()
|
||||
s.err(rows.Scan(&dest))
|
||||
|
||||
switch dest.(type) {
|
||||
case []uint8:
|
||||
if dest_type.String() == "string" {
|
||||
|
@ -362,10 +372,10 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
|
|||
for _, arg := range args {
|
||||
switch reflect.TypeOf(arg).Kind() {
|
||||
case reflect.Slice: // For where("id in (?)", []int64{1,2})
|
||||
v := reflect.ValueOf(arg)
|
||||
values := reflect.ValueOf(arg)
|
||||
var temp_marks []string
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
temp_marks = append(temp_marks, s.addToVars(v.Index(i).Addr().Interface()))
|
||||
for i := 0; i < values.Len(); i++ {
|
||||
temp_marks = append(temp_marks, s.addToVars(values.Index(i).Addr().Interface()))
|
||||
}
|
||||
str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1)
|
||||
default:
|
||||
|
@ -381,6 +391,7 @@ func (s *Do) whereSql() (sql string) {
|
|||
if !s.unscoped && s.model.hasColumn("DeletedAt") {
|
||||
primary_condiations = append(primary_condiations, "(deleted_at is null or deleted_at <= '0001-01-02')")
|
||||
}
|
||||
|
||||
if !s.model.primaryKeyZero() {
|
||||
primary_condiations = append(primary_condiations, s.primaryCondiation(s.addToVars(s.model.primaryKeyValue())))
|
||||
}
|
||||
|
@ -453,9 +464,10 @@ func (s *Do) combinedSql() string {
|
|||
|
||||
func (s *Do) createTable() *Do {
|
||||
var sqls []string
|
||||
for _, field := range s.model.fields("null") {
|
||||
for _, field := range s.model.fields("") {
|
||||
sqls = append(sqls, field.DbName+" "+field.SqlType)
|
||||
}
|
||||
|
||||
s.sql = fmt.Sprintf(
|
||||
"CREATE TABLE \"%v\" (%v)",
|
||||
s.tableName(),
|
||||
|
|
|
@ -48,7 +48,7 @@ func init() {
|
|||
panic(fmt.Sprintf("No error should happen when connect database, but got %+v", err))
|
||||
}
|
||||
db.SetPool(10)
|
||||
db.DebugMode = true
|
||||
// db.DebugMode = true
|
||||
|
||||
err = db.Exec("drop table users;").Error
|
||||
if err != nil {
|
||||
|
|
44
model.go
44
model.go
|
@ -34,15 +34,21 @@ func (m *Model) primaryKeyValue() int64 {
|
|||
return -1
|
||||
}
|
||||
|
||||
t := reflect.TypeOf(m.data).Elem()
|
||||
switch t.Kind() {
|
||||
data := reflect.ValueOf(m.data).Elem()
|
||||
|
||||
switch data.Kind() {
|
||||
case reflect.Array, reflect.Chan, reflect.Map, reflect.Ptr, reflect.Slice:
|
||||
return 0
|
||||
default:
|
||||
result := reflect.ValueOf(m.data).Elem()
|
||||
value := result.FieldByName(m.primaryKey())
|
||||
value := data.FieldByName(m.primaryKey())
|
||||
|
||||
if value.IsValid() {
|
||||
return value.Interface().(int64)
|
||||
switch value.Kind() {
|
||||
case reflect.Int, reflect.Int64, reflect.Int32:
|
||||
return value.Int()
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
} else {
|
||||
return 0
|
||||
}
|
||||
|
@ -83,17 +89,19 @@ func (m *Model) fields(operation string) (fields []Field) {
|
|||
}
|
||||
}
|
||||
|
||||
if v, ok := value.Interface().(time.Time); ok {
|
||||
switch operation {
|
||||
case "create":
|
||||
if (field.AutoCreateTime || field.AutoUpdateTime) && value.Interface().(time.Time).IsZero() {
|
||||
if (field.AutoCreateTime || field.AutoUpdateTime) && v.IsZero() {
|
||||
value.Set(reflect.ValueOf(time.Now()))
|
||||
}
|
||||
case "update":
|
||||
if field.AutoUpdateTime {
|
||||
value.Set(reflect.ValueOf(time.Now()))
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
field.Value = value.Interface()
|
||||
|
||||
if field.IsPrimaryKey {
|
||||
|
@ -134,7 +142,7 @@ func (m *Model) hasColumn(name string) bool {
|
|||
|
||||
data := reflect.Indirect(reflect.ValueOf(m.data))
|
||||
if data.Kind() == reflect.Slice {
|
||||
return false //FIXME data.Elem().FieldByName(name).IsValid()
|
||||
return reflect.New(data.Type().Elem()).Elem().FieldByName(name).IsValid()
|
||||
} else {
|
||||
return data.FieldByName(name).IsValid()
|
||||
}
|
||||
|
@ -146,21 +154,12 @@ func (m *Model) tableName() (str string, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
t := reflect.TypeOf(m.data)
|
||||
for {
|
||||
c := false
|
||||
switch t.Kind() {
|
||||
case reflect.Array, reflect.Chan, reflect.Map, reflect.Ptr, reflect.Slice:
|
||||
t = t.Elem()
|
||||
c = true
|
||||
}
|
||||
if !c {
|
||||
break
|
||||
}
|
||||
typ := reflect.Indirect(reflect.ValueOf(m.data)).Type()
|
||||
if typ.Kind() == reflect.Slice {
|
||||
typ = typ.Elem()
|
||||
}
|
||||
|
||||
str = toSnake(t.Name())
|
||||
|
||||
str = toSnake(typ.Name())
|
||||
pluralMap := map[string]string{"ch": "ches", "ss": "sses", "sh": "shes", "day": "days", "y": "ies", "x": "xes", "s?": "s"}
|
||||
for key, value := range pluralMap {
|
||||
reg := regexp.MustCompile(key + "$")
|
||||
|
@ -200,7 +199,8 @@ func (m *Model) missingColumns() (results []string) {
|
|||
}
|
||||
|
||||
func (m *Model) setValueByColumn(name string, value interface{}, out interface{}) {
|
||||
data := reflect.ValueOf(out).Elem()
|
||||
data := reflect.Indirect(reflect.ValueOf(out))
|
||||
|
||||
field := data.FieldByName(snakeToUpperCamel(name))
|
||||
if field.IsValid() {
|
||||
field.Set(reflect.ValueOf(value))
|
||||
|
|
Loading…
Reference in New Issue