From db7fc2d53ade0a09d3842b90f59403c997138c47 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 29 Oct 2013 22:00:06 +0800 Subject: [PATCH] Refact code to make it more strong --- do.go | 26 ++++++++++++++++------- gorm_test.go | 2 +- model.go | 58 ++++++++++++++++++++++++++-------------------------- 3 files changed, 49 insertions(+), 37 deletions(-) diff --git a/do.go b/do.go index 8b700090..641abf87 100644 --- a/do.go +++ b/do.go @@ -131,7 +131,10 @@ func (s *Do) create() { if !s.hasError() { result := reflect.ValueOf(s.value).Elem() - result.FieldByName(s.model.primaryKey()).SetInt(id) + primary_key := result.FieldByName(s.model.primaryKey()) + if primary_key.IsValid() { + primary_key.SetInt(id) + } s.err(s.model.callMethod("AfterCreate")) s.err(s.model.callMethod("AfterSave")) @@ -238,7 +241,7 @@ func (s *Do) query() { if is_slice { dest = reflect.New(dest_type).Elem() } else { - dest = reflect.ValueOf(s.value).Elem() + dest = dest_out } columns, _ := rows.Columns() @@ -279,7 +282,7 @@ func (s *Do) count(value interface{}) { for rows.Next() { var dest int64 if s.err(rows.Scan(&dest)) == nil { - dest_out.Set(reflect.ValueOf(dest)) + dest_out.SetInt(dest) } } } @@ -289,7 +292,13 @@ func (s *Do) count(value interface{}) { func (s *Do) pluck(column string, value interface{}) { s.selectStr = column dest_out := reflect.Indirect(reflect.ValueOf(value)) + + if dest_out.Kind() != reflect.Slice { + s.err(errors.New("Return results should be a slice")) + return + } dest_type := dest_out.Type().Elem() + s.prepareQuerySql() if !s.hasError() { @@ -302,6 +311,7 @@ func (s *Do) pluck(column string, value interface{}) { for rows.Next() { dest := reflect.New(dest_type).Elem().Interface() s.err(rows.Scan(&dest)) + switch dest.(type) { case []uint8: if dest_type.String() == "string" { @@ -362,10 +372,10 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) { for _, arg := range args { switch reflect.TypeOf(arg).Kind() { case reflect.Slice: // For where("id in (?)", []int64{1,2}) - v := reflect.ValueOf(arg) + values := reflect.ValueOf(arg) var temp_marks []string - for i := 0; i < v.Len(); i++ { - temp_marks = append(temp_marks, s.addToVars(v.Index(i).Addr().Interface())) + for i := 0; i < values.Len(); i++ { + temp_marks = append(temp_marks, s.addToVars(values.Index(i).Addr().Interface())) } str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1) default: @@ -381,6 +391,7 @@ func (s *Do) whereSql() (sql string) { if !s.unscoped && s.model.hasColumn("DeletedAt") { primary_condiations = append(primary_condiations, "(deleted_at is null or deleted_at <= '0001-01-02')") } + if !s.model.primaryKeyZero() { primary_condiations = append(primary_condiations, s.primaryCondiation(s.addToVars(s.model.primaryKeyValue()))) } @@ -453,9 +464,10 @@ func (s *Do) combinedSql() string { func (s *Do) createTable() *Do { var sqls []string - for _, field := range s.model.fields("null") { + for _, field := range s.model.fields("") { sqls = append(sqls, field.DbName+" "+field.SqlType) } + s.sql = fmt.Sprintf( "CREATE TABLE \"%v\" (%v)", s.tableName(), diff --git a/gorm_test.go b/gorm_test.go index ee903c4e..9275c7a0 100644 --- a/gorm_test.go +++ b/gorm_test.go @@ -48,7 +48,7 @@ func init() { panic(fmt.Sprintf("No error should happen when connect database, but got %+v", err)) } db.SetPool(10) - db.DebugMode = true + // db.DebugMode = true err = db.Exec("drop table users;").Error if err != nil { diff --git a/model.go b/model.go index 4d200626..facb9671 100644 --- a/model.go +++ b/model.go @@ -34,15 +34,21 @@ func (m *Model) primaryKeyValue() int64 { return -1 } - t := reflect.TypeOf(m.data).Elem() - switch t.Kind() { + data := reflect.ValueOf(m.data).Elem() + + switch data.Kind() { case reflect.Array, reflect.Chan, reflect.Map, reflect.Ptr, reflect.Slice: return 0 default: - result := reflect.ValueOf(m.data).Elem() - value := result.FieldByName(m.primaryKey()) + value := data.FieldByName(m.primaryKey()) + if value.IsValid() { - return value.Interface().(int64) + switch value.Kind() { + case reflect.Int, reflect.Int64, reflect.Int32: + return value.Int() + default: + return 0 + } } else { return 0 } @@ -83,17 +89,19 @@ func (m *Model) fields(operation string) (fields []Field) { } } - switch operation { - case "create": - if (field.AutoCreateTime || field.AutoUpdateTime) && value.Interface().(time.Time).IsZero() { - value.Set(reflect.ValueOf(time.Now())) + if v, ok := value.Interface().(time.Time); ok { + switch operation { + case "create": + if (field.AutoCreateTime || field.AutoUpdateTime) && v.IsZero() { + value.Set(reflect.ValueOf(time.Now())) + } + case "update": + if field.AutoUpdateTime { + value.Set(reflect.ValueOf(time.Now())) + } } - case "update": - if field.AutoUpdateTime { - value.Set(reflect.ValueOf(time.Now())) - } - default: } + field.Value = value.Interface() if field.IsPrimaryKey { @@ -134,7 +142,7 @@ func (m *Model) hasColumn(name string) bool { data := reflect.Indirect(reflect.ValueOf(m.data)) if data.Kind() == reflect.Slice { - return false //FIXME data.Elem().FieldByName(name).IsValid() + return reflect.New(data.Type().Elem()).Elem().FieldByName(name).IsValid() } else { return data.FieldByName(name).IsValid() } @@ -146,21 +154,12 @@ func (m *Model) tableName() (str string, err error) { return } - t := reflect.TypeOf(m.data) - for { - c := false - switch t.Kind() { - case reflect.Array, reflect.Chan, reflect.Map, reflect.Ptr, reflect.Slice: - t = t.Elem() - c = true - } - if !c { - break - } + typ := reflect.Indirect(reflect.ValueOf(m.data)).Type() + if typ.Kind() == reflect.Slice { + typ = typ.Elem() } - str = toSnake(t.Name()) - + str = toSnake(typ.Name()) pluralMap := map[string]string{"ch": "ches", "ss": "sses", "sh": "shes", "day": "days", "y": "ies", "x": "xes", "s?": "s"} for key, value := range pluralMap { reg := regexp.MustCompile(key + "$") @@ -200,7 +199,8 @@ func (m *Model) missingColumns() (results []string) { } func (m *Model) setValueByColumn(name string, value interface{}, out interface{}) { - data := reflect.ValueOf(out).Elem() + data := reflect.Indirect(reflect.ValueOf(out)) + field := data.FieldByName(snakeToUpperCamel(name)) if field.IsValid() { field.Set(reflect.ValueOf(value))