mirror of https://github.com/go-gorm/gorm.git
Cleanup code
This commit is contained in:
parent
84b280c0ff
commit
bc785a9173
|
@ -217,10 +217,11 @@ db.Where("mail_type = ?", "TEXT").Find(&users1).Table("deleted_users").First(&us
|
||||||
## TODO
|
## TODO
|
||||||
* Soft Delete
|
* Soft Delete
|
||||||
* Query with map or struct
|
* Query with map or struct
|
||||||
|
* SubStruct
|
||||||
|
* Index, Unique, Valiations
|
||||||
|
* Auto Migration
|
||||||
* FindOrInitialize / FindOrCreate
|
* FindOrInitialize / FindOrCreate
|
||||||
* SQL Log
|
* SQL Log
|
||||||
* Auto Migration
|
|
||||||
* Index, Unique, Valiations
|
|
||||||
* SQL Query with goroutines
|
* SQL Query with goroutines
|
||||||
* Only tested with postgres, confirm works with other database adaptors
|
* Only tested with postgres, confirm works with other database adaptors
|
||||||
|
|
||||||
|
|
15
chain.go
15
chain.go
|
@ -26,11 +26,12 @@ type Chain struct {
|
||||||
specifiedTableName string
|
specifiedTableName string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Chain) err(err error) {
|
func (s *Chain) err(err error) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.Errors = append(s.Errors, err)
|
s.Errors = append(s.Errors, err)
|
||||||
s.Error = err
|
s.Error = err
|
||||||
}
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Chain) do(value interface{}) *Do {
|
func (s *Chain) do(value interface{}) *Do {
|
||||||
|
@ -136,18 +137,16 @@ func (s *Chain) Update(column string, value interface{}) *Chain {
|
||||||
return s.Updates(map[string]interface{}{column: value}, true)
|
return s.Updates(map[string]interface{}{column: value}, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Chain) Updates(values map[string]interface{}, ignore_protected_attrs ...interface{}) *Chain {
|
func (s *Chain) Updates(values map[string]interface{}, ignore_protected_attrs ...bool) *Chain {
|
||||||
do := s.do(s.value)
|
do := s.do(s.value)
|
||||||
do.updateAttrs = values
|
do.updateAttrs = values
|
||||||
do.ignoreProtectedAttrs = len(ignore_protected_attrs) > 0
|
do.ignoreProtectedAttrs = len(ignore_protected_attrs) > 0 && ignore_protected_attrs[0]
|
||||||
do.update()
|
do.update()
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Chain) Exec(sql string) *Chain {
|
func (s *Chain) Exec(sql string) *Chain {
|
||||||
var err error
|
s.do(nil).exec(sql)
|
||||||
_, err = s.db.Exec(sql)
|
|
||||||
s.err(err)
|
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -164,9 +163,7 @@ func (s *Chain) Find(out interface{}, where ...interface{}) *Chain {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Chain) Pluck(column string, value interface{}) (orm *Chain) {
|
func (s *Chain) Pluck(column string, value interface{}) (orm *Chain) {
|
||||||
do := s.do(s.value)
|
s.do(s.value).pluck(column, value)
|
||||||
do.selectStr = column
|
|
||||||
do.pluck(value)
|
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
166
do.go
166
do.go
|
@ -21,10 +21,9 @@ type Do struct {
|
||||||
|
|
||||||
model *Model
|
model *Model
|
||||||
value interface{}
|
value interface{}
|
||||||
SqlResult sql.Result
|
sqlResult sql.Result
|
||||||
|
sql string
|
||||||
Sql string
|
sqlVars []interface{}
|
||||||
SqlVars []interface{}
|
|
||||||
|
|
||||||
whereClause []map[string]interface{}
|
whereClause []map[string]interface{}
|
||||||
orClause []map[string]interface{}
|
orClause []map[string]interface{}
|
||||||
|
@ -32,7 +31,6 @@ type Do struct {
|
||||||
orderStrs []string
|
orderStrs []string
|
||||||
offsetStr string
|
offsetStr string
|
||||||
limitStr string
|
limitStr string
|
||||||
operation string
|
|
||||||
|
|
||||||
updateAttrs map[string]interface{}
|
updateAttrs map[string]interface{}
|
||||||
ignoreProtectedAttrs bool
|
ignoreProtectedAttrs bool
|
||||||
|
@ -40,17 +38,21 @@ type Do struct {
|
||||||
|
|
||||||
func (s *Do) tableName() string {
|
func (s *Do) tableName() string {
|
||||||
if s.specifiedTableName == "" {
|
if s.specifiedTableName == "" {
|
||||||
|
var err error
|
||||||
|
s.guessedTableName, err = s.model.tableName()
|
||||||
|
s.err(err)
|
||||||
return s.guessedTableName
|
return s.guessedTableName
|
||||||
} else {
|
} else {
|
||||||
return s.specifiedTableName
|
return s.specifiedTableName
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) err(err error) {
|
func (s *Do) err(err error) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.Errors = append(s.Errors, err)
|
s.Errors = append(s.Errors, err)
|
||||||
s.chain.err(err)
|
s.chain.err(err)
|
||||||
}
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) hasError() bool {
|
func (s *Do) hasError() bool {
|
||||||
|
@ -58,18 +60,13 @@ func (s *Do) hasError() bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) setModel(value interface{}) {
|
func (s *Do) setModel(value interface{}) {
|
||||||
|
s.model = &Model{data: value, driver: s.driver}
|
||||||
s.value = value
|
s.value = value
|
||||||
s.model = &Model{Data: value, driver: s.driver}
|
|
||||||
var err error
|
|
||||||
if s.specifiedTableName == "" {
|
|
||||||
s.guessedTableName, err = s.model.tableName()
|
|
||||||
s.err(err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) addToVars(value interface{}) string {
|
func (s *Do) addToVars(value interface{}) string {
|
||||||
s.SqlVars = append(s.SqlVars, value)
|
s.sqlVars = append(s.sqlVars, value)
|
||||||
return fmt.Sprintf("$%d", len(s.SqlVars))
|
return fmt.Sprintf("$%d", len(s.sqlVars))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) exec(sql ...string) {
|
func (s *Do) exec(sql ...string) {
|
||||||
|
@ -79,23 +76,23 @@ func (s *Do) exec(sql ...string) {
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
if len(sql) == 0 {
|
if len(sql) == 0 {
|
||||||
s.SqlResult, err = s.db.Exec(s.Sql, s.SqlVars...)
|
s.sqlResult, err = s.db.Exec(s.sql, s.sqlVars...)
|
||||||
} else {
|
} else {
|
||||||
s.SqlResult, err = s.db.Exec(sql[0])
|
s.sqlResult, err = s.db.Exec(sql[0])
|
||||||
}
|
}
|
||||||
s.err(err)
|
s.err(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) save() *Do {
|
func (s *Do) save() {
|
||||||
if s.model.primaryKeyZero() {
|
if s.model.primaryKeyZero() {
|
||||||
s.create()
|
s.create()
|
||||||
} else {
|
} else {
|
||||||
s.update()
|
s.update()
|
||||||
}
|
}
|
||||||
return s
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) prepareCreateSql() *Do {
|
func (s *Do) prepareCreateSql() {
|
||||||
var sqls, columns []string
|
var sqls, columns []string
|
||||||
|
|
||||||
for key, value := range s.model.columnsAndValues("create") {
|
for key, value := range s.model.columnsAndValues("create") {
|
||||||
|
@ -103,44 +100,47 @@ func (s *Do) prepareCreateSql() *Do {
|
||||||
sqls = append(sqls, s.addToVars(value))
|
sqls = append(sqls, s.addToVars(value))
|
||||||
}
|
}
|
||||||
|
|
||||||
s.Sql = fmt.Sprintf(
|
s.sql = fmt.Sprintf(
|
||||||
"INSERT INTO \"%v\" (%v) VALUES (%v) %v",
|
"INSERT INTO \"%v\" (%v) VALUES (%v) %v",
|
||||||
s.tableName(),
|
s.tableName(),
|
||||||
strings.Join(s.quoteMap(columns), ","),
|
strings.Join(columns, ","),
|
||||||
strings.Join(sqls, ","),
|
strings.Join(sqls, ","),
|
||||||
s.model.returningStr(),
|
s.model.returningStr(),
|
||||||
)
|
)
|
||||||
return s
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) create() *Do {
|
func (s *Do) create() {
|
||||||
s.err(s.model.callMethod("BeforeCreate"))
|
s.err(s.model.callMethod("BeforeCreate"))
|
||||||
s.err(s.model.callMethod("BeforeSave"))
|
s.err(s.model.callMethod("BeforeSave"))
|
||||||
|
|
||||||
s.prepareCreateSql()
|
s.prepareCreateSql()
|
||||||
|
|
||||||
if len(s.Errors) == 0 {
|
if !s.hasError() {
|
||||||
var id int64
|
var id int64
|
||||||
if s.driver == "postgres" {
|
if s.driver == "postgres" {
|
||||||
s.err(s.db.QueryRow(s.Sql, s.SqlVars...).Scan(&id))
|
s.err(s.db.QueryRow(s.sql, s.sqlVars...).Scan(&id))
|
||||||
} else {
|
} else {
|
||||||
var err error
|
var err error
|
||||||
s.SqlResult, err = s.db.Exec(s.Sql, s.SqlVars...)
|
s.sqlResult, err = s.db.Exec(s.sql, s.sqlVars...)
|
||||||
s.err(err)
|
s.err(err)
|
||||||
id, err = s.SqlResult.LastInsertId()
|
id, err = s.sqlResult.LastInsertId()
|
||||||
s.err(err)
|
s.err(err)
|
||||||
}
|
}
|
||||||
result := reflect.ValueOf(s.model.Data).Elem()
|
|
||||||
|
if !s.hasError() {
|
||||||
|
result := reflect.ValueOf(s.value).Elem()
|
||||||
result.FieldByName(s.model.primaryKey()).SetInt(id)
|
result.FieldByName(s.model.primaryKey()).SetInt(id)
|
||||||
|
|
||||||
s.err(s.model.callMethod("AfterCreate"))
|
s.err(s.model.callMethod("AfterCreate"))
|
||||||
s.err(s.model.callMethod("AfterSave"))
|
s.err(s.model.callMethod("AfterSave"))
|
||||||
}
|
}
|
||||||
|
|
||||||
return s
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) prepareUpdateSql() *Do {
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Do) prepareUpdateSql() {
|
||||||
update_attrs := s.updateAttrs
|
update_attrs := s.updateAttrs
|
||||||
if len(update_attrs) == 0 {
|
if len(update_attrs) == 0 {
|
||||||
update_attrs = s.model.columnsAndValues("update")
|
update_attrs = s.model.columnsAndValues("update")
|
||||||
|
@ -148,46 +148,55 @@ func (s *Do) prepareUpdateSql() *Do {
|
||||||
|
|
||||||
var sqls []string
|
var sqls []string
|
||||||
for key, value := range update_attrs {
|
for key, value := range update_attrs {
|
||||||
sqls = append(sqls, fmt.Sprintf("%v = %v", s.quote(key), s.addToVars(value)))
|
sqls = append(sqls, fmt.Sprintf("%v = %v", key, s.addToVars(value)))
|
||||||
}
|
}
|
||||||
|
|
||||||
s.Sql = fmt.Sprintf(
|
s.sql = fmt.Sprintf(
|
||||||
"UPDATE %v SET %v %v",
|
"UPDATE %v SET %v %v",
|
||||||
s.tableName(),
|
s.tableName(),
|
||||||
strings.Join(sqls, ", "),
|
strings.Join(sqls, ", "),
|
||||||
s.combinedSql(),
|
s.combinedSql(),
|
||||||
)
|
)
|
||||||
return s
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) update() *Do {
|
func (s *Do) update() {
|
||||||
s.err(s.model.callMethod("BeforeUpdate"))
|
s.err(s.model.callMethod("BeforeUpdate"))
|
||||||
s.err(s.model.callMethod("BeforeSave"))
|
s.err(s.model.callMethod("BeforeSave"))
|
||||||
if len(s.Errors) == 0 {
|
|
||||||
s.prepareUpdateSql().exec()
|
s.prepareUpdateSql()
|
||||||
}
|
if !s.hasError() {
|
||||||
|
s.exec()
|
||||||
|
|
||||||
|
if !s.hasError() {
|
||||||
s.err(s.model.callMethod("AfterUpdate"))
|
s.err(s.model.callMethod("AfterUpdate"))
|
||||||
s.err(s.model.callMethod("AfterSave"))
|
s.err(s.model.callMethod("AfterSave"))
|
||||||
return s
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) prepareDeleteSql() *Do {
|
func (s *Do) prepareDeleteSql() {
|
||||||
s.Sql = fmt.Sprintf("DELETE FROM %v %v", s.tableName(), s.combinedSql())
|
s.sql = fmt.Sprintf("DELETE FROM %v %v", s.tableName(), s.combinedSql())
|
||||||
return s
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) delete() *Do {
|
func (s *Do) delete() {
|
||||||
s.err(s.model.callMethod("BeforeDelete"))
|
s.err(s.model.callMethod("BeforeDelete"))
|
||||||
if len(s.Errors) == 0 {
|
|
||||||
s.prepareDeleteSql().exec()
|
s.prepareDeleteSql()
|
||||||
}
|
if !s.hasError() {
|
||||||
|
s.exec()
|
||||||
|
if !s.hasError() {
|
||||||
s.err(s.model.callMethod("AfterDelete"))
|
s.err(s.model.callMethod("AfterDelete"))
|
||||||
return s
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) prepareQuerySql() *Do {
|
func (s *Do) prepareQuerySql() {
|
||||||
s.Sql = fmt.Sprintf("SELECT %v FROM %v %v", s.selectSql(), s.tableName(), s.combinedSql())
|
s.sql = fmt.Sprintf("SELECT %v FROM %v %v", s.selectSql(), s.tableName(), s.combinedSql())
|
||||||
return s
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) query(where ...interface{}) {
|
func (s *Do) query(where ...interface{}) {
|
||||||
|
@ -207,11 +216,9 @@ func (s *Do) query(where ...interface{}) {
|
||||||
}
|
}
|
||||||
|
|
||||||
s.prepareQuerySql()
|
s.prepareQuerySql()
|
||||||
|
if !s.hasError() {
|
||||||
rows, err := s.db.Query(s.Sql, s.SqlVars...)
|
rows, err := s.db.Query(s.sql, s.sqlVars...)
|
||||||
s.err(err)
|
if s.err(err) != nil {
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -250,33 +257,39 @@ func (s *Do) query(where ...interface{}) {
|
||||||
s.err(errors.New("Record not found!"))
|
s.err(errors.New("Record not found!"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Do) count(value interface{}) {
|
func (s *Do) count(value interface{}) {
|
||||||
dest_out := reflect.Indirect(reflect.ValueOf(value))
|
dest_out := reflect.Indirect(reflect.ValueOf(value))
|
||||||
|
|
||||||
s.prepareQuerySql()
|
s.prepareQuerySql()
|
||||||
rows, err := s.db.Query(s.Sql, s.SqlVars...)
|
if !s.hasError() {
|
||||||
s.err(err)
|
rows, err := s.db.Query(s.sql, s.sqlVars...)
|
||||||
|
if s.err(err) != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
defer rows.Close()
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var dest int64
|
var dest int64
|
||||||
s.err(rows.Scan(&dest))
|
if s.err(rows.Scan(&dest)) == nil {
|
||||||
dest_out.Set(reflect.ValueOf(dest))
|
dest_out.Set(reflect.ValueOf(dest))
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) pluck(value interface{}) *Do {
|
func (s *Do) pluck(column string, value interface{}) {
|
||||||
if s.hasError() {
|
s.selectStr = column
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
dest_out := reflect.Indirect(reflect.ValueOf(value))
|
dest_out := reflect.Indirect(reflect.ValueOf(value))
|
||||||
dest_type := dest_out.Type().Elem()
|
dest_type := dest_out.Type().Elem()
|
||||||
s.prepareQuerySql()
|
s.prepareQuerySql()
|
||||||
rows, err := s.db.Query(s.Sql, s.SqlVars...)
|
|
||||||
s.err(err)
|
if !s.hasError() {
|
||||||
if err != nil {
|
rows, err := s.db.Query(s.sql, s.sqlVars...)
|
||||||
return s
|
if s.err(err) != nil {
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
@ -293,16 +306,17 @@ func (s *Do) pluck(value interface{}) *Do {
|
||||||
dest_out.Set(reflect.Append(dest_out, reflect.ValueOf(dest)))
|
dest_out.Set(reflect.Append(dest_out, reflect.ValueOf(dest)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return s
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) where(querystring interface{}, args ...interface{}) *Do {
|
func (s *Do) where(querystring interface{}, args ...interface{}) {
|
||||||
s.whereClause = append(s.whereClause, map[string]interface{}{"query": querystring, "args": args})
|
s.whereClause = append(s.whereClause, map[string]interface{}{"query": querystring, "args": args})
|
||||||
return s
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) primaryCondiation(value interface{}) string {
|
func (s *Do) primaryCondiation(value interface{}) string {
|
||||||
return fmt.Sprintf("(%v = %v)", s.quote(s.model.primaryKeyDb()), value)
|
return fmt.Sprintf("(%v = %v)", s.model.primaryKeyDb(), value)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
|
func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
|
||||||
|
@ -324,17 +338,11 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
|
||||||
switch reflect.TypeOf(arg).Kind() {
|
switch reflect.TypeOf(arg).Kind() {
|
||||||
case reflect.Slice: // For where("id in (?)", []int64{1,2})
|
case reflect.Slice: // For where("id in (?)", []int64{1,2})
|
||||||
v := reflect.ValueOf(arg)
|
v := reflect.ValueOf(arg)
|
||||||
|
|
||||||
var temp_marks []string
|
var temp_marks []string
|
||||||
for i := 0; i < v.Len(); i++ {
|
for i := 0; i < v.Len(); i++ {
|
||||||
temp_marks = append(temp_marks, "?")
|
temp_marks = append(temp_marks, s.addToVars(v.Index(i).Addr().Interface()))
|
||||||
}
|
}
|
||||||
|
|
||||||
str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1)
|
str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1)
|
||||||
|
|
||||||
for i := 0; i < v.Len(); i++ {
|
|
||||||
str = strings.Replace(str, "?", s.addToVars(v.Index(i).Addr().Interface()), 1)
|
|
||||||
}
|
|
||||||
default:
|
default:
|
||||||
str = strings.Replace(str, "?", s.addToVars(arg), 1)
|
str = strings.Replace(str, "?", s.addToVars(arg), 1)
|
||||||
}
|
}
|
||||||
|
@ -421,7 +429,7 @@ func (s *Do) createTable() *Do {
|
||||||
for _, field := range s.model.fields("null") {
|
for _, field := range s.model.fields("null") {
|
||||||
sqls = append(sqls, field.DbName+" "+field.SqlType)
|
sqls = append(sqls, field.DbName+" "+field.SqlType)
|
||||||
}
|
}
|
||||||
s.Sql = fmt.Sprintf(
|
s.sql = fmt.Sprintf(
|
||||||
"CREATE TABLE \"%v\" (%v)",
|
"CREATE TABLE \"%v\" (%v)",
|
||||||
s.tableName(),
|
s.tableName(),
|
||||||
strings.Join(sqls, ","),
|
strings.Join(sqls, ","),
|
||||||
|
|
15
gorm_test.go
15
gorm_test.go
|
@ -2,6 +2,7 @@ package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
_ "github.com/lib/pq"
|
_ "github.com/lib/pq"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
@ -41,15 +42,23 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
db, _ = Open("postgres", "user=gorm dbname=gorm sslmode=disable")
|
var err error
|
||||||
|
db, err = Open("postgres", "user=gorm dbname=gorm sslmode=disable")
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("No error should happen when connect database, but got %+v", err))
|
||||||
|
}
|
||||||
db.SetPool(10)
|
db.SetPool(10)
|
||||||
|
|
||||||
db.Exec("drop table users;")
|
err = db.Exec("drop table users;").Error
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Got error when try to delete table uses, %+v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
db.Exec("drop table products;")
|
db.Exec("drop table products;")
|
||||||
|
|
||||||
orm := db.CreateTable(&User{})
|
orm := db.CreateTable(&User{})
|
||||||
if orm.Error != nil {
|
if orm.Error != nil {
|
||||||
panic("No error should raise when create table")
|
panic(fmt.Sprintf("No error should happen when create table, but got %+v", orm.Error))
|
||||||
}
|
}
|
||||||
|
|
||||||
db.CreateTable(&Product{})
|
db.CreateTable(&Product{})
|
||||||
|
|
28
main.go
28
main.go
|
@ -17,58 +17,58 @@ func (s *DB) SetPool(n int) {
|
||||||
s.db.SetMaxIdleConns(n)
|
s.db.SetMaxIdleConns(n)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) buildORM() *Chain {
|
func (s *DB) buildChain() *Chain {
|
||||||
return &Chain{db: s.db, driver: s.driver}
|
return &Chain{db: s.db, driver: s.driver}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Where(querystring interface{}, args ...interface{}) *Chain {
|
func (s *DB) Where(querystring interface{}, args ...interface{}) *Chain {
|
||||||
return s.buildORM().Where(querystring, args...)
|
return s.buildChain().Where(querystring, args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) First(out interface{}, where ...interface{}) *Chain {
|
func (s *DB) First(out interface{}, where ...interface{}) *Chain {
|
||||||
return s.buildORM().First(out, where...)
|
return s.buildChain().First(out, where...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Find(out interface{}, where ...interface{}) *Chain {
|
func (s *DB) Find(out interface{}, where ...interface{}) *Chain {
|
||||||
return s.buildORM().Find(out, where...)
|
return s.buildChain().Find(out, where...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Limit(value interface{}) *Chain {
|
func (s *DB) Limit(value interface{}) *Chain {
|
||||||
return s.buildORM().Limit(value)
|
return s.buildChain().Limit(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Offset(value interface{}) *Chain {
|
func (s *DB) Offset(value interface{}) *Chain {
|
||||||
return s.buildORM().Offset(value)
|
return s.buildChain().Offset(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Order(value string, reorder ...bool) *Chain {
|
func (s *DB) Order(value string, reorder ...bool) *Chain {
|
||||||
return s.buildORM().Order(value, reorder...)
|
return s.buildChain().Order(value, reorder...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Select(value interface{}) *Chain {
|
func (s *DB) Select(value interface{}) *Chain {
|
||||||
return s.buildORM().Select(value)
|
return s.buildChain().Select(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Save(value interface{}) *Chain {
|
func (s *DB) Save(value interface{}) *Chain {
|
||||||
return s.buildORM().Save(value)
|
return s.buildChain().Save(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Delete(value interface{}) *Chain {
|
func (s *DB) Delete(value interface{}) *Chain {
|
||||||
return s.buildORM().Delete(value)
|
return s.buildChain().Delete(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Exec(sql string) *Chain {
|
func (s *DB) Exec(sql string) *Chain {
|
||||||
return s.buildORM().Exec(sql)
|
return s.buildChain().Exec(sql)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Model(value interface{}) *Chain {
|
func (s *DB) Model(value interface{}) *Chain {
|
||||||
return s.buildORM().Model(value)
|
return s.buildChain().Model(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Table(name string) *Chain {
|
func (s *DB) Table(name string) *Chain {
|
||||||
return s.buildORM().Table(name)
|
return s.buildChain().Table(name)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) CreateTable(value interface{}) *Chain {
|
func (s *DB) CreateTable(value interface{}) *Chain {
|
||||||
return s.buildORM().CreateTable(value)
|
return s.buildChain().CreateTable(value)
|
||||||
}
|
}
|
||||||
|
|
40
model.go
40
model.go
|
@ -10,7 +10,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
Data interface{}
|
data interface{}
|
||||||
driver string
|
driver string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -25,23 +25,23 @@ type Field struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) primaryKeyZero() bool {
|
func (m *Model) primaryKeyZero() bool {
|
||||||
return m.primaryKeyValue() == 0
|
return m.primaryKeyValue() <= 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) primaryKeyValue() int64 {
|
func (m *Model) primaryKeyValue() int64 {
|
||||||
if m.Data == nil {
|
if m.data == nil {
|
||||||
return 0
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
t := reflect.TypeOf(m.Data).Elem()
|
t := reflect.TypeOf(m.data).Elem()
|
||||||
switch t.Kind() {
|
switch t.Kind() {
|
||||||
case reflect.Array, reflect.Chan, reflect.Map, reflect.Ptr, reflect.Slice:
|
case reflect.Array, reflect.Chan, reflect.Map, reflect.Ptr, reflect.Slice:
|
||||||
return 0
|
return 0
|
||||||
default:
|
default:
|
||||||
result := reflect.ValueOf(m.Data).Elem()
|
result := reflect.ValueOf(m.data).Elem()
|
||||||
value := result.FieldByName(m.primaryKey())
|
value := result.FieldByName(m.primaryKey())
|
||||||
if value.IsValid() {
|
if value.IsValid() {
|
||||||
return result.FieldByName(m.primaryKey()).Interface().(int64)
|
return value.Interface().(int64)
|
||||||
} else {
|
} else {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
@ -57,7 +57,7 @@ func (m *Model) primaryKeyDb() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) fields(operation string) (fields []Field) {
|
func (m *Model) fields(operation string) (fields []Field) {
|
||||||
typ := reflect.TypeOf(m.Data).Elem()
|
typ := reflect.TypeOf(m.data).Elem()
|
||||||
|
|
||||||
for i := 0; i < typ.NumField(); i++ {
|
for i := 0; i < typ.NumField(); i++ {
|
||||||
p := typ.Field(i)
|
p := typ.Field(i)
|
||||||
|
@ -68,18 +68,16 @@ func (m *Model) fields(operation string) (fields []Field) {
|
||||||
field.IsPrimaryKey = m.primaryKeyDb() == field.DbName
|
field.IsPrimaryKey = m.primaryKeyDb() == field.DbName
|
||||||
field.AutoCreateTime = "created_at" == field.DbName
|
field.AutoCreateTime = "created_at" == field.DbName
|
||||||
field.AutoUpdateTime = "updated_at" == field.DbName
|
field.AutoUpdateTime = "updated_at" == field.DbName
|
||||||
value := reflect.ValueOf(m.Data).Elem().FieldByName(p.Name)
|
value := reflect.ValueOf(m.data).Elem().FieldByName(p.Name)
|
||||||
|
|
||||||
switch operation {
|
switch operation {
|
||||||
case "create":
|
case "create":
|
||||||
if (field.AutoCreateTime || field.AutoUpdateTime) && value.Interface().(time.Time).IsZero() {
|
if (field.AutoCreateTime || field.AutoUpdateTime) && value.Interface().(time.Time).IsZero() {
|
||||||
value = reflect.ValueOf(time.Now())
|
value.Set(reflect.ValueOf(time.Now()))
|
||||||
reflect.ValueOf(m.Data).Elem().FieldByName(p.Name).Set(value)
|
|
||||||
}
|
}
|
||||||
case "update":
|
case "update":
|
||||||
if field.AutoUpdateTime {
|
if field.AutoUpdateTime {
|
||||||
value = reflect.ValueOf(time.Now())
|
value.Set(reflect.ValueOf(time.Now()))
|
||||||
reflect.ValueOf(m.Data).Elem().FieldByName(p.Name).Set(value)
|
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
@ -107,12 +105,12 @@ func (m *Model) columnsAndValues(operation string) map[string]interface{} {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) tableName() (str string, err error) {
|
func (m *Model) tableName() (str string, err error) {
|
||||||
if m.Data == nil {
|
if m.data == nil {
|
||||||
err = errors.New("Model haven't been set")
|
err = errors.New("Model haven't been set")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
t := reflect.TypeOf(m.Data)
|
t := reflect.TypeOf(m.data)
|
||||||
for {
|
for {
|
||||||
c := false
|
c := false
|
||||||
switch t.Kind() {
|
switch t.Kind() {
|
||||||
|
@ -138,11 +136,11 @@ func (m *Model) tableName() (str string, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) callMethod(method string) error {
|
func (m *Model) callMethod(method string) error {
|
||||||
if m.Data == nil {
|
if m.data == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
fm := reflect.ValueOf(m.Data).MethodByName(method)
|
fm := reflect.ValueOf(m.data).MethodByName(method)
|
||||||
if fm.IsValid() {
|
if fm.IsValid() {
|
||||||
v := fm.Call([]reflect.Value{})
|
v := fm.Call([]reflect.Value{})
|
||||||
if len(v) > 0 {
|
if len(v) > 0 {
|
||||||
|
@ -154,13 +152,13 @@ func (m *Model) callMethod(method string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (model *Model) missingColumns() (results []string) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (model *Model) returningStr() (str string) {
|
func (model *Model) returningStr() (str string) {
|
||||||
if model.driver == "postgres" {
|
if model.driver == "postgres" {
|
||||||
str = fmt.Sprintf("RETURNING \"%v\"", model.primaryKeyDb())
|
str = fmt.Sprintf("RETURNING \"%v\"", model.primaryKeyDb())
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (model *Model) missingColumns() (results []string) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
11
utils.go
11
utils.go
|
@ -7,17 +7,6 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Do) quote(value string) string {
|
|
||||||
return "\"" + value + "\""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Do) quoteMap(values []string) (results []string) {
|
|
||||||
for _, value := range values {
|
|
||||||
results = append(results, s.quote(value))
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func toSnake(s string) string {
|
func toSnake(s string) string {
|
||||||
buf := bytes.NewBufferString("")
|
buf := bytes.NewBufferString("")
|
||||||
for i, v := range s {
|
for i, v := range s {
|
||||||
|
|
Loading…
Reference in New Issue