From bc785a91732e9fd89b2051a1fba64f288a276304 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 29 Oct 2013 07:39:26 +0800 Subject: [PATCH] Cleanup code --- README.md | 5 +- chain.go | 15 ++- do.go | 262 ++++++++++++++++++++++++++------------------------- gorm_test.go | 15 ++- main.go | 28 +++--- model.go | 40 ++++---- utils.go | 11 --- 7 files changed, 189 insertions(+), 187 deletions(-) diff --git a/README.md b/README.md index 1f9ab8f4..01ee3707 100644 --- a/README.md +++ b/README.md @@ -217,10 +217,11 @@ db.Where("mail_type = ?", "TEXT").Find(&users1).Table("deleted_users").First(&us ## TODO * Soft Delete * Query with map or struct +* SubStruct +* Index, Unique, Valiations +* Auto Migration * FindOrInitialize / FindOrCreate * SQL Log -* Auto Migration -* Index, Unique, Valiations * SQL Query with goroutines * Only tested with postgres, confirm works with other database adaptors diff --git a/chain.go b/chain.go index 59bb33e3..9e184334 100644 --- a/chain.go +++ b/chain.go @@ -26,11 +26,12 @@ type Chain struct { specifiedTableName string } -func (s *Chain) err(err error) { +func (s *Chain) err(err error) error { if err != nil { s.Errors = append(s.Errors, err) s.Error = err } + return err } func (s *Chain) do(value interface{}) *Do { @@ -136,18 +137,16 @@ func (s *Chain) Update(column string, value interface{}) *Chain { return s.Updates(map[string]interface{}{column: value}, true) } -func (s *Chain) Updates(values map[string]interface{}, ignore_protected_attrs ...interface{}) *Chain { +func (s *Chain) Updates(values map[string]interface{}, ignore_protected_attrs ...bool) *Chain { do := s.do(s.value) do.updateAttrs = values - do.ignoreProtectedAttrs = len(ignore_protected_attrs) > 0 + do.ignoreProtectedAttrs = len(ignore_protected_attrs) > 0 && ignore_protected_attrs[0] do.update() return s } func (s *Chain) Exec(sql string) *Chain { - var err error - _, err = s.db.Exec(sql) - s.err(err) + s.do(nil).exec(sql) return s } @@ -164,9 +163,7 @@ func (s *Chain) Find(out interface{}, where ...interface{}) *Chain { } func (s *Chain) Pluck(column string, value interface{}) (orm *Chain) { - do := s.do(s.value) - do.selectStr = column - do.pluck(value) + s.do(s.value).pluck(column, value) return s } diff --git a/do.go b/do.go index 21b37bed..9dcc2f3c 100644 --- a/do.go +++ b/do.go @@ -21,10 +21,9 @@ type Do struct { model *Model value interface{} - SqlResult sql.Result - - Sql string - SqlVars []interface{} + sqlResult sql.Result + sql string + sqlVars []interface{} whereClause []map[string]interface{} orClause []map[string]interface{} @@ -32,7 +31,6 @@ type Do struct { orderStrs []string offsetStr string limitStr string - operation string updateAttrs map[string]interface{} ignoreProtectedAttrs bool @@ -40,17 +38,21 @@ type Do struct { func (s *Do) tableName() string { if s.specifiedTableName == "" { + var err error + s.guessedTableName, err = s.model.tableName() + s.err(err) return s.guessedTableName } else { return s.specifiedTableName } } -func (s *Do) err(err error) { +func (s *Do) err(err error) error { if err != nil { s.Errors = append(s.Errors, err) s.chain.err(err) } + return err } func (s *Do) hasError() bool { @@ -58,18 +60,13 @@ func (s *Do) hasError() bool { } func (s *Do) setModel(value interface{}) { + s.model = &Model{data: value, driver: s.driver} s.value = value - s.model = &Model{Data: value, driver: s.driver} - var err error - if s.specifiedTableName == "" { - s.guessedTableName, err = s.model.tableName() - s.err(err) - } } func (s *Do) addToVars(value interface{}) string { - s.SqlVars = append(s.SqlVars, value) - return fmt.Sprintf("$%d", len(s.SqlVars)) + s.sqlVars = append(s.sqlVars, value) + return fmt.Sprintf("$%d", len(s.sqlVars)) } func (s *Do) exec(sql ...string) { @@ -79,23 +76,23 @@ func (s *Do) exec(sql ...string) { var err error if len(sql) == 0 { - s.SqlResult, err = s.db.Exec(s.Sql, s.SqlVars...) + s.sqlResult, err = s.db.Exec(s.sql, s.sqlVars...) } else { - s.SqlResult, err = s.db.Exec(sql[0]) + s.sqlResult, err = s.db.Exec(sql[0]) } s.err(err) } -func (s *Do) save() *Do { +func (s *Do) save() { if s.model.primaryKeyZero() { s.create() } else { s.update() } - return s + return } -func (s *Do) prepareCreateSql() *Do { +func (s *Do) prepareCreateSql() { var sqls, columns []string for key, value := range s.model.columnsAndValues("create") { @@ -103,44 +100,47 @@ func (s *Do) prepareCreateSql() *Do { sqls = append(sqls, s.addToVars(value)) } - s.Sql = fmt.Sprintf( + s.sql = fmt.Sprintf( "INSERT INTO \"%v\" (%v) VALUES (%v) %v", s.tableName(), - strings.Join(s.quoteMap(columns), ","), + strings.Join(columns, ","), strings.Join(sqls, ","), s.model.returningStr(), ) - return s + return } -func (s *Do) create() *Do { +func (s *Do) create() { s.err(s.model.callMethod("BeforeCreate")) s.err(s.model.callMethod("BeforeSave")) s.prepareCreateSql() - if len(s.Errors) == 0 { + if !s.hasError() { var id int64 if s.driver == "postgres" { - s.err(s.db.QueryRow(s.Sql, s.SqlVars...).Scan(&id)) + s.err(s.db.QueryRow(s.sql, s.sqlVars...).Scan(&id)) } else { var err error - s.SqlResult, err = s.db.Exec(s.Sql, s.SqlVars...) + s.sqlResult, err = s.db.Exec(s.sql, s.sqlVars...) s.err(err) - id, err = s.SqlResult.LastInsertId() + id, err = s.sqlResult.LastInsertId() s.err(err) } - result := reflect.ValueOf(s.model.Data).Elem() - result.FieldByName(s.model.primaryKey()).SetInt(id) - s.err(s.model.callMethod("AfterCreate")) - s.err(s.model.callMethod("AfterSave")) + if !s.hasError() { + result := reflect.ValueOf(s.value).Elem() + result.FieldByName(s.model.primaryKey()).SetInt(id) + + s.err(s.model.callMethod("AfterCreate")) + s.err(s.model.callMethod("AfterSave")) + } } - return s + return } -func (s *Do) prepareUpdateSql() *Do { +func (s *Do) prepareUpdateSql() { update_attrs := s.updateAttrs if len(update_attrs) == 0 { update_attrs = s.model.columnsAndValues("update") @@ -148,46 +148,55 @@ func (s *Do) prepareUpdateSql() *Do { var sqls []string for key, value := range update_attrs { - sqls = append(sqls, fmt.Sprintf("%v = %v", s.quote(key), s.addToVars(value))) + sqls = append(sqls, fmt.Sprintf("%v = %v", key, s.addToVars(value))) } - s.Sql = fmt.Sprintf( + s.sql = fmt.Sprintf( "UPDATE %v SET %v %v", s.tableName(), strings.Join(sqls, ", "), s.combinedSql(), ) - return s + return } -func (s *Do) update() *Do { +func (s *Do) update() { s.err(s.model.callMethod("BeforeUpdate")) s.err(s.model.callMethod("BeforeSave")) - if len(s.Errors) == 0 { - s.prepareUpdateSql().exec() + + s.prepareUpdateSql() + if !s.hasError() { + s.exec() + + if !s.hasError() { + s.err(s.model.callMethod("AfterUpdate")) + s.err(s.model.callMethod("AfterSave")) + } } - s.err(s.model.callMethod("AfterUpdate")) - s.err(s.model.callMethod("AfterSave")) - return s + return } -func (s *Do) prepareDeleteSql() *Do { - s.Sql = fmt.Sprintf("DELETE FROM %v %v", s.tableName(), s.combinedSql()) - return s +func (s *Do) prepareDeleteSql() { + s.sql = fmt.Sprintf("DELETE FROM %v %v", s.tableName(), s.combinedSql()) + return } -func (s *Do) delete() *Do { +func (s *Do) delete() { s.err(s.model.callMethod("BeforeDelete")) - if len(s.Errors) == 0 { - s.prepareDeleteSql().exec() + + s.prepareDeleteSql() + if !s.hasError() { + s.exec() + if !s.hasError() { + s.err(s.model.callMethod("AfterDelete")) + } } - s.err(s.model.callMethod("AfterDelete")) - return s + return } -func (s *Do) prepareQuerySql() *Do { - s.Sql = fmt.Sprintf("SELECT %v FROM %v %v", s.selectSql(), s.tableName(), s.combinedSql()) - return s +func (s *Do) prepareQuerySql() { + s.sql = fmt.Sprintf("SELECT %v FROM %v %v", s.selectSql(), s.tableName(), s.combinedSql()) + return } func (s *Do) query(where ...interface{}) { @@ -207,102 +216,107 @@ func (s *Do) query(where ...interface{}) { } s.prepareQuerySql() - - rows, err := s.db.Query(s.Sql, s.SqlVars...) - s.err(err) - - if err != nil { - return - } - - defer rows.Close() - - if rows.Err() != nil { - s.err(rows.Err()) - } - - counts := 0 - for rows.Next() { - counts += 1 - var dest reflect.Value - if is_slice { - dest = reflect.New(dest_type).Elem() - } else { - dest = reflect.ValueOf(s.value).Elem() + if !s.hasError() { + rows, err := s.db.Query(s.sql, s.sqlVars...) + if s.err(err) != nil { + return } - columns, _ := rows.Columns() - var values []interface{} - for _, value := range columns { - field := dest.FieldByName(snakeToUpperCamel(value)) - if field.IsValid() { - values = append(values, dest.FieldByName(snakeToUpperCamel(value)).Addr().Interface()) + defer rows.Close() + + if rows.Err() != nil { + s.err(rows.Err()) + } + + counts := 0 + for rows.Next() { + counts += 1 + var dest reflect.Value + if is_slice { + dest = reflect.New(dest_type).Elem() + } else { + dest = reflect.ValueOf(s.value).Elem() + } + + columns, _ := rows.Columns() + var values []interface{} + for _, value := range columns { + field := dest.FieldByName(snakeToUpperCamel(value)) + if field.IsValid() { + values = append(values, dest.FieldByName(snakeToUpperCamel(value)).Addr().Interface()) + } + } + s.err(rows.Scan(values...)) + + if is_slice { + dest_out.Set(reflect.Append(dest_out, dest)) } } - s.err(rows.Scan(values...)) - if is_slice { - dest_out.Set(reflect.Append(dest_out, dest)) + if (counts == 0) && !is_slice { + s.err(errors.New("Record not found!")) } } - - if (counts == 0) && !is_slice { - s.err(errors.New("Record not found!")) - } } func (s *Do) count(value interface{}) { dest_out := reflect.Indirect(reflect.ValueOf(value)) s.prepareQuerySql() - rows, err := s.db.Query(s.Sql, s.SqlVars...) - s.err(err) - for rows.Next() { - var dest int64 - s.err(rows.Scan(&dest)) - dest_out.Set(reflect.ValueOf(dest)) + if !s.hasError() { + rows, err := s.db.Query(s.sql, s.sqlVars...) + if s.err(err) != nil { + return + } + + defer rows.Close() + for rows.Next() { + var dest int64 + if s.err(rows.Scan(&dest)) == nil { + dest_out.Set(reflect.ValueOf(dest)) + } + } } return } -func (s *Do) pluck(value interface{}) *Do { - if s.hasError() { - return s - } - +func (s *Do) pluck(column string, value interface{}) { + s.selectStr = column dest_out := reflect.Indirect(reflect.ValueOf(value)) dest_type := dest_out.Type().Elem() s.prepareQuerySql() - rows, err := s.db.Query(s.Sql, s.SqlVars...) - s.err(err) - if err != nil { - return s - } - defer rows.Close() - for rows.Next() { - dest := reflect.New(dest_type).Elem().Interface() - s.err(rows.Scan(&dest)) - switch dest.(type) { - case []uint8: - if dest_type.String() == "string" { - dest = string(dest.([]uint8)) + if !s.hasError() { + rows, err := s.db.Query(s.sql, s.sqlVars...) + if s.err(err) != nil { + return + } + + defer rows.Close() + for rows.Next() { + dest := reflect.New(dest_type).Elem().Interface() + s.err(rows.Scan(&dest)) + switch dest.(type) { + case []uint8: + if dest_type.String() == "string" { + dest = string(dest.([]uint8)) + } + dest_out.Set(reflect.Append(dest_out, reflect.ValueOf(dest))) + default: + dest_out.Set(reflect.Append(dest_out, reflect.ValueOf(dest))) } - dest_out.Set(reflect.Append(dest_out, reflect.ValueOf(dest))) - default: - dest_out.Set(reflect.Append(dest_out, reflect.ValueOf(dest))) } } - return s + return } -func (s *Do) where(querystring interface{}, args ...interface{}) *Do { +func (s *Do) where(querystring interface{}, args ...interface{}) { s.whereClause = append(s.whereClause, map[string]interface{}{"query": querystring, "args": args}) - return s + return } func (s *Do) primaryCondiation(value interface{}) string { - return fmt.Sprintf("(%v = %v)", s.quote(s.model.primaryKeyDb()), value) + return fmt.Sprintf("(%v = %v)", s.model.primaryKeyDb(), value) } func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) { @@ -324,17 +338,11 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) { switch reflect.TypeOf(arg).Kind() { case reflect.Slice: // For where("id in (?)", []int64{1,2}) v := reflect.ValueOf(arg) - var temp_marks []string for i := 0; i < v.Len(); i++ { - temp_marks = append(temp_marks, "?") + temp_marks = append(temp_marks, s.addToVars(v.Index(i).Addr().Interface())) } - str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1) - - for i := 0; i < v.Len(); i++ { - str = strings.Replace(str, "?", s.addToVars(v.Index(i).Addr().Interface()), 1) - } default: str = strings.Replace(str, "?", s.addToVars(arg), 1) } @@ -421,7 +429,7 @@ func (s *Do) createTable() *Do { for _, field := range s.model.fields("null") { sqls = append(sqls, field.DbName+" "+field.SqlType) } - s.Sql = fmt.Sprintf( + s.sql = fmt.Sprintf( "CREATE TABLE \"%v\" (%v)", s.tableName(), strings.Join(sqls, ","), diff --git a/gorm_test.go b/gorm_test.go index b9b655b4..2d9e1e49 100644 --- a/gorm_test.go +++ b/gorm_test.go @@ -2,6 +2,7 @@ package gorm import ( "errors" + "fmt" _ "github.com/lib/pq" "reflect" "strconv" @@ -41,15 +42,23 @@ var ( ) func init() { - db, _ = Open("postgres", "user=gorm dbname=gorm sslmode=disable") + var err error + db, err = Open("postgres", "user=gorm dbname=gorm sslmode=disable") + if err != nil { + panic(fmt.Sprintf("No error should happen when connect database, but got %+v", err)) + } db.SetPool(10) - db.Exec("drop table users;") + err = db.Exec("drop table users;").Error + if err != nil { + fmt.Printf("Got error when try to delete table uses, %+v\n", err) + } + db.Exec("drop table products;") orm := db.CreateTable(&User{}) if orm.Error != nil { - panic("No error should raise when create table") + panic(fmt.Sprintf("No error should happen when create table, but got %+v", orm.Error)) } db.CreateTable(&Product{}) diff --git a/main.go b/main.go index f8c54a03..c93efe7c 100644 --- a/main.go +++ b/main.go @@ -17,58 +17,58 @@ func (s *DB) SetPool(n int) { s.db.SetMaxIdleConns(n) } -func (s *DB) buildORM() *Chain { +func (s *DB) buildChain() *Chain { return &Chain{db: s.db, driver: s.driver} } func (s *DB) Where(querystring interface{}, args ...interface{}) *Chain { - return s.buildORM().Where(querystring, args...) + return s.buildChain().Where(querystring, args...) } func (s *DB) First(out interface{}, where ...interface{}) *Chain { - return s.buildORM().First(out, where...) + return s.buildChain().First(out, where...) } func (s *DB) Find(out interface{}, where ...interface{}) *Chain { - return s.buildORM().Find(out, where...) + return s.buildChain().Find(out, where...) } func (s *DB) Limit(value interface{}) *Chain { - return s.buildORM().Limit(value) + return s.buildChain().Limit(value) } func (s *DB) Offset(value interface{}) *Chain { - return s.buildORM().Offset(value) + return s.buildChain().Offset(value) } func (s *DB) Order(value string, reorder ...bool) *Chain { - return s.buildORM().Order(value, reorder...) + return s.buildChain().Order(value, reorder...) } func (s *DB) Select(value interface{}) *Chain { - return s.buildORM().Select(value) + return s.buildChain().Select(value) } func (s *DB) Save(value interface{}) *Chain { - return s.buildORM().Save(value) + return s.buildChain().Save(value) } func (s *DB) Delete(value interface{}) *Chain { - return s.buildORM().Delete(value) + return s.buildChain().Delete(value) } func (s *DB) Exec(sql string) *Chain { - return s.buildORM().Exec(sql) + return s.buildChain().Exec(sql) } func (s *DB) Model(value interface{}) *Chain { - return s.buildORM().Model(value) + return s.buildChain().Model(value) } func (s *DB) Table(name string) *Chain { - return s.buildORM().Table(name) + return s.buildChain().Table(name) } func (s *DB) CreateTable(value interface{}) *Chain { - return s.buildORM().CreateTable(value) + return s.buildChain().CreateTable(value) } diff --git a/model.go b/model.go index 244a001d..18db1385 100644 --- a/model.go +++ b/model.go @@ -10,7 +10,7 @@ import ( ) type Model struct { - Data interface{} + data interface{} driver string } @@ -25,23 +25,23 @@ type Field struct { } func (m *Model) primaryKeyZero() bool { - return m.primaryKeyValue() == 0 + return m.primaryKeyValue() <= 0 } func (m *Model) primaryKeyValue() int64 { - if m.Data == nil { - return 0 + if m.data == nil { + return -1 } - t := reflect.TypeOf(m.Data).Elem() + t := reflect.TypeOf(m.data).Elem() switch t.Kind() { case reflect.Array, reflect.Chan, reflect.Map, reflect.Ptr, reflect.Slice: return 0 default: - result := reflect.ValueOf(m.Data).Elem() + result := reflect.ValueOf(m.data).Elem() value := result.FieldByName(m.primaryKey()) if value.IsValid() { - return result.FieldByName(m.primaryKey()).Interface().(int64) + return value.Interface().(int64) } else { return 0 } @@ -57,7 +57,7 @@ func (m *Model) primaryKeyDb() string { } func (m *Model) fields(operation string) (fields []Field) { - typ := reflect.TypeOf(m.Data).Elem() + typ := reflect.TypeOf(m.data).Elem() for i := 0; i < typ.NumField(); i++ { p := typ.Field(i) @@ -68,18 +68,16 @@ func (m *Model) fields(operation string) (fields []Field) { field.IsPrimaryKey = m.primaryKeyDb() == field.DbName field.AutoCreateTime = "created_at" == field.DbName field.AutoUpdateTime = "updated_at" == field.DbName - value := reflect.ValueOf(m.Data).Elem().FieldByName(p.Name) + value := reflect.ValueOf(m.data).Elem().FieldByName(p.Name) switch operation { case "create": if (field.AutoCreateTime || field.AutoUpdateTime) && value.Interface().(time.Time).IsZero() { - value = reflect.ValueOf(time.Now()) - reflect.ValueOf(m.Data).Elem().FieldByName(p.Name).Set(value) + value.Set(reflect.ValueOf(time.Now())) } case "update": if field.AutoUpdateTime { - value = reflect.ValueOf(time.Now()) - reflect.ValueOf(m.Data).Elem().FieldByName(p.Name).Set(value) + value.Set(reflect.ValueOf(time.Now())) } default: } @@ -107,12 +105,12 @@ func (m *Model) columnsAndValues(operation string) map[string]interface{} { } func (m *Model) tableName() (str string, err error) { - if m.Data == nil { + if m.data == nil { err = errors.New("Model haven't been set") return } - t := reflect.TypeOf(m.Data) + t := reflect.TypeOf(m.data) for { c := false switch t.Kind() { @@ -138,11 +136,11 @@ func (m *Model) tableName() (str string, err error) { } func (m *Model) callMethod(method string) error { - if m.Data == nil { + if m.data == nil { return nil } - fm := reflect.ValueOf(m.Data).MethodByName(method) + fm := reflect.ValueOf(m.data).MethodByName(method) if fm.IsValid() { v := fm.Call([]reflect.Value{}) if len(v) > 0 { @@ -154,13 +152,13 @@ func (m *Model) callMethod(method string) error { return nil } -func (model *Model) missingColumns() (results []string) { - return -} - func (model *Model) returningStr() (str string) { if model.driver == "postgres" { str = fmt.Sprintf("RETURNING \"%v\"", model.primaryKeyDb()) } return } + +func (model *Model) missingColumns() (results []string) { + return +} diff --git a/utils.go b/utils.go index 40798a0c..41791b7a 100644 --- a/utils.go +++ b/utils.go @@ -7,17 +7,6 @@ import ( "strings" ) -func (s *Do) quote(value string) string { - return "\"" + value + "\"" -} - -func (s *Do) quoteMap(values []string) (results []string) { - for _, value := range values { - results = append(results, s.quote(value)) - } - return -} - func toSnake(s string) string { buf := bytes.NewBufferString("") for i, v := range s {