Refact code to make it more strong

This commit is contained in:
Jinzhu 2013-10-29 22:00:06 +08:00
parent e250637d12
commit db7fc2d53a
3 changed files with 49 additions and 37 deletions

26
do.go
View File

@ -131,7 +131,10 @@ func (s *Do) create() {
if !s.hasError() { if !s.hasError() {
result := reflect.ValueOf(s.value).Elem() result := reflect.ValueOf(s.value).Elem()
result.FieldByName(s.model.primaryKey()).SetInt(id) primary_key := result.FieldByName(s.model.primaryKey())
if primary_key.IsValid() {
primary_key.SetInt(id)
}
s.err(s.model.callMethod("AfterCreate")) s.err(s.model.callMethod("AfterCreate"))
s.err(s.model.callMethod("AfterSave")) s.err(s.model.callMethod("AfterSave"))
@ -238,7 +241,7 @@ func (s *Do) query() {
if is_slice { if is_slice {
dest = reflect.New(dest_type).Elem() dest = reflect.New(dest_type).Elem()
} else { } else {
dest = reflect.ValueOf(s.value).Elem() dest = dest_out
} }
columns, _ := rows.Columns() columns, _ := rows.Columns()
@ -279,7 +282,7 @@ func (s *Do) count(value interface{}) {
for rows.Next() { for rows.Next() {
var dest int64 var dest int64
if s.err(rows.Scan(&dest)) == nil { if s.err(rows.Scan(&dest)) == nil {
dest_out.Set(reflect.ValueOf(dest)) dest_out.SetInt(dest)
} }
} }
} }
@ -289,7 +292,13 @@ func (s *Do) count(value interface{}) {
func (s *Do) pluck(column string, value interface{}) { func (s *Do) pluck(column string, value interface{}) {
s.selectStr = column s.selectStr = column
dest_out := reflect.Indirect(reflect.ValueOf(value)) dest_out := reflect.Indirect(reflect.ValueOf(value))
if dest_out.Kind() != reflect.Slice {
s.err(errors.New("Return results should be a slice"))
return
}
dest_type := dest_out.Type().Elem() dest_type := dest_out.Type().Elem()
s.prepareQuerySql() s.prepareQuerySql()
if !s.hasError() { if !s.hasError() {
@ -302,6 +311,7 @@ func (s *Do) pluck(column string, value interface{}) {
for rows.Next() { for rows.Next() {
dest := reflect.New(dest_type).Elem().Interface() dest := reflect.New(dest_type).Elem().Interface()
s.err(rows.Scan(&dest)) s.err(rows.Scan(&dest))
switch dest.(type) { switch dest.(type) {
case []uint8: case []uint8:
if dest_type.String() == "string" { if dest_type.String() == "string" {
@ -362,10 +372,10 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
for _, arg := range args { for _, arg := range args {
switch reflect.TypeOf(arg).Kind() { switch reflect.TypeOf(arg).Kind() {
case reflect.Slice: // For where("id in (?)", []int64{1,2}) case reflect.Slice: // For where("id in (?)", []int64{1,2})
v := reflect.ValueOf(arg) values := reflect.ValueOf(arg)
var temp_marks []string var temp_marks []string
for i := 0; i < v.Len(); i++ { for i := 0; i < values.Len(); i++ {
temp_marks = append(temp_marks, s.addToVars(v.Index(i).Addr().Interface())) temp_marks = append(temp_marks, s.addToVars(values.Index(i).Addr().Interface()))
} }
str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1) str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1)
default: default:
@ -381,6 +391,7 @@ func (s *Do) whereSql() (sql string) {
if !s.unscoped && s.model.hasColumn("DeletedAt") { if !s.unscoped && s.model.hasColumn("DeletedAt") {
primary_condiations = append(primary_condiations, "(deleted_at is null or deleted_at <= '0001-01-02')") primary_condiations = append(primary_condiations, "(deleted_at is null or deleted_at <= '0001-01-02')")
} }
if !s.model.primaryKeyZero() { if !s.model.primaryKeyZero() {
primary_condiations = append(primary_condiations, s.primaryCondiation(s.addToVars(s.model.primaryKeyValue()))) primary_condiations = append(primary_condiations, s.primaryCondiation(s.addToVars(s.model.primaryKeyValue())))
} }
@ -453,9 +464,10 @@ func (s *Do) combinedSql() string {
func (s *Do) createTable() *Do { func (s *Do) createTable() *Do {
var sqls []string var sqls []string
for _, field := range s.model.fields("null") { for _, field := range s.model.fields("") {
sqls = append(sqls, field.DbName+" "+field.SqlType) sqls = append(sqls, field.DbName+" "+field.SqlType)
} }
s.sql = fmt.Sprintf( s.sql = fmt.Sprintf(
"CREATE TABLE \"%v\" (%v)", "CREATE TABLE \"%v\" (%v)",
s.tableName(), s.tableName(),

View File

@ -48,7 +48,7 @@ func init() {
panic(fmt.Sprintf("No error should happen when connect database, but got %+v", err)) panic(fmt.Sprintf("No error should happen when connect database, but got %+v", err))
} }
db.SetPool(10) db.SetPool(10)
db.DebugMode = true // db.DebugMode = true
err = db.Exec("drop table users;").Error err = db.Exec("drop table users;").Error
if err != nil { if err != nil {

View File

@ -34,15 +34,21 @@ func (m *Model) primaryKeyValue() int64 {
return -1 return -1
} }
t := reflect.TypeOf(m.data).Elem() data := reflect.ValueOf(m.data).Elem()
switch t.Kind() {
switch data.Kind() {
case reflect.Array, reflect.Chan, reflect.Map, reflect.Ptr, reflect.Slice: case reflect.Array, reflect.Chan, reflect.Map, reflect.Ptr, reflect.Slice:
return 0 return 0
default: default:
result := reflect.ValueOf(m.data).Elem() value := data.FieldByName(m.primaryKey())
value := result.FieldByName(m.primaryKey())
if value.IsValid() { if value.IsValid() {
return value.Interface().(int64) switch value.Kind() {
case reflect.Int, reflect.Int64, reflect.Int32:
return value.Int()
default:
return 0
}
} else { } else {
return 0 return 0
} }
@ -83,17 +89,19 @@ func (m *Model) fields(operation string) (fields []Field) {
} }
} }
if v, ok := value.Interface().(time.Time); ok {
switch operation { switch operation {
case "create": case "create":
if (field.AutoCreateTime || field.AutoUpdateTime) && value.Interface().(time.Time).IsZero() { if (field.AutoCreateTime || field.AutoUpdateTime) && v.IsZero() {
value.Set(reflect.ValueOf(time.Now())) value.Set(reflect.ValueOf(time.Now()))
} }
case "update": case "update":
if field.AutoUpdateTime { if field.AutoUpdateTime {
value.Set(reflect.ValueOf(time.Now())) value.Set(reflect.ValueOf(time.Now()))
} }
default:
} }
}
field.Value = value.Interface() field.Value = value.Interface()
if field.IsPrimaryKey { if field.IsPrimaryKey {
@ -134,7 +142,7 @@ func (m *Model) hasColumn(name string) bool {
data := reflect.Indirect(reflect.ValueOf(m.data)) data := reflect.Indirect(reflect.ValueOf(m.data))
if data.Kind() == reflect.Slice { if data.Kind() == reflect.Slice {
return false //FIXME data.Elem().FieldByName(name).IsValid() return reflect.New(data.Type().Elem()).Elem().FieldByName(name).IsValid()
} else { } else {
return data.FieldByName(name).IsValid() return data.FieldByName(name).IsValid()
} }
@ -146,21 +154,12 @@ func (m *Model) tableName() (str string, err error) {
return return
} }
t := reflect.TypeOf(m.data) typ := reflect.Indirect(reflect.ValueOf(m.data)).Type()
for { if typ.Kind() == reflect.Slice {
c := false typ = typ.Elem()
switch t.Kind() {
case reflect.Array, reflect.Chan, reflect.Map, reflect.Ptr, reflect.Slice:
t = t.Elem()
c = true
}
if !c {
break
}
} }
str = toSnake(t.Name()) str = toSnake(typ.Name())
pluralMap := map[string]string{"ch": "ches", "ss": "sses", "sh": "shes", "day": "days", "y": "ies", "x": "xes", "s?": "s"} pluralMap := map[string]string{"ch": "ches", "ss": "sses", "sh": "shes", "day": "days", "y": "ies", "x": "xes", "s?": "s"}
for key, value := range pluralMap { for key, value := range pluralMap {
reg := regexp.MustCompile(key + "$") reg := regexp.MustCompile(key + "$")
@ -200,7 +199,8 @@ func (m *Model) missingColumns() (results []string) {
} }
func (m *Model) setValueByColumn(name string, value interface{}, out interface{}) { func (m *Model) setValueByColumn(name string, value interface{}, out interface{}) {
data := reflect.ValueOf(out).Elem() data := reflect.Indirect(reflect.ValueOf(out))
field := data.FieldByName(snakeToUpperCamel(name)) field := data.FieldByName(snakeToUpperCamel(name))
if field.IsValid() { if field.IsValid() {
field.Set(reflect.ValueOf(value)) field.Set(reflect.ValueOf(value))