Cleanup code

This commit is contained in:
Jinzhu 2013-10-29 07:39:26 +08:00
parent 84b280c0ff
commit bc785a9173
7 changed files with 189 additions and 187 deletions

View File

@ -217,10 +217,11 @@ db.Where("mail_type = ?", "TEXT").Find(&users1).Table("deleted_users").First(&us
## TODO
* Soft Delete
* Query with map or struct
* SubStruct
* Index, Unique, Valiations
* Auto Migration
* FindOrInitialize / FindOrCreate
* SQL Log
* Auto Migration
* Index, Unique, Valiations
* SQL Query with goroutines
* Only tested with postgres, confirm works with other database adaptors

View File

@ -26,11 +26,12 @@ type Chain struct {
specifiedTableName string
}
func (s *Chain) err(err error) {
func (s *Chain) err(err error) error {
if err != nil {
s.Errors = append(s.Errors, err)
s.Error = err
}
return err
}
func (s *Chain) do(value interface{}) *Do {
@ -136,18 +137,16 @@ func (s *Chain) Update(column string, value interface{}) *Chain {
return s.Updates(map[string]interface{}{column: value}, true)
}
func (s *Chain) Updates(values map[string]interface{}, ignore_protected_attrs ...interface{}) *Chain {
func (s *Chain) Updates(values map[string]interface{}, ignore_protected_attrs ...bool) *Chain {
do := s.do(s.value)
do.updateAttrs = values
do.ignoreProtectedAttrs = len(ignore_protected_attrs) > 0
do.ignoreProtectedAttrs = len(ignore_protected_attrs) > 0 && ignore_protected_attrs[0]
do.update()
return s
}
func (s *Chain) Exec(sql string) *Chain {
var err error
_, err = s.db.Exec(sql)
s.err(err)
s.do(nil).exec(sql)
return s
}
@ -164,9 +163,7 @@ func (s *Chain) Find(out interface{}, where ...interface{}) *Chain {
}
func (s *Chain) Pluck(column string, value interface{}) (orm *Chain) {
do := s.do(s.value)
do.selectStr = column
do.pluck(value)
s.do(s.value).pluck(column, value)
return s
}

262
do.go
View File

@ -21,10 +21,9 @@ type Do struct {
model *Model
value interface{}
SqlResult sql.Result
Sql string
SqlVars []interface{}
sqlResult sql.Result
sql string
sqlVars []interface{}
whereClause []map[string]interface{}
orClause []map[string]interface{}
@ -32,7 +31,6 @@ type Do struct {
orderStrs []string
offsetStr string
limitStr string
operation string
updateAttrs map[string]interface{}
ignoreProtectedAttrs bool
@ -40,17 +38,21 @@ type Do struct {
func (s *Do) tableName() string {
if s.specifiedTableName == "" {
var err error
s.guessedTableName, err = s.model.tableName()
s.err(err)
return s.guessedTableName
} else {
return s.specifiedTableName
}
}
func (s *Do) err(err error) {
func (s *Do) err(err error) error {
if err != nil {
s.Errors = append(s.Errors, err)
s.chain.err(err)
}
return err
}
func (s *Do) hasError() bool {
@ -58,18 +60,13 @@ func (s *Do) hasError() bool {
}
func (s *Do) setModel(value interface{}) {
s.model = &Model{data: value, driver: s.driver}
s.value = value
s.model = &Model{Data: value, driver: s.driver}
var err error
if s.specifiedTableName == "" {
s.guessedTableName, err = s.model.tableName()
s.err(err)
}
}
func (s *Do) addToVars(value interface{}) string {
s.SqlVars = append(s.SqlVars, value)
return fmt.Sprintf("$%d", len(s.SqlVars))
s.sqlVars = append(s.sqlVars, value)
return fmt.Sprintf("$%d", len(s.sqlVars))
}
func (s *Do) exec(sql ...string) {
@ -79,23 +76,23 @@ func (s *Do) exec(sql ...string) {
var err error
if len(sql) == 0 {
s.SqlResult, err = s.db.Exec(s.Sql, s.SqlVars...)
s.sqlResult, err = s.db.Exec(s.sql, s.sqlVars...)
} else {
s.SqlResult, err = s.db.Exec(sql[0])
s.sqlResult, err = s.db.Exec(sql[0])
}
s.err(err)
}
func (s *Do) save() *Do {
func (s *Do) save() {
if s.model.primaryKeyZero() {
s.create()
} else {
s.update()
}
return s
return
}
func (s *Do) prepareCreateSql() *Do {
func (s *Do) prepareCreateSql() {
var sqls, columns []string
for key, value := range s.model.columnsAndValues("create") {
@ -103,44 +100,47 @@ func (s *Do) prepareCreateSql() *Do {
sqls = append(sqls, s.addToVars(value))
}
s.Sql = fmt.Sprintf(
s.sql = fmt.Sprintf(
"INSERT INTO \"%v\" (%v) VALUES (%v) %v",
s.tableName(),
strings.Join(s.quoteMap(columns), ","),
strings.Join(columns, ","),
strings.Join(sqls, ","),
s.model.returningStr(),
)
return s
return
}
func (s *Do) create() *Do {
func (s *Do) create() {
s.err(s.model.callMethod("BeforeCreate"))
s.err(s.model.callMethod("BeforeSave"))
s.prepareCreateSql()
if len(s.Errors) == 0 {
if !s.hasError() {
var id int64
if s.driver == "postgres" {
s.err(s.db.QueryRow(s.Sql, s.SqlVars...).Scan(&id))
s.err(s.db.QueryRow(s.sql, s.sqlVars...).Scan(&id))
} else {
var err error
s.SqlResult, err = s.db.Exec(s.Sql, s.SqlVars...)
s.sqlResult, err = s.db.Exec(s.sql, s.sqlVars...)
s.err(err)
id, err = s.SqlResult.LastInsertId()
id, err = s.sqlResult.LastInsertId()
s.err(err)
}
result := reflect.ValueOf(s.model.Data).Elem()
result.FieldByName(s.model.primaryKey()).SetInt(id)
s.err(s.model.callMethod("AfterCreate"))
s.err(s.model.callMethod("AfterSave"))
if !s.hasError() {
result := reflect.ValueOf(s.value).Elem()
result.FieldByName(s.model.primaryKey()).SetInt(id)
s.err(s.model.callMethod("AfterCreate"))
s.err(s.model.callMethod("AfterSave"))
}
}
return s
return
}
func (s *Do) prepareUpdateSql() *Do {
func (s *Do) prepareUpdateSql() {
update_attrs := s.updateAttrs
if len(update_attrs) == 0 {
update_attrs = s.model.columnsAndValues("update")
@ -148,46 +148,55 @@ func (s *Do) prepareUpdateSql() *Do {
var sqls []string
for key, value := range update_attrs {
sqls = append(sqls, fmt.Sprintf("%v = %v", s.quote(key), s.addToVars(value)))
sqls = append(sqls, fmt.Sprintf("%v = %v", key, s.addToVars(value)))
}
s.Sql = fmt.Sprintf(
s.sql = fmt.Sprintf(
"UPDATE %v SET %v %v",
s.tableName(),
strings.Join(sqls, ", "),
s.combinedSql(),
)
return s
return
}
func (s *Do) update() *Do {
func (s *Do) update() {
s.err(s.model.callMethod("BeforeUpdate"))
s.err(s.model.callMethod("BeforeSave"))
if len(s.Errors) == 0 {
s.prepareUpdateSql().exec()
s.prepareUpdateSql()
if !s.hasError() {
s.exec()
if !s.hasError() {
s.err(s.model.callMethod("AfterUpdate"))
s.err(s.model.callMethod("AfterSave"))
}
}
s.err(s.model.callMethod("AfterUpdate"))
s.err(s.model.callMethod("AfterSave"))
return s
return
}
func (s *Do) prepareDeleteSql() *Do {
s.Sql = fmt.Sprintf("DELETE FROM %v %v", s.tableName(), s.combinedSql())
return s
func (s *Do) prepareDeleteSql() {
s.sql = fmt.Sprintf("DELETE FROM %v %v", s.tableName(), s.combinedSql())
return
}
func (s *Do) delete() *Do {
func (s *Do) delete() {
s.err(s.model.callMethod("BeforeDelete"))
if len(s.Errors) == 0 {
s.prepareDeleteSql().exec()
s.prepareDeleteSql()
if !s.hasError() {
s.exec()
if !s.hasError() {
s.err(s.model.callMethod("AfterDelete"))
}
}
s.err(s.model.callMethod("AfterDelete"))
return s
return
}
func (s *Do) prepareQuerySql() *Do {
s.Sql = fmt.Sprintf("SELECT %v FROM %v %v", s.selectSql(), s.tableName(), s.combinedSql())
return s
func (s *Do) prepareQuerySql() {
s.sql = fmt.Sprintf("SELECT %v FROM %v %v", s.selectSql(), s.tableName(), s.combinedSql())
return
}
func (s *Do) query(where ...interface{}) {
@ -207,102 +216,107 @@ func (s *Do) query(where ...interface{}) {
}
s.prepareQuerySql()
rows, err := s.db.Query(s.Sql, s.SqlVars...)
s.err(err)
if err != nil {
return
}
defer rows.Close()
if rows.Err() != nil {
s.err(rows.Err())
}
counts := 0
for rows.Next() {
counts += 1
var dest reflect.Value
if is_slice {
dest = reflect.New(dest_type).Elem()
} else {
dest = reflect.ValueOf(s.value).Elem()
if !s.hasError() {
rows, err := s.db.Query(s.sql, s.sqlVars...)
if s.err(err) != nil {
return
}
columns, _ := rows.Columns()
var values []interface{}
for _, value := range columns {
field := dest.FieldByName(snakeToUpperCamel(value))
if field.IsValid() {
values = append(values, dest.FieldByName(snakeToUpperCamel(value)).Addr().Interface())
defer rows.Close()
if rows.Err() != nil {
s.err(rows.Err())
}
counts := 0
for rows.Next() {
counts += 1
var dest reflect.Value
if is_slice {
dest = reflect.New(dest_type).Elem()
} else {
dest = reflect.ValueOf(s.value).Elem()
}
columns, _ := rows.Columns()
var values []interface{}
for _, value := range columns {
field := dest.FieldByName(snakeToUpperCamel(value))
if field.IsValid() {
values = append(values, dest.FieldByName(snakeToUpperCamel(value)).Addr().Interface())
}
}
s.err(rows.Scan(values...))
if is_slice {
dest_out.Set(reflect.Append(dest_out, dest))
}
}
s.err(rows.Scan(values...))
if is_slice {
dest_out.Set(reflect.Append(dest_out, dest))
if (counts == 0) && !is_slice {
s.err(errors.New("Record not found!"))
}
}
if (counts == 0) && !is_slice {
s.err(errors.New("Record not found!"))
}
}
func (s *Do) count(value interface{}) {
dest_out := reflect.Indirect(reflect.ValueOf(value))
s.prepareQuerySql()
rows, err := s.db.Query(s.Sql, s.SqlVars...)
s.err(err)
for rows.Next() {
var dest int64
s.err(rows.Scan(&dest))
dest_out.Set(reflect.ValueOf(dest))
if !s.hasError() {
rows, err := s.db.Query(s.sql, s.sqlVars...)
if s.err(err) != nil {
return
}
defer rows.Close()
for rows.Next() {
var dest int64
if s.err(rows.Scan(&dest)) == nil {
dest_out.Set(reflect.ValueOf(dest))
}
}
}
return
}
func (s *Do) pluck(value interface{}) *Do {
if s.hasError() {
return s
}
func (s *Do) pluck(column string, value interface{}) {
s.selectStr = column
dest_out := reflect.Indirect(reflect.ValueOf(value))
dest_type := dest_out.Type().Elem()
s.prepareQuerySql()
rows, err := s.db.Query(s.Sql, s.SqlVars...)
s.err(err)
if err != nil {
return s
}
defer rows.Close()
for rows.Next() {
dest := reflect.New(dest_type).Elem().Interface()
s.err(rows.Scan(&dest))
switch dest.(type) {
case []uint8:
if dest_type.String() == "string" {
dest = string(dest.([]uint8))
if !s.hasError() {
rows, err := s.db.Query(s.sql, s.sqlVars...)
if s.err(err) != nil {
return
}
defer rows.Close()
for rows.Next() {
dest := reflect.New(dest_type).Elem().Interface()
s.err(rows.Scan(&dest))
switch dest.(type) {
case []uint8:
if dest_type.String() == "string" {
dest = string(dest.([]uint8))
}
dest_out.Set(reflect.Append(dest_out, reflect.ValueOf(dest)))
default:
dest_out.Set(reflect.Append(dest_out, reflect.ValueOf(dest)))
}
dest_out.Set(reflect.Append(dest_out, reflect.ValueOf(dest)))
default:
dest_out.Set(reflect.Append(dest_out, reflect.ValueOf(dest)))
}
}
return s
return
}
func (s *Do) where(querystring interface{}, args ...interface{}) *Do {
func (s *Do) where(querystring interface{}, args ...interface{}) {
s.whereClause = append(s.whereClause, map[string]interface{}{"query": querystring, "args": args})
return s
return
}
func (s *Do) primaryCondiation(value interface{}) string {
return fmt.Sprintf("(%v = %v)", s.quote(s.model.primaryKeyDb()), value)
return fmt.Sprintf("(%v = %v)", s.model.primaryKeyDb(), value)
}
func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
@ -324,17 +338,11 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
switch reflect.TypeOf(arg).Kind() {
case reflect.Slice: // For where("id in (?)", []int64{1,2})
v := reflect.ValueOf(arg)
var temp_marks []string
for i := 0; i < v.Len(); i++ {
temp_marks = append(temp_marks, "?")
temp_marks = append(temp_marks, s.addToVars(v.Index(i).Addr().Interface()))
}
str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1)
for i := 0; i < v.Len(); i++ {
str = strings.Replace(str, "?", s.addToVars(v.Index(i).Addr().Interface()), 1)
}
default:
str = strings.Replace(str, "?", s.addToVars(arg), 1)
}
@ -421,7 +429,7 @@ func (s *Do) createTable() *Do {
for _, field := range s.model.fields("null") {
sqls = append(sqls, field.DbName+" "+field.SqlType)
}
s.Sql = fmt.Sprintf(
s.sql = fmt.Sprintf(
"CREATE TABLE \"%v\" (%v)",
s.tableName(),
strings.Join(sqls, ","),

View File

@ -2,6 +2,7 @@ package gorm
import (
"errors"
"fmt"
_ "github.com/lib/pq"
"reflect"
"strconv"
@ -41,15 +42,23 @@ var (
)
func init() {
db, _ = Open("postgres", "user=gorm dbname=gorm sslmode=disable")
var err error
db, err = Open("postgres", "user=gorm dbname=gorm sslmode=disable")
if err != nil {
panic(fmt.Sprintf("No error should happen when connect database, but got %+v", err))
}
db.SetPool(10)
db.Exec("drop table users;")
err = db.Exec("drop table users;").Error
if err != nil {
fmt.Printf("Got error when try to delete table uses, %+v\n", err)
}
db.Exec("drop table products;")
orm := db.CreateTable(&User{})
if orm.Error != nil {
panic("No error should raise when create table")
panic(fmt.Sprintf("No error should happen when create table, but got %+v", orm.Error))
}
db.CreateTable(&Product{})

28
main.go
View File

@ -17,58 +17,58 @@ func (s *DB) SetPool(n int) {
s.db.SetMaxIdleConns(n)
}
func (s *DB) buildORM() *Chain {
func (s *DB) buildChain() *Chain {
return &Chain{db: s.db, driver: s.driver}
}
func (s *DB) Where(querystring interface{}, args ...interface{}) *Chain {
return s.buildORM().Where(querystring, args...)
return s.buildChain().Where(querystring, args...)
}
func (s *DB) First(out interface{}, where ...interface{}) *Chain {
return s.buildORM().First(out, where...)
return s.buildChain().First(out, where...)
}
func (s *DB) Find(out interface{}, where ...interface{}) *Chain {
return s.buildORM().Find(out, where...)
return s.buildChain().Find(out, where...)
}
func (s *DB) Limit(value interface{}) *Chain {
return s.buildORM().Limit(value)
return s.buildChain().Limit(value)
}
func (s *DB) Offset(value interface{}) *Chain {
return s.buildORM().Offset(value)
return s.buildChain().Offset(value)
}
func (s *DB) Order(value string, reorder ...bool) *Chain {
return s.buildORM().Order(value, reorder...)
return s.buildChain().Order(value, reorder...)
}
func (s *DB) Select(value interface{}) *Chain {
return s.buildORM().Select(value)
return s.buildChain().Select(value)
}
func (s *DB) Save(value interface{}) *Chain {
return s.buildORM().Save(value)
return s.buildChain().Save(value)
}
func (s *DB) Delete(value interface{}) *Chain {
return s.buildORM().Delete(value)
return s.buildChain().Delete(value)
}
func (s *DB) Exec(sql string) *Chain {
return s.buildORM().Exec(sql)
return s.buildChain().Exec(sql)
}
func (s *DB) Model(value interface{}) *Chain {
return s.buildORM().Model(value)
return s.buildChain().Model(value)
}
func (s *DB) Table(name string) *Chain {
return s.buildORM().Table(name)
return s.buildChain().Table(name)
}
func (s *DB) CreateTable(value interface{}) *Chain {
return s.buildORM().CreateTable(value)
return s.buildChain().CreateTable(value)
}

View File

@ -10,7 +10,7 @@ import (
)
type Model struct {
Data interface{}
data interface{}
driver string
}
@ -25,23 +25,23 @@ type Field struct {
}
func (m *Model) primaryKeyZero() bool {
return m.primaryKeyValue() == 0
return m.primaryKeyValue() <= 0
}
func (m *Model) primaryKeyValue() int64 {
if m.Data == nil {
return 0
if m.data == nil {
return -1
}
t := reflect.TypeOf(m.Data).Elem()
t := reflect.TypeOf(m.data).Elem()
switch t.Kind() {
case reflect.Array, reflect.Chan, reflect.Map, reflect.Ptr, reflect.Slice:
return 0
default:
result := reflect.ValueOf(m.Data).Elem()
result := reflect.ValueOf(m.data).Elem()
value := result.FieldByName(m.primaryKey())
if value.IsValid() {
return result.FieldByName(m.primaryKey()).Interface().(int64)
return value.Interface().(int64)
} else {
return 0
}
@ -57,7 +57,7 @@ func (m *Model) primaryKeyDb() string {
}
func (m *Model) fields(operation string) (fields []Field) {
typ := reflect.TypeOf(m.Data).Elem()
typ := reflect.TypeOf(m.data).Elem()
for i := 0; i < typ.NumField(); i++ {
p := typ.Field(i)
@ -68,18 +68,16 @@ func (m *Model) fields(operation string) (fields []Field) {
field.IsPrimaryKey = m.primaryKeyDb() == field.DbName
field.AutoCreateTime = "created_at" == field.DbName
field.AutoUpdateTime = "updated_at" == field.DbName
value := reflect.ValueOf(m.Data).Elem().FieldByName(p.Name)
value := reflect.ValueOf(m.data).Elem().FieldByName(p.Name)
switch operation {
case "create":
if (field.AutoCreateTime || field.AutoUpdateTime) && value.Interface().(time.Time).IsZero() {
value = reflect.ValueOf(time.Now())
reflect.ValueOf(m.Data).Elem().FieldByName(p.Name).Set(value)
value.Set(reflect.ValueOf(time.Now()))
}
case "update":
if field.AutoUpdateTime {
value = reflect.ValueOf(time.Now())
reflect.ValueOf(m.Data).Elem().FieldByName(p.Name).Set(value)
value.Set(reflect.ValueOf(time.Now()))
}
default:
}
@ -107,12 +105,12 @@ func (m *Model) columnsAndValues(operation string) map[string]interface{} {
}
func (m *Model) tableName() (str string, err error) {
if m.Data == nil {
if m.data == nil {
err = errors.New("Model haven't been set")
return
}
t := reflect.TypeOf(m.Data)
t := reflect.TypeOf(m.data)
for {
c := false
switch t.Kind() {
@ -138,11 +136,11 @@ func (m *Model) tableName() (str string, err error) {
}
func (m *Model) callMethod(method string) error {
if m.Data == nil {
if m.data == nil {
return nil
}
fm := reflect.ValueOf(m.Data).MethodByName(method)
fm := reflect.ValueOf(m.data).MethodByName(method)
if fm.IsValid() {
v := fm.Call([]reflect.Value{})
if len(v) > 0 {
@ -154,13 +152,13 @@ func (m *Model) callMethod(method string) error {
return nil
}
func (model *Model) missingColumns() (results []string) {
return
}
func (model *Model) returningStr() (str string) {
if model.driver == "postgres" {
str = fmt.Sprintf("RETURNING \"%v\"", model.primaryKeyDb())
}
return
}
func (model *Model) missingColumns() (results []string) {
return
}

View File

@ -7,17 +7,6 @@ import (
"strings"
)
func (s *Do) quote(value string) string {
return "\"" + value + "\""
}
func (s *Do) quoteMap(values []string) (results []string) {
for _, value := range values {
results = append(results, s.quote(value))
}
return
}
func toSnake(s string) string {
buf := bytes.NewBufferString("")
for i, v := range s {