mirror of https://github.com/go-gorm/gorm.git
Cleanup code
This commit is contained in:
parent
84b280c0ff
commit
bc785a9173
|
@ -217,10 +217,11 @@ db.Where("mail_type = ?", "TEXT").Find(&users1).Table("deleted_users").First(&us
|
|||
## TODO
|
||||
* Soft Delete
|
||||
* Query with map or struct
|
||||
* SubStruct
|
||||
* Index, Unique, Valiations
|
||||
* Auto Migration
|
||||
* FindOrInitialize / FindOrCreate
|
||||
* SQL Log
|
||||
* Auto Migration
|
||||
* Index, Unique, Valiations
|
||||
* SQL Query with goroutines
|
||||
* Only tested with postgres, confirm works with other database adaptors
|
||||
|
||||
|
|
15
chain.go
15
chain.go
|
@ -26,11 +26,12 @@ type Chain struct {
|
|||
specifiedTableName string
|
||||
}
|
||||
|
||||
func (s *Chain) err(err error) {
|
||||
func (s *Chain) err(err error) error {
|
||||
if err != nil {
|
||||
s.Errors = append(s.Errors, err)
|
||||
s.Error = err
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Chain) do(value interface{}) *Do {
|
||||
|
@ -136,18 +137,16 @@ func (s *Chain) Update(column string, value interface{}) *Chain {
|
|||
return s.Updates(map[string]interface{}{column: value}, true)
|
||||
}
|
||||
|
||||
func (s *Chain) Updates(values map[string]interface{}, ignore_protected_attrs ...interface{}) *Chain {
|
||||
func (s *Chain) Updates(values map[string]interface{}, ignore_protected_attrs ...bool) *Chain {
|
||||
do := s.do(s.value)
|
||||
do.updateAttrs = values
|
||||
do.ignoreProtectedAttrs = len(ignore_protected_attrs) > 0
|
||||
do.ignoreProtectedAttrs = len(ignore_protected_attrs) > 0 && ignore_protected_attrs[0]
|
||||
do.update()
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Chain) Exec(sql string) *Chain {
|
||||
var err error
|
||||
_, err = s.db.Exec(sql)
|
||||
s.err(err)
|
||||
s.do(nil).exec(sql)
|
||||
return s
|
||||
}
|
||||
|
||||
|
@ -164,9 +163,7 @@ func (s *Chain) Find(out interface{}, where ...interface{}) *Chain {
|
|||
}
|
||||
|
||||
func (s *Chain) Pluck(column string, value interface{}) (orm *Chain) {
|
||||
do := s.do(s.value)
|
||||
do.selectStr = column
|
||||
do.pluck(value)
|
||||
s.do(s.value).pluck(column, value)
|
||||
return s
|
||||
}
|
||||
|
||||
|
|
262
do.go
262
do.go
|
@ -21,10 +21,9 @@ type Do struct {
|
|||
|
||||
model *Model
|
||||
value interface{}
|
||||
SqlResult sql.Result
|
||||
|
||||
Sql string
|
||||
SqlVars []interface{}
|
||||
sqlResult sql.Result
|
||||
sql string
|
||||
sqlVars []interface{}
|
||||
|
||||
whereClause []map[string]interface{}
|
||||
orClause []map[string]interface{}
|
||||
|
@ -32,7 +31,6 @@ type Do struct {
|
|||
orderStrs []string
|
||||
offsetStr string
|
||||
limitStr string
|
||||
operation string
|
||||
|
||||
updateAttrs map[string]interface{}
|
||||
ignoreProtectedAttrs bool
|
||||
|
@ -40,17 +38,21 @@ type Do struct {
|
|||
|
||||
func (s *Do) tableName() string {
|
||||
if s.specifiedTableName == "" {
|
||||
var err error
|
||||
s.guessedTableName, err = s.model.tableName()
|
||||
s.err(err)
|
||||
return s.guessedTableName
|
||||
} else {
|
||||
return s.specifiedTableName
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Do) err(err error) {
|
||||
func (s *Do) err(err error) error {
|
||||
if err != nil {
|
||||
s.Errors = append(s.Errors, err)
|
||||
s.chain.err(err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Do) hasError() bool {
|
||||
|
@ -58,18 +60,13 @@ func (s *Do) hasError() bool {
|
|||
}
|
||||
|
||||
func (s *Do) setModel(value interface{}) {
|
||||
s.model = &Model{data: value, driver: s.driver}
|
||||
s.value = value
|
||||
s.model = &Model{Data: value, driver: s.driver}
|
||||
var err error
|
||||
if s.specifiedTableName == "" {
|
||||
s.guessedTableName, err = s.model.tableName()
|
||||
s.err(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Do) addToVars(value interface{}) string {
|
||||
s.SqlVars = append(s.SqlVars, value)
|
||||
return fmt.Sprintf("$%d", len(s.SqlVars))
|
||||
s.sqlVars = append(s.sqlVars, value)
|
||||
return fmt.Sprintf("$%d", len(s.sqlVars))
|
||||
}
|
||||
|
||||
func (s *Do) exec(sql ...string) {
|
||||
|
@ -79,23 +76,23 @@ func (s *Do) exec(sql ...string) {
|
|||
|
||||
var err error
|
||||
if len(sql) == 0 {
|
||||
s.SqlResult, err = s.db.Exec(s.Sql, s.SqlVars...)
|
||||
s.sqlResult, err = s.db.Exec(s.sql, s.sqlVars...)
|
||||
} else {
|
||||
s.SqlResult, err = s.db.Exec(sql[0])
|
||||
s.sqlResult, err = s.db.Exec(sql[0])
|
||||
}
|
||||
s.err(err)
|
||||
}
|
||||
|
||||
func (s *Do) save() *Do {
|
||||
func (s *Do) save() {
|
||||
if s.model.primaryKeyZero() {
|
||||
s.create()
|
||||
} else {
|
||||
s.update()
|
||||
}
|
||||
return s
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Do) prepareCreateSql() *Do {
|
||||
func (s *Do) prepareCreateSql() {
|
||||
var sqls, columns []string
|
||||
|
||||
for key, value := range s.model.columnsAndValues("create") {
|
||||
|
@ -103,44 +100,47 @@ func (s *Do) prepareCreateSql() *Do {
|
|||
sqls = append(sqls, s.addToVars(value))
|
||||
}
|
||||
|
||||
s.Sql = fmt.Sprintf(
|
||||
s.sql = fmt.Sprintf(
|
||||
"INSERT INTO \"%v\" (%v) VALUES (%v) %v",
|
||||
s.tableName(),
|
||||
strings.Join(s.quoteMap(columns), ","),
|
||||
strings.Join(columns, ","),
|
||||
strings.Join(sqls, ","),
|
||||
s.model.returningStr(),
|
||||
)
|
||||
return s
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Do) create() *Do {
|
||||
func (s *Do) create() {
|
||||
s.err(s.model.callMethod("BeforeCreate"))
|
||||
s.err(s.model.callMethod("BeforeSave"))
|
||||
|
||||
s.prepareCreateSql()
|
||||
|
||||
if len(s.Errors) == 0 {
|
||||
if !s.hasError() {
|
||||
var id int64
|
||||
if s.driver == "postgres" {
|
||||
s.err(s.db.QueryRow(s.Sql, s.SqlVars...).Scan(&id))
|
||||
s.err(s.db.QueryRow(s.sql, s.sqlVars...).Scan(&id))
|
||||
} else {
|
||||
var err error
|
||||
s.SqlResult, err = s.db.Exec(s.Sql, s.SqlVars...)
|
||||
s.sqlResult, err = s.db.Exec(s.sql, s.sqlVars...)
|
||||
s.err(err)
|
||||
id, err = s.SqlResult.LastInsertId()
|
||||
id, err = s.sqlResult.LastInsertId()
|
||||
s.err(err)
|
||||
}
|
||||
result := reflect.ValueOf(s.model.Data).Elem()
|
||||
result.FieldByName(s.model.primaryKey()).SetInt(id)
|
||||
|
||||
s.err(s.model.callMethod("AfterCreate"))
|
||||
s.err(s.model.callMethod("AfterSave"))
|
||||
if !s.hasError() {
|
||||
result := reflect.ValueOf(s.value).Elem()
|
||||
result.FieldByName(s.model.primaryKey()).SetInt(id)
|
||||
|
||||
s.err(s.model.callMethod("AfterCreate"))
|
||||
s.err(s.model.callMethod("AfterSave"))
|
||||
}
|
||||
}
|
||||
|
||||
return s
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Do) prepareUpdateSql() *Do {
|
||||
func (s *Do) prepareUpdateSql() {
|
||||
update_attrs := s.updateAttrs
|
||||
if len(update_attrs) == 0 {
|
||||
update_attrs = s.model.columnsAndValues("update")
|
||||
|
@ -148,46 +148,55 @@ func (s *Do) prepareUpdateSql() *Do {
|
|||
|
||||
var sqls []string
|
||||
for key, value := range update_attrs {
|
||||
sqls = append(sqls, fmt.Sprintf("%v = %v", s.quote(key), s.addToVars(value)))
|
||||
sqls = append(sqls, fmt.Sprintf("%v = %v", key, s.addToVars(value)))
|
||||
}
|
||||
|
||||
s.Sql = fmt.Sprintf(
|
||||
s.sql = fmt.Sprintf(
|
||||
"UPDATE %v SET %v %v",
|
||||
s.tableName(),
|
||||
strings.Join(sqls, ", "),
|
||||
s.combinedSql(),
|
||||
)
|
||||
return s
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Do) update() *Do {
|
||||
func (s *Do) update() {
|
||||
s.err(s.model.callMethod("BeforeUpdate"))
|
||||
s.err(s.model.callMethod("BeforeSave"))
|
||||
if len(s.Errors) == 0 {
|
||||
s.prepareUpdateSql().exec()
|
||||
|
||||
s.prepareUpdateSql()
|
||||
if !s.hasError() {
|
||||
s.exec()
|
||||
|
||||
if !s.hasError() {
|
||||
s.err(s.model.callMethod("AfterUpdate"))
|
||||
s.err(s.model.callMethod("AfterSave"))
|
||||
}
|
||||
}
|
||||
s.err(s.model.callMethod("AfterUpdate"))
|
||||
s.err(s.model.callMethod("AfterSave"))
|
||||
return s
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Do) prepareDeleteSql() *Do {
|
||||
s.Sql = fmt.Sprintf("DELETE FROM %v %v", s.tableName(), s.combinedSql())
|
||||
return s
|
||||
func (s *Do) prepareDeleteSql() {
|
||||
s.sql = fmt.Sprintf("DELETE FROM %v %v", s.tableName(), s.combinedSql())
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Do) delete() *Do {
|
||||
func (s *Do) delete() {
|
||||
s.err(s.model.callMethod("BeforeDelete"))
|
||||
if len(s.Errors) == 0 {
|
||||
s.prepareDeleteSql().exec()
|
||||
|
||||
s.prepareDeleteSql()
|
||||
if !s.hasError() {
|
||||
s.exec()
|
||||
if !s.hasError() {
|
||||
s.err(s.model.callMethod("AfterDelete"))
|
||||
}
|
||||
}
|
||||
s.err(s.model.callMethod("AfterDelete"))
|
||||
return s
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Do) prepareQuerySql() *Do {
|
||||
s.Sql = fmt.Sprintf("SELECT %v FROM %v %v", s.selectSql(), s.tableName(), s.combinedSql())
|
||||
return s
|
||||
func (s *Do) prepareQuerySql() {
|
||||
s.sql = fmt.Sprintf("SELECT %v FROM %v %v", s.selectSql(), s.tableName(), s.combinedSql())
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Do) query(where ...interface{}) {
|
||||
|
@ -207,102 +216,107 @@ func (s *Do) query(where ...interface{}) {
|
|||
}
|
||||
|
||||
s.prepareQuerySql()
|
||||
|
||||
rows, err := s.db.Query(s.Sql, s.SqlVars...)
|
||||
s.err(err)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
if rows.Err() != nil {
|
||||
s.err(rows.Err())
|
||||
}
|
||||
|
||||
counts := 0
|
||||
for rows.Next() {
|
||||
counts += 1
|
||||
var dest reflect.Value
|
||||
if is_slice {
|
||||
dest = reflect.New(dest_type).Elem()
|
||||
} else {
|
||||
dest = reflect.ValueOf(s.value).Elem()
|
||||
if !s.hasError() {
|
||||
rows, err := s.db.Query(s.sql, s.sqlVars...)
|
||||
if s.err(err) != nil {
|
||||
return
|
||||
}
|
||||
|
||||
columns, _ := rows.Columns()
|
||||
var values []interface{}
|
||||
for _, value := range columns {
|
||||
field := dest.FieldByName(snakeToUpperCamel(value))
|
||||
if field.IsValid() {
|
||||
values = append(values, dest.FieldByName(snakeToUpperCamel(value)).Addr().Interface())
|
||||
defer rows.Close()
|
||||
|
||||
if rows.Err() != nil {
|
||||
s.err(rows.Err())
|
||||
}
|
||||
|
||||
counts := 0
|
||||
for rows.Next() {
|
||||
counts += 1
|
||||
var dest reflect.Value
|
||||
if is_slice {
|
||||
dest = reflect.New(dest_type).Elem()
|
||||
} else {
|
||||
dest = reflect.ValueOf(s.value).Elem()
|
||||
}
|
||||
|
||||
columns, _ := rows.Columns()
|
||||
var values []interface{}
|
||||
for _, value := range columns {
|
||||
field := dest.FieldByName(snakeToUpperCamel(value))
|
||||
if field.IsValid() {
|
||||
values = append(values, dest.FieldByName(snakeToUpperCamel(value)).Addr().Interface())
|
||||
}
|
||||
}
|
||||
s.err(rows.Scan(values...))
|
||||
|
||||
if is_slice {
|
||||
dest_out.Set(reflect.Append(dest_out, dest))
|
||||
}
|
||||
}
|
||||
s.err(rows.Scan(values...))
|
||||
|
||||
if is_slice {
|
||||
dest_out.Set(reflect.Append(dest_out, dest))
|
||||
if (counts == 0) && !is_slice {
|
||||
s.err(errors.New("Record not found!"))
|
||||
}
|
||||
}
|
||||
|
||||
if (counts == 0) && !is_slice {
|
||||
s.err(errors.New("Record not found!"))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Do) count(value interface{}) {
|
||||
dest_out := reflect.Indirect(reflect.ValueOf(value))
|
||||
|
||||
s.prepareQuerySql()
|
||||
rows, err := s.db.Query(s.Sql, s.SqlVars...)
|
||||
s.err(err)
|
||||
for rows.Next() {
|
||||
var dest int64
|
||||
s.err(rows.Scan(&dest))
|
||||
dest_out.Set(reflect.ValueOf(dest))
|
||||
if !s.hasError() {
|
||||
rows, err := s.db.Query(s.sql, s.sqlVars...)
|
||||
if s.err(err) != nil {
|
||||
return
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var dest int64
|
||||
if s.err(rows.Scan(&dest)) == nil {
|
||||
dest_out.Set(reflect.ValueOf(dest))
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Do) pluck(value interface{}) *Do {
|
||||
if s.hasError() {
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Do) pluck(column string, value interface{}) {
|
||||
s.selectStr = column
|
||||
dest_out := reflect.Indirect(reflect.ValueOf(value))
|
||||
dest_type := dest_out.Type().Elem()
|
||||
s.prepareQuerySql()
|
||||
rows, err := s.db.Query(s.Sql, s.SqlVars...)
|
||||
s.err(err)
|
||||
if err != nil {
|
||||
return s
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
dest := reflect.New(dest_type).Elem().Interface()
|
||||
s.err(rows.Scan(&dest))
|
||||
switch dest.(type) {
|
||||
case []uint8:
|
||||
if dest_type.String() == "string" {
|
||||
dest = string(dest.([]uint8))
|
||||
if !s.hasError() {
|
||||
rows, err := s.db.Query(s.sql, s.sqlVars...)
|
||||
if s.err(err) != nil {
|
||||
return
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
dest := reflect.New(dest_type).Elem().Interface()
|
||||
s.err(rows.Scan(&dest))
|
||||
switch dest.(type) {
|
||||
case []uint8:
|
||||
if dest_type.String() == "string" {
|
||||
dest = string(dest.([]uint8))
|
||||
}
|
||||
dest_out.Set(reflect.Append(dest_out, reflect.ValueOf(dest)))
|
||||
default:
|
||||
dest_out.Set(reflect.Append(dest_out, reflect.ValueOf(dest)))
|
||||
}
|
||||
dest_out.Set(reflect.Append(dest_out, reflect.ValueOf(dest)))
|
||||
default:
|
||||
dest_out.Set(reflect.Append(dest_out, reflect.ValueOf(dest)))
|
||||
}
|
||||
}
|
||||
return s
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Do) where(querystring interface{}, args ...interface{}) *Do {
|
||||
func (s *Do) where(querystring interface{}, args ...interface{}) {
|
||||
s.whereClause = append(s.whereClause, map[string]interface{}{"query": querystring, "args": args})
|
||||
return s
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Do) primaryCondiation(value interface{}) string {
|
||||
return fmt.Sprintf("(%v = %v)", s.quote(s.model.primaryKeyDb()), value)
|
||||
return fmt.Sprintf("(%v = %v)", s.model.primaryKeyDb(), value)
|
||||
}
|
||||
|
||||
func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
|
||||
|
@ -324,17 +338,11 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
|
|||
switch reflect.TypeOf(arg).Kind() {
|
||||
case reflect.Slice: // For where("id in (?)", []int64{1,2})
|
||||
v := reflect.ValueOf(arg)
|
||||
|
||||
var temp_marks []string
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
temp_marks = append(temp_marks, "?")
|
||||
temp_marks = append(temp_marks, s.addToVars(v.Index(i).Addr().Interface()))
|
||||
}
|
||||
|
||||
str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1)
|
||||
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
str = strings.Replace(str, "?", s.addToVars(v.Index(i).Addr().Interface()), 1)
|
||||
}
|
||||
default:
|
||||
str = strings.Replace(str, "?", s.addToVars(arg), 1)
|
||||
}
|
||||
|
@ -421,7 +429,7 @@ func (s *Do) createTable() *Do {
|
|||
for _, field := range s.model.fields("null") {
|
||||
sqls = append(sqls, field.DbName+" "+field.SqlType)
|
||||
}
|
||||
s.Sql = fmt.Sprintf(
|
||||
s.sql = fmt.Sprintf(
|
||||
"CREATE TABLE \"%v\" (%v)",
|
||||
s.tableName(),
|
||||
strings.Join(sqls, ","),
|
||||
|
|
15
gorm_test.go
15
gorm_test.go
|
@ -2,6 +2,7 @@ package gorm
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
_ "github.com/lib/pq"
|
||||
"reflect"
|
||||
"strconv"
|
||||
|
@ -41,15 +42,23 @@ var (
|
|||
)
|
||||
|
||||
func init() {
|
||||
db, _ = Open("postgres", "user=gorm dbname=gorm sslmode=disable")
|
||||
var err error
|
||||
db, err = Open("postgres", "user=gorm dbname=gorm sslmode=disable")
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("No error should happen when connect database, but got %+v", err))
|
||||
}
|
||||
db.SetPool(10)
|
||||
|
||||
db.Exec("drop table users;")
|
||||
err = db.Exec("drop table users;").Error
|
||||
if err != nil {
|
||||
fmt.Printf("Got error when try to delete table uses, %+v\n", err)
|
||||
}
|
||||
|
||||
db.Exec("drop table products;")
|
||||
|
||||
orm := db.CreateTable(&User{})
|
||||
if orm.Error != nil {
|
||||
panic("No error should raise when create table")
|
||||
panic(fmt.Sprintf("No error should happen when create table, but got %+v", orm.Error))
|
||||
}
|
||||
|
||||
db.CreateTable(&Product{})
|
||||
|
|
28
main.go
28
main.go
|
@ -17,58 +17,58 @@ func (s *DB) SetPool(n int) {
|
|||
s.db.SetMaxIdleConns(n)
|
||||
}
|
||||
|
||||
func (s *DB) buildORM() *Chain {
|
||||
func (s *DB) buildChain() *Chain {
|
||||
return &Chain{db: s.db, driver: s.driver}
|
||||
}
|
||||
|
||||
func (s *DB) Where(querystring interface{}, args ...interface{}) *Chain {
|
||||
return s.buildORM().Where(querystring, args...)
|
||||
return s.buildChain().Where(querystring, args...)
|
||||
}
|
||||
|
||||
func (s *DB) First(out interface{}, where ...interface{}) *Chain {
|
||||
return s.buildORM().First(out, where...)
|
||||
return s.buildChain().First(out, where...)
|
||||
}
|
||||
|
||||
func (s *DB) Find(out interface{}, where ...interface{}) *Chain {
|
||||
return s.buildORM().Find(out, where...)
|
||||
return s.buildChain().Find(out, where...)
|
||||
}
|
||||
|
||||
func (s *DB) Limit(value interface{}) *Chain {
|
||||
return s.buildORM().Limit(value)
|
||||
return s.buildChain().Limit(value)
|
||||
}
|
||||
|
||||
func (s *DB) Offset(value interface{}) *Chain {
|
||||
return s.buildORM().Offset(value)
|
||||
return s.buildChain().Offset(value)
|
||||
}
|
||||
|
||||
func (s *DB) Order(value string, reorder ...bool) *Chain {
|
||||
return s.buildORM().Order(value, reorder...)
|
||||
return s.buildChain().Order(value, reorder...)
|
||||
}
|
||||
|
||||
func (s *DB) Select(value interface{}) *Chain {
|
||||
return s.buildORM().Select(value)
|
||||
return s.buildChain().Select(value)
|
||||
}
|
||||
|
||||
func (s *DB) Save(value interface{}) *Chain {
|
||||
return s.buildORM().Save(value)
|
||||
return s.buildChain().Save(value)
|
||||
}
|
||||
|
||||
func (s *DB) Delete(value interface{}) *Chain {
|
||||
return s.buildORM().Delete(value)
|
||||
return s.buildChain().Delete(value)
|
||||
}
|
||||
|
||||
func (s *DB) Exec(sql string) *Chain {
|
||||
return s.buildORM().Exec(sql)
|
||||
return s.buildChain().Exec(sql)
|
||||
}
|
||||
|
||||
func (s *DB) Model(value interface{}) *Chain {
|
||||
return s.buildORM().Model(value)
|
||||
return s.buildChain().Model(value)
|
||||
}
|
||||
|
||||
func (s *DB) Table(name string) *Chain {
|
||||
return s.buildORM().Table(name)
|
||||
return s.buildChain().Table(name)
|
||||
}
|
||||
|
||||
func (s *DB) CreateTable(value interface{}) *Chain {
|
||||
return s.buildORM().CreateTable(value)
|
||||
return s.buildChain().CreateTable(value)
|
||||
}
|
||||
|
|
40
model.go
40
model.go
|
@ -10,7 +10,7 @@ import (
|
|||
)
|
||||
|
||||
type Model struct {
|
||||
Data interface{}
|
||||
data interface{}
|
||||
driver string
|
||||
}
|
||||
|
||||
|
@ -25,23 +25,23 @@ type Field struct {
|
|||
}
|
||||
|
||||
func (m *Model) primaryKeyZero() bool {
|
||||
return m.primaryKeyValue() == 0
|
||||
return m.primaryKeyValue() <= 0
|
||||
}
|
||||
|
||||
func (m *Model) primaryKeyValue() int64 {
|
||||
if m.Data == nil {
|
||||
return 0
|
||||
if m.data == nil {
|
||||
return -1
|
||||
}
|
||||
|
||||
t := reflect.TypeOf(m.Data).Elem()
|
||||
t := reflect.TypeOf(m.data).Elem()
|
||||
switch t.Kind() {
|
||||
case reflect.Array, reflect.Chan, reflect.Map, reflect.Ptr, reflect.Slice:
|
||||
return 0
|
||||
default:
|
||||
result := reflect.ValueOf(m.Data).Elem()
|
||||
result := reflect.ValueOf(m.data).Elem()
|
||||
value := result.FieldByName(m.primaryKey())
|
||||
if value.IsValid() {
|
||||
return result.FieldByName(m.primaryKey()).Interface().(int64)
|
||||
return value.Interface().(int64)
|
||||
} else {
|
||||
return 0
|
||||
}
|
||||
|
@ -57,7 +57,7 @@ func (m *Model) primaryKeyDb() string {
|
|||
}
|
||||
|
||||
func (m *Model) fields(operation string) (fields []Field) {
|
||||
typ := reflect.TypeOf(m.Data).Elem()
|
||||
typ := reflect.TypeOf(m.data).Elem()
|
||||
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
p := typ.Field(i)
|
||||
|
@ -68,18 +68,16 @@ func (m *Model) fields(operation string) (fields []Field) {
|
|||
field.IsPrimaryKey = m.primaryKeyDb() == field.DbName
|
||||
field.AutoCreateTime = "created_at" == field.DbName
|
||||
field.AutoUpdateTime = "updated_at" == field.DbName
|
||||
value := reflect.ValueOf(m.Data).Elem().FieldByName(p.Name)
|
||||
value := reflect.ValueOf(m.data).Elem().FieldByName(p.Name)
|
||||
|
||||
switch operation {
|
||||
case "create":
|
||||
if (field.AutoCreateTime || field.AutoUpdateTime) && value.Interface().(time.Time).IsZero() {
|
||||
value = reflect.ValueOf(time.Now())
|
||||
reflect.ValueOf(m.Data).Elem().FieldByName(p.Name).Set(value)
|
||||
value.Set(reflect.ValueOf(time.Now()))
|
||||
}
|
||||
case "update":
|
||||
if field.AutoUpdateTime {
|
||||
value = reflect.ValueOf(time.Now())
|
||||
reflect.ValueOf(m.Data).Elem().FieldByName(p.Name).Set(value)
|
||||
value.Set(reflect.ValueOf(time.Now()))
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
@ -107,12 +105,12 @@ func (m *Model) columnsAndValues(operation string) map[string]interface{} {
|
|||
}
|
||||
|
||||
func (m *Model) tableName() (str string, err error) {
|
||||
if m.Data == nil {
|
||||
if m.data == nil {
|
||||
err = errors.New("Model haven't been set")
|
||||
return
|
||||
}
|
||||
|
||||
t := reflect.TypeOf(m.Data)
|
||||
t := reflect.TypeOf(m.data)
|
||||
for {
|
||||
c := false
|
||||
switch t.Kind() {
|
||||
|
@ -138,11 +136,11 @@ func (m *Model) tableName() (str string, err error) {
|
|||
}
|
||||
|
||||
func (m *Model) callMethod(method string) error {
|
||||
if m.Data == nil {
|
||||
if m.data == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
fm := reflect.ValueOf(m.Data).MethodByName(method)
|
||||
fm := reflect.ValueOf(m.data).MethodByName(method)
|
||||
if fm.IsValid() {
|
||||
v := fm.Call([]reflect.Value{})
|
||||
if len(v) > 0 {
|
||||
|
@ -154,13 +152,13 @@ func (m *Model) callMethod(method string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (model *Model) missingColumns() (results []string) {
|
||||
return
|
||||
}
|
||||
|
||||
func (model *Model) returningStr() (str string) {
|
||||
if model.driver == "postgres" {
|
||||
str = fmt.Sprintf("RETURNING \"%v\"", model.primaryKeyDb())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (model *Model) missingColumns() (results []string) {
|
||||
return
|
||||
}
|
||||
|
|
11
utils.go
11
utils.go
|
@ -7,17 +7,6 @@ import (
|
|||
"strings"
|
||||
)
|
||||
|
||||
func (s *Do) quote(value string) string {
|
||||
return "\"" + value + "\""
|
||||
}
|
||||
|
||||
func (s *Do) quoteMap(values []string) (results []string) {
|
||||
for _, value := range values {
|
||||
results = append(results, s.quote(value))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func toSnake(s string) string {
|
||||
buf := bytes.NewBufferString("")
|
||||
for i, v := range s {
|
||||
|
|
Loading…
Reference in New Issue