Refact code to make it more strong

This commit is contained in:
Jinzhu 2013-10-29 22:00:06 +08:00
parent e250637d12
commit db7fc2d53a
3 changed files with 49 additions and 37 deletions

26
do.go
View File

@ -131,7 +131,10 @@ func (s *Do) create() {
if !s.hasError() {
result := reflect.ValueOf(s.value).Elem()
result.FieldByName(s.model.primaryKey()).SetInt(id)
primary_key := result.FieldByName(s.model.primaryKey())
if primary_key.IsValid() {
primary_key.SetInt(id)
}
s.err(s.model.callMethod("AfterCreate"))
s.err(s.model.callMethod("AfterSave"))
@ -238,7 +241,7 @@ func (s *Do) query() {
if is_slice {
dest = reflect.New(dest_type).Elem()
} else {
dest = reflect.ValueOf(s.value).Elem()
dest = dest_out
}
columns, _ := rows.Columns()
@ -279,7 +282,7 @@ func (s *Do) count(value interface{}) {
for rows.Next() {
var dest int64
if s.err(rows.Scan(&dest)) == nil {
dest_out.Set(reflect.ValueOf(dest))
dest_out.SetInt(dest)
}
}
}
@ -289,7 +292,13 @@ func (s *Do) count(value interface{}) {
func (s *Do) pluck(column string, value interface{}) {
s.selectStr = column
dest_out := reflect.Indirect(reflect.ValueOf(value))
if dest_out.Kind() != reflect.Slice {
s.err(errors.New("Return results should be a slice"))
return
}
dest_type := dest_out.Type().Elem()
s.prepareQuerySql()
if !s.hasError() {
@ -302,6 +311,7 @@ func (s *Do) pluck(column string, value interface{}) {
for rows.Next() {
dest := reflect.New(dest_type).Elem().Interface()
s.err(rows.Scan(&dest))
switch dest.(type) {
case []uint8:
if dest_type.String() == "string" {
@ -362,10 +372,10 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
for _, arg := range args {
switch reflect.TypeOf(arg).Kind() {
case reflect.Slice: // For where("id in (?)", []int64{1,2})
v := reflect.ValueOf(arg)
values := reflect.ValueOf(arg)
var temp_marks []string
for i := 0; i < v.Len(); i++ {
temp_marks = append(temp_marks, s.addToVars(v.Index(i).Addr().Interface()))
for i := 0; i < values.Len(); i++ {
temp_marks = append(temp_marks, s.addToVars(values.Index(i).Addr().Interface()))
}
str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1)
default:
@ -381,6 +391,7 @@ func (s *Do) whereSql() (sql string) {
if !s.unscoped && s.model.hasColumn("DeletedAt") {
primary_condiations = append(primary_condiations, "(deleted_at is null or deleted_at <= '0001-01-02')")
}
if !s.model.primaryKeyZero() {
primary_condiations = append(primary_condiations, s.primaryCondiation(s.addToVars(s.model.primaryKeyValue())))
}
@ -453,9 +464,10 @@ func (s *Do) combinedSql() string {
func (s *Do) createTable() *Do {
var sqls []string
for _, field := range s.model.fields("null") {
for _, field := range s.model.fields("") {
sqls = append(sqls, field.DbName+" "+field.SqlType)
}
s.sql = fmt.Sprintf(
"CREATE TABLE \"%v\" (%v)",
s.tableName(),

View File

@ -48,7 +48,7 @@ func init() {
panic(fmt.Sprintf("No error should happen when connect database, but got %+v", err))
}
db.SetPool(10)
db.DebugMode = true
// db.DebugMode = true
err = db.Exec("drop table users;").Error
if err != nil {

View File

@ -34,15 +34,21 @@ func (m *Model) primaryKeyValue() int64 {
return -1
}
t := reflect.TypeOf(m.data).Elem()
switch t.Kind() {
data := reflect.ValueOf(m.data).Elem()
switch data.Kind() {
case reflect.Array, reflect.Chan, reflect.Map, reflect.Ptr, reflect.Slice:
return 0
default:
result := reflect.ValueOf(m.data).Elem()
value := result.FieldByName(m.primaryKey())
value := data.FieldByName(m.primaryKey())
if value.IsValid() {
return value.Interface().(int64)
switch value.Kind() {
case reflect.Int, reflect.Int64, reflect.Int32:
return value.Int()
default:
return 0
}
} else {
return 0
}
@ -83,17 +89,19 @@ func (m *Model) fields(operation string) (fields []Field) {
}
}
if v, ok := value.Interface().(time.Time); ok {
switch operation {
case "create":
if (field.AutoCreateTime || field.AutoUpdateTime) && value.Interface().(time.Time).IsZero() {
if (field.AutoCreateTime || field.AutoUpdateTime) && v.IsZero() {
value.Set(reflect.ValueOf(time.Now()))
}
case "update":
if field.AutoUpdateTime {
value.Set(reflect.ValueOf(time.Now()))
}
default:
}
}
field.Value = value.Interface()
if field.IsPrimaryKey {
@ -134,7 +142,7 @@ func (m *Model) hasColumn(name string) bool {
data := reflect.Indirect(reflect.ValueOf(m.data))
if data.Kind() == reflect.Slice {
return false //FIXME data.Elem().FieldByName(name).IsValid()
return reflect.New(data.Type().Elem()).Elem().FieldByName(name).IsValid()
} else {
return data.FieldByName(name).IsValid()
}
@ -146,21 +154,12 @@ func (m *Model) tableName() (str string, err error) {
return
}
t := reflect.TypeOf(m.data)
for {
c := false
switch t.Kind() {
case reflect.Array, reflect.Chan, reflect.Map, reflect.Ptr, reflect.Slice:
t = t.Elem()
c = true
}
if !c {
break
}
typ := reflect.Indirect(reflect.ValueOf(m.data)).Type()
if typ.Kind() == reflect.Slice {
typ = typ.Elem()
}
str = toSnake(t.Name())
str = toSnake(typ.Name())
pluralMap := map[string]string{"ch": "ches", "ss": "sses", "sh": "shes", "day": "days", "y": "ies", "x": "xes", "s?": "s"}
for key, value := range pluralMap {
reg := regexp.MustCompile(key + "$")
@ -200,7 +199,8 @@ func (m *Model) missingColumns() (results []string) {
}
func (m *Model) setValueByColumn(name string, value interface{}, out interface{}) {
data := reflect.ValueOf(out).Elem()
data := reflect.Indirect(reflect.ValueOf(out))
field := data.FieldByName(snakeToUpperCamel(name))
if field.IsValid() {
field.Set(reflect.ValueOf(value))