Merge branch 'master' into refactor

This commit is contained in:
Jinzhu 2015-02-17 08:38:32 +08:00
commit 139b9a37e7
14 changed files with 183 additions and 199 deletions

View File

@ -9,10 +9,10 @@ type callback struct {
updates []*func(scope *Scope) updates []*func(scope *Scope)
deletes []*func(scope *Scope) deletes []*func(scope *Scope)
queries []*func(scope *Scope) queries []*func(scope *Scope)
processors []*callback_processor processors []*callbackProcessor
} }
type callback_processor struct { type callbackProcessor struct {
name string name string
before string before string
after string after string
@ -23,8 +23,8 @@ type callback_processor struct {
callback *callback callback *callback
} }
func (c *callback) addProcessor(typ string) *callback_processor { func (c *callback) addProcessor(typ string) *callbackProcessor {
cp := &callback_processor{typ: typ, callback: c} cp := &callbackProcessor{typ: typ, callback: c}
c.processors = append(c.processors, cp) c.processors = append(c.processors, cp)
return cp return cp
} }
@ -33,46 +33,46 @@ func (c *callback) clone() *callback {
return &callback{processors: c.processors} return &callback{processors: c.processors}
} }
func (c *callback) Create() *callback_processor { func (c *callback) Create() *callbackProcessor {
return c.addProcessor("create") return c.addProcessor("create")
} }
func (c *callback) Update() *callback_processor { func (c *callback) Update() *callbackProcessor {
return c.addProcessor("update") return c.addProcessor("update")
} }
func (c *callback) Delete() *callback_processor { func (c *callback) Delete() *callbackProcessor {
return c.addProcessor("delete") return c.addProcessor("delete")
} }
func (c *callback) Query() *callback_processor { func (c *callback) Query() *callbackProcessor {
return c.addProcessor("query") return c.addProcessor("query")
} }
func (cp *callback_processor) Before(name string) *callback_processor { func (cp *callbackProcessor) Before(name string) *callbackProcessor {
cp.before = name cp.before = name
return cp return cp
} }
func (cp *callback_processor) After(name string) *callback_processor { func (cp *callbackProcessor) After(name string) *callbackProcessor {
cp.after = name cp.after = name
return cp return cp
} }
func (cp *callback_processor) Register(name string, fc func(scope *Scope)) { func (cp *callbackProcessor) Register(name string, fc func(scope *Scope)) {
cp.name = name cp.name = name
cp.processor = &fc cp.processor = &fc
cp.callback.sort() cp.callback.sort()
} }
func (cp *callback_processor) Remove(name string) { func (cp *callbackProcessor) Remove(name string) {
fmt.Printf("[info] removing callback `%v` from %v\n", name, fileWithLineNum()) fmt.Printf("[info] removing callback `%v` from %v\n", name, fileWithLineNum())
cp.name = name cp.name = name
cp.remove = true cp.remove = true
cp.callback.sort() cp.callback.sort()
} }
func (cp *callback_processor) Replace(name string, fc func(scope *Scope)) { func (cp *callbackProcessor) Replace(name string, fc func(scope *Scope)) {
fmt.Printf("[info] replacing callback `%v` from %v\n", name, fileWithLineNum()) fmt.Printf("[info] replacing callback `%v` from %v\n", name, fileWithLineNum())
cp.name = name cp.name = name
cp.processor = &fc cp.processor = &fc
@ -89,8 +89,8 @@ func getRIndex(strs []string, str string) int {
return -1 return -1
} }
func sortProcessors(cps []*callback_processor) []*func(scope *Scope) { func sortProcessors(cps []*callbackProcessor) []*func(scope *Scope) {
var sortCallbackProcessor func(c *callback_processor) var sortCallbackProcessor func(c *callbackProcessor)
var names, sortedNames = []string{}, []string{} var names, sortedNames = []string{}, []string{}
for _, cp := range cps { for _, cp := range cps {
@ -102,7 +102,7 @@ func sortProcessors(cps []*callback_processor) []*func(scope *Scope) {
names = append(names, cp.name) names = append(names, cp.name)
} }
sortCallbackProcessor = func(c *callback_processor) { sortCallbackProcessor = func(c *callbackProcessor) {
if getRIndex(sortedNames, c.name) > -1 { if getRIndex(sortedNames, c.name) > -1 {
return return
} }
@ -162,7 +162,7 @@ func sortProcessors(cps []*callback_processor) []*func(scope *Scope) {
} }
func (c *callback) sort() { func (c *callback) sort() {
creates, updates, deletes, queries := []*callback_processor{}, []*callback_processor{}, []*callback_processor{}, []*callback_processor{} creates, updates, deletes, queries := []*callbackProcessor{}, []*callbackProcessor{}, []*callbackProcessor{}, []*callbackProcessor{}
for _, processor := range c.processors { for _, processor := range c.processors {
switch processor.typ { switch processor.typ {
@ -183,4 +183,4 @@ func (c *callback) sort() {
c.queries = sortProcessors(queries) c.queries = sortProcessors(queries)
} }
var DefaultCallback = &callback{processors: []*callback_processor{}} var DefaultCallback = &callback{processors: []*callbackProcessor{}}

View File

@ -50,7 +50,7 @@ func Query(scope *Scope) {
columns, _ := rows.Columns() columns, _ := rows.Columns()
defer rows.Close() defer rows.Close()
for rows.Next() { for rows.Next() {
scope.db.RowsAffected += 1 scope.db.RowsAffected++
anyRecordFound = true anyRecordFound = true
elem := dest elem := dest

View File

@ -17,96 +17,96 @@ func equalFuncs(funcs []*func(s *Scope), fnames []string) bool {
} }
func create(s *Scope) {} func create(s *Scope) {}
func before_create1(s *Scope) {} func beforeCreate1(s *Scope) {}
func before_create2(s *Scope) {} func beforeCreate2(s *Scope) {}
func after_create1(s *Scope) {} func afterCreate1(s *Scope) {}
func after_create2(s *Scope) {} func afterCreate2(s *Scope) {}
func TestRegisterCallback(t *testing.T) { func TestRegisterCallback(t *testing.T) {
var callback = &callback{processors: []*callback_processor{}} var callback = &callback{processors: []*callbackProcessor{}}
callback.Create().Register("before_create1", before_create1) callback.Create().Register("before_create1", beforeCreate1)
callback.Create().Register("before_create2", before_create2) callback.Create().Register("before_create2", beforeCreate2)
callback.Create().Register("create", create) callback.Create().Register("create", create)
callback.Create().Register("after_create1", after_create1) callback.Create().Register("after_create1", afterCreate1)
callback.Create().Register("after_create2", after_create2) callback.Create().Register("after_create2", afterCreate2)
if !equalFuncs(callback.creates, []string{"before_create1", "before_create2", "create", "after_create1", "after_create2"}) { if !equalFuncs(callback.creates, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
t.Errorf("register callback") t.Errorf("register callback")
} }
} }
func TestRegisterCallbackWithOrder(t *testing.T) { func TestRegisterCallbackWithOrder(t *testing.T) {
var callback1 = &callback{processors: []*callback_processor{}} var callback1 = &callback{processors: []*callbackProcessor{}}
callback1.Create().Register("before_create1", before_create1) callback1.Create().Register("before_create1", beforeCreate1)
callback1.Create().Register("create", create) callback1.Create().Register("create", create)
callback1.Create().Register("after_create1", after_create1) callback1.Create().Register("after_create1", afterCreate1)
callback1.Create().Before("after_create1").Register("after_create2", after_create2) callback1.Create().Before("after_create1").Register("after_create2", afterCreate2)
if !equalFuncs(callback1.creates, []string{"before_create1", "create", "after_create2", "after_create1"}) { if !equalFuncs(callback1.creates, []string{"beforeCreate1", "create", "afterCreate2", "afterCreate1"}) {
t.Errorf("register callback with order") t.Errorf("register callback with order")
} }
var callback2 = &callback{processors: []*callback_processor{}} var callback2 = &callback{processors: []*callbackProcessor{}}
callback2.Update().Register("create", create) callback2.Update().Register("create", create)
callback2.Update().Before("create").Register("before_create1", before_create1) callback2.Update().Before("create").Register("before_create1", beforeCreate1)
callback2.Update().After("after_create2").Register("after_create1", after_create1) callback2.Update().After("after_create2").Register("after_create1", afterCreate1)
callback2.Update().Before("before_create1").Register("before_create2", before_create2) callback2.Update().Before("before_create1").Register("before_create2", beforeCreate2)
callback2.Update().Register("after_create2", after_create2) callback2.Update().Register("after_create2", afterCreate2)
if !equalFuncs(callback2.updates, []string{"before_create2", "before_create1", "create", "after_create2", "after_create1"}) { if !equalFuncs(callback2.updates, []string{"beforeCreate2", "beforeCreate1", "create", "afterCreate2", "afterCreate1"}) {
t.Errorf("register callback with order") t.Errorf("register callback with order")
} }
} }
func TestRegisterCallbackWithComplexOrder(t *testing.T) { func TestRegisterCallbackWithComplexOrder(t *testing.T) {
var callback1 = &callback{processors: []*callback_processor{}} var callback1 = &callback{processors: []*callbackProcessor{}}
callback1.Query().Before("after_create1").After("before_create1").Register("create", create) callback1.Query().Before("after_create1").After("before_create1").Register("create", create)
callback1.Query().Register("before_create1", before_create1) callback1.Query().Register("before_create1", beforeCreate1)
callback1.Query().Register("after_create1", after_create1) callback1.Query().Register("after_create1", afterCreate1)
if !equalFuncs(callback1.queries, []string{"before_create1", "create", "after_create1"}) { if !equalFuncs(callback1.queries, []string{"beforeCreate1", "create", "afterCreate1"}) {
t.Errorf("register callback with order") t.Errorf("register callback with order")
} }
var callback2 = &callback{processors: []*callback_processor{}} var callback2 = &callback{processors: []*callbackProcessor{}}
callback2.Delete().Before("after_create1").After("before_create1").Register("create", create) callback2.Delete().Before("after_create1").After("before_create1").Register("create", create)
callback2.Delete().Before("create").Register("before_create1", before_create1) callback2.Delete().Before("create").Register("before_create1", beforeCreate1)
callback2.Delete().After("before_create1").Register("before_create2", before_create2) callback2.Delete().After("before_create1").Register("before_create2", beforeCreate2)
callback2.Delete().Register("after_create1", after_create1) callback2.Delete().Register("after_create1", afterCreate1)
callback2.Delete().After("after_create1").Register("after_create2", after_create2) callback2.Delete().After("after_create1").Register("after_create2", afterCreate2)
if !equalFuncs(callback2.deletes, []string{"before_create1", "before_create2", "create", "after_create1", "after_create2"}) { if !equalFuncs(callback2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
t.Errorf("register callback with order") t.Errorf("register callback with order")
} }
} }
func replace_create(s *Scope) {} func replaceCreate(s *Scope) {}
func TestReplaceCallback(t *testing.T) { func TestReplaceCallback(t *testing.T) {
var callback = &callback{processors: []*callback_processor{}} var callback = &callback{processors: []*callbackProcessor{}}
callback.Create().Before("after_create1").After("before_create1").Register("create", create) callback.Create().Before("after_create1").After("before_create1").Register("create", create)
callback.Create().Register("before_create1", before_create1) callback.Create().Register("before_create1", beforeCreate1)
callback.Create().Register("after_create1", after_create1) callback.Create().Register("after_create1", afterCreate1)
callback.Create().Replace("create", replace_create) callback.Create().Replace("create", replaceCreate)
if !equalFuncs(callback.creates, []string{"before_create1", "replace_create", "after_create1"}) { if !equalFuncs(callback.creates, []string{"beforeCreate1", "replaceCreate", "afterCreate1"}) {
t.Errorf("replace callback") t.Errorf("replace callback")
} }
} }
func TestRemoveCallback(t *testing.T) { func TestRemoveCallback(t *testing.T) {
var callback = &callback{processors: []*callback_processor{}} var callback = &callback{processors: []*callbackProcessor{}}
callback.Create().Before("after_create1").After("before_create1").Register("create", create) callback.Create().Before("after_create1").After("before_create1").Register("create", create)
callback.Create().Register("before_create1", before_create1) callback.Create().Register("before_create1", beforeCreate1)
callback.Create().Register("after_create1", after_create1) callback.Create().Register("after_create1", afterCreate1)
callback.Create().Remove("create") callback.Create().Remove("create")
if !equalFuncs(callback.creates, []string{"before_create1", "after_create1"}) { if !equalFuncs(callback.creates, []string{"beforeCreate1", "afterCreate1"}) {
t.Errorf("remove callback") t.Errorf("remove callback")
} }
} }

View File

@ -20,7 +20,7 @@ func (s *commonDialect) HasTop() bool {
return false return false
} }
func (d *commonDialect) SqlTag(value reflect.Value, size int) string { func (s *commonDialect) SqlTag(value reflect.Value, size int) string {
switch value.Kind() { switch value.Kind() {
case reflect.Bool: case reflect.Bool:
return "BOOLEAN" return "BOOLEAN"
@ -33,9 +33,8 @@ func (d *commonDialect) SqlTag(value reflect.Value, size int) string {
case reflect.String: case reflect.String:
if size > 0 && size < 65532 { if size > 0 && size < 65532 {
return fmt.Sprintf("VARCHAR(%d)", size) return fmt.Sprintf("VARCHAR(%d)", size)
} else {
return "VARCHAR(65532)"
} }
return "VARCHAR(65532)"
case reflect.Struct: case reflect.Struct:
if value.Type() == timeType { if value.Type() == timeType {
return "TIMESTAMP" return "TIMESTAMP"
@ -44,21 +43,20 @@ func (d *commonDialect) SqlTag(value reflect.Value, size int) string {
if _, ok := value.Interface().([]byte); ok { if _, ok := value.Interface().([]byte); ok {
if size > 0 && size < 65532 { if size > 0 && size < 65532 {
return fmt.Sprintf("BINARY(%d)", size) return fmt.Sprintf("BINARY(%d)", size)
} else {
return "BINARY(65532)"
} }
return "BINARY(65532)"
} }
} }
panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", value.Type().Name(), value.Kind().String())) panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", value.Type().Name(), value.Kind().String()))
} }
func (s *commonDialect) PrimaryKeyTag(value reflect.Value, size int) string { func (s *commonDialect) PrimaryKeyTag(value reflect.Value, size int) string {
suffix_str := " NOT NULL PRIMARY KEY" suffix := " NOT NULL PRIMARY KEY"
switch value.Kind() { switch value.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
return "INTEGER" + suffix_str return "INTEGER" + suffix
case reflect.Int64, reflect.Uint64: case reflect.Int64, reflect.Uint64:
return "BIGINT" + suffix_str return "BIGINT" + suffix
default: default:
panic("Invalid primary key type") panic("Invalid primary key type")
} }

View File

@ -3,9 +3,9 @@ package gorm
import "errors" import "errors"
var ( var (
RecordNotFound = errors.New("Record Not Found") RecordNotFound = errors.New("record not found")
InvalidSql = errors.New("Invalid SQL") InvalidSql = errors.New("invalid sql")
NoNewAttrs = errors.New("No new Attributes") NoNewAttrs = errors.New("no new attributes")
NoValidTransaction = errors.New("No valid transaction") NoValidTransaction = errors.New("no valid transaction")
CantStartTransaction = errors.New("Can't start transaction") CantStartTransaction = errors.New("can't start transaction")
) )

View File

@ -265,9 +265,8 @@ func (s *DB) Save(value interface{}) *DB {
scope := s.clone().NewScope(value) scope := s.clone().NewScope(value)
if scope.PrimaryKeyZero() { if scope.PrimaryKeyZero() {
return scope.callCallbacks(s.parent.callback.creates).db return scope.callCallbacks(s.parent.callback.creates).db
} else {
return scope.callCallbacks(s.parent.callback.updates).db
} }
return scope.callCallbacks(s.parent.callback.updates).db
} }
func (s *DB) Create(value interface{}) *DB { func (s *DB) Create(value interface{}) *DB {

View File

@ -588,10 +588,10 @@ func BenchmarkGorm(b *testing.B) {
func BenchmarkRawSql(b *testing.B) { func BenchmarkRawSql(b *testing.B) {
DB, _ := sql.Open("postgres", "user=gorm DB.ame=gorm sslmode=disable") DB, _ := sql.Open("postgres", "user=gorm DB.ame=gorm sslmode=disable")
DB.SetMaxIdleConns(10) DB.SetMaxIdleConns(10)
insert_sql := "INSERT INTO emails (user_id,email,user_agent,registered_at,created_at,updated_at) VALUES ($1,$2,$3,$4,$5,$6) RETURNING id" insertSql := "INSERT INTO emails (user_id,email,user_agent,registered_at,created_at,updated_at) VALUES ($1,$2,$3,$4,$5,$6) RETURNING id"
query_sql := "SELECT * FROM emails WHERE email = $1 ORDER BY id LIMIT 1" querySql := "SELECT * FROM emails WHERE email = $1 ORDER BY id LIMIT 1"
update_sql := "UPDATE emails SET email = $1, updated_at = $2 WHERE id = $3" updateSql := "UPDATE emails SET email = $1, updated_at = $2 WHERE id = $3"
delete_sql := "DELETE FROM orders WHERE id = $1" deleteSql := "DELETE FROM orders WHERE id = $1"
b.N = 2000 b.N = 2000
for x := 0; x < b.N; x++ { for x := 0; x < b.N; x++ {
@ -599,13 +599,13 @@ func BenchmarkRawSql(b *testing.B) {
e := strconv.Itoa(x) + "benchmark@example.org" e := strconv.Itoa(x) + "benchmark@example.org"
email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()} email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()}
// Insert // Insert
DB.QueryRow(insert_sql, email.UserId, email.Email, email.UserAgent, email.RegisteredAt, time.Now(), time.Now()).Scan(&id) DB.QueryRow(insertSql, email.UserId, email.Email, email.UserAgent, email.RegisteredAt, time.Now(), time.Now()).Scan(&id)
// Query // Query
rows, _ := DB.Query(query_sql, email.Email) rows, _ := DB.Query(querySql, email.Email)
rows.Close() rows.Close()
// Update // Update
DB.Exec(update_sql, "new-"+e, time.Now(), id) DB.Exec(updateSql, "new-"+e, time.Now(), id)
// Delete // Delete
DB.Exec(delete_sql, id) DB.Exec(deleteSql, id)
} }
} }

View File

@ -20,7 +20,7 @@ func (s *mssql) HasTop() bool {
return true return true
} }
func (d *mssql) SqlTag(value reflect.Value, size int) string { func (s *mssql) SqlTag(value reflect.Value, size int) string {
switch value.Kind() { switch value.Kind() {
case reflect.Bool: case reflect.Bool:
return "bit" return "bit"
@ -33,9 +33,8 @@ func (d *mssql) SqlTag(value reflect.Value, size int) string {
case reflect.String: case reflect.String:
if size > 0 && size < 65532 { if size > 0 && size < 65532 {
return fmt.Sprintf("nvarchar(%d)", size) return fmt.Sprintf("nvarchar(%d)", size)
} else {
return "text"
} }
return "text"
case reflect.Struct: case reflect.Struct:
if value.Type() == timeType { if value.Type() == timeType {
return "datetime2" return "datetime2"
@ -44,21 +43,20 @@ func (d *mssql) SqlTag(value reflect.Value, size int) string {
if _, ok := value.Interface().([]byte); ok { if _, ok := value.Interface().([]byte); ok {
if size > 0 && size < 65532 { if size > 0 && size < 65532 {
return fmt.Sprintf("varchar(%d)", size) return fmt.Sprintf("varchar(%d)", size)
} else {
return "text"
} }
return "text"
} }
} }
panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String())) panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String()))
} }
func (s *mssql) PrimaryKeyTag(value reflect.Value, size int) string { func (s *mssql) PrimaryKeyTag(value reflect.Value, size int) string {
suffix_str := " IDENTITY(1,1) PRIMARY KEY" suffix := " IDENTITY(1,1) PRIMARY KEY"
switch value.Kind() { switch value.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
return "int" + suffix_str return "int" + suffix
case reflect.Int64, reflect.Uint64: case reflect.Int64, reflect.Uint64:
return "bigint" + suffix_str return "bigint" + suffix
default: default:
panic("Invalid primary key type") panic("Invalid primary key type")
} }

View File

@ -20,7 +20,7 @@ func (s *mysql) HasTop() bool {
return false return false
} }
func (d *mysql) SqlTag(value reflect.Value, size int) string { func (s *mysql) SqlTag(value reflect.Value, size int) string {
switch value.Kind() { switch value.Kind() {
case reflect.Bool: case reflect.Bool:
return "boolean" return "boolean"
@ -33,9 +33,8 @@ func (d *mysql) SqlTag(value reflect.Value, size int) string {
case reflect.String: case reflect.String:
if size > 0 && size < 65532 { if size > 0 && size < 65532 {
return fmt.Sprintf("varchar(%d)", size) return fmt.Sprintf("varchar(%d)", size)
} else {
return "longtext"
} }
return "longtext"
case reflect.Struct: case reflect.Struct:
if value.Type() == timeType { if value.Type() == timeType {
return "timestamp NULL" return "timestamp NULL"
@ -44,21 +43,20 @@ func (d *mysql) SqlTag(value reflect.Value, size int) string {
if _, ok := value.Interface().([]byte); ok { if _, ok := value.Interface().([]byte); ok {
if size > 0 && size < 65532 { if size > 0 && size < 65532 {
return fmt.Sprintf("varbinary(%d)", size) return fmt.Sprintf("varbinary(%d)", size)
} else {
return "longblob"
} }
return "longblob"
} }
} }
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String())) panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String()))
} }
func (s *mysql) PrimaryKeyTag(value reflect.Value, size int) string { func (s *mysql) PrimaryKeyTag(value reflect.Value, size int) string {
suffix_str := " NOT NULL AUTO_INCREMENT PRIMARY KEY" suffix := " NOT NULL AUTO_INCREMENT PRIMARY KEY"
switch value.Kind() { switch value.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
return "int" + suffix_str return "int" + suffix
case reflect.Int64, reflect.Uint64: case reflect.Int64, reflect.Uint64:
return "bigint" + suffix_str return "bigint" + suffix
default: default:
panic("Invalid primary key type") panic("Invalid primary key type")
} }

View File

@ -24,7 +24,7 @@ func (s *postgres) HasTop() bool {
return false return false
} }
func (d *postgres) SqlTag(value reflect.Value, size int) string { func (s *postgres) SqlTag(value reflect.Value, size int) string {
switch value.Kind() { switch value.Kind() {
case reflect.Bool: case reflect.Bool:
return "boolean" return "boolean"

View File

@ -39,7 +39,7 @@ func (db *DB) NewScope(value interface{}) *Scope {
func (scope *Scope) NeedPtr() *Scope { func (scope *Scope) NeedPtr() *Scope {
reflectKind := reflect.ValueOf(scope.Value).Kind() reflectKind := reflect.ValueOf(scope.Value).Kind()
if !((reflectKind == reflect.Invalid) || (reflectKind == reflect.Ptr)) { if !((reflectKind == reflect.Invalid) || (reflectKind == reflect.Ptr)) {
err := errors.New(fmt.Sprintf("%v %v\n", fileWithLineNum(), "using unaddressable value")) err := fmt.Errorf("%v %v\n", fileWithLineNum(), "using unaddressable value")
scope.Err(err) scope.Err(err)
fmt.Printf(err.Error()) fmt.Printf(err.Error())
} }
@ -102,9 +102,8 @@ func (scope *Scope) PrimaryKeyField() *Field {
func (scope *Scope) PrimaryKey() string { func (scope *Scope) PrimaryKey() string {
if field := scope.PrimaryKeyField(); field != nil { if field := scope.PrimaryKeyField(); field != nil {
return field.DBName return field.DBName
} else {
return ""
} }
return ""
} }
// PrimaryKeyZero check the primary key is blank or not // PrimaryKeyZero check the primary key is blank or not
@ -116,9 +115,8 @@ func (scope *Scope) PrimaryKeyZero() bool {
func (scope *Scope) PrimaryKeyValue() interface{} { func (scope *Scope) PrimaryKeyValue() interface{} {
if field := scope.PrimaryKeyField(); field != nil && field.Field.IsValid() { if field := scope.PrimaryKeyField(); field != nil && field.Field.IsValid() {
return field.Field.Interface() return field.Field.Interface()
} else {
return 0
} }
return 0
} }
// HasColumn to check if has column // HasColumn to check if has column
@ -184,7 +182,7 @@ func (scope *Scope) CallMethod(name string) {
case func(s *DB) error: case func(s *DB) error:
scope.Err(f(scope.db.New())) scope.Err(f(scope.db.New()))
default: default:
scope.Err(errors.New(fmt.Sprintf("unsupported function %v", name))) scope.Err(fmt.Errorf("unsupported function %v", name))
} }
} }
} }
@ -209,7 +207,8 @@ func (scope *Scope) AddToVars(value interface{}) string {
func (scope *Scope) TableName() string { func (scope *Scope) TableName() string {
if scope.Search != nil && len(scope.Search.TableName) > 0 { if scope.Search != nil && len(scope.Search.TableName) > 0 {
return scope.Search.TableName return scope.Search.TableName
} else { }
if scope.Value == nil { if scope.Value == nil {
scope.Err(errors.New("can't get table name")) scope.Err(errors.New("can't get table name"))
return "" return ""
@ -243,19 +242,19 @@ func (scope *Scope) TableName() string {
} }
return str return str
}
} }
func (scope *Scope) QuotedTableName() string { func (scope *Scope) QuotedTableName() string {
if scope.Search != nil && len(scope.Search.TableName) > 0 { if scope.Search != nil && len(scope.Search.TableName) > 0 {
return scope.Search.TableName return scope.Search.TableName
} else { }
keys := strings.Split(scope.TableName(), ".") keys := strings.Split(scope.TableName(), ".")
for i, v := range keys { for i, v := range keys {
keys[i] = scope.Quote(v) keys[i] = scope.Quote(v)
} }
return strings.Join(keys, ".") return strings.Join(keys, ".")
}
} }
// CombinedConditionSql get combined condition sql // CombinedConditionSql get combined condition sql

View File

@ -214,87 +214,80 @@ func (scope *Scope) whereSql() (sql string) {
return return
} }
func (s *Scope) selectSql() string { func (scope *Scope) selectSql() string {
if len(s.Search.Selects) == 0 { if len(scope.Search.Selects) == 0 {
return "*" return "*"
} }
var selectQueries []string var selectQueries []string
for _, clause := range s.Search.Selects { for _, clause := range scope.Search.Selects {
selectQueries = append(selectQueries, s.buildSelectQuery(clause)) selectQueries = append(selectQueries, scope.buildSelectQuery(clause))
} }
return strings.Join(selectQueries, ", ") return strings.Join(selectQueries, ", ")
} }
func (s *Scope) orderSql() string { func (scope *Scope) orderSql() string {
if len(s.Search.Orders) == 0 { if len(scope.Search.Orders) == 0 {
return "" return ""
} else {
return " ORDER BY " + strings.Join(s.Search.Orders, ",")
} }
return " ORDER BY " + strings.Join(scope.Search.Orders, ",")
} }
func (s *Scope) limitSql() string { func (scope *Scope) limitSql() string {
if !s.Dialect().HasTop() { if !scope.Dialect().HasTop() {
if len(s.Search.Limit) == 0 { if len(scope.Search.Limit) == 0 {
return ""
} else {
return " LIMIT " + s.Search.Limit
}
} else {
return "" return ""
} }
return " LIMIT " + scope.Search.Limit
}
return ""
} }
func (s *Scope) topSql() string { func (scope *Scope) topSql() string {
if s.Dialect().HasTop() && len(s.Search.Offset) == 0 { if scope.Dialect().HasTop() && len(scope.Search.Offset) == 0 {
if len(s.Search.Limit) == 0 { if len(scope.Search.Limit) == 0 {
return ""
} else {
return " TOP(" + s.Search.Limit + ")"
}
} else {
return "" return ""
} }
return " TOP(" + scope.Search.Limit + ")"
}
return ""
} }
func (s *Scope) offsetSql() string { func (scope *Scope) offsetSql() string {
if len(s.Search.Offset) == 0 { if len(scope.Search.Offset) == 0 {
return "" return ""
} else { }
if s.Dialect().HasTop() {
sql := " OFFSET " + s.Search.Offset + " ROW " if scope.Dialect().HasTop() {
if len(s.Search.Limit) > 0 { sql := " OFFSET " + scope.Search.Offset + " ROW "
sql += "FETCH NEXT " + s.Search.Limit + " ROWS ONLY" if len(scope.Search.Limit) > 0 {
sql += "FETCH NEXT " + scope.Search.Limit + " ROWS ONLY"
} }
return sql return sql
} else {
return " OFFSET " + s.Search.Offset
}
} }
return " OFFSET " + scope.Search.Offset
} }
func (s *Scope) groupSql() string { func (scope *Scope) groupSql() string {
if len(s.Search.Group) == 0 { if len(scope.Search.Group) == 0 {
return "" return ""
} else {
return " GROUP BY " + s.Search.Group
} }
return " GROUP BY " + scope.Search.Group
} }
func (s *Scope) havingSql() string { func (scope *Scope) havingSql() string {
if s.Search.HavingCondition == nil { if scope.Search.HavingCondition == nil {
return "" return ""
} else {
return " HAVING " + s.buildWhereCondition(s.Search.HavingCondition)
} }
return " HAVING " + scope.buildWhereCondition(scope.Search.HavingCondition)
} }
func (s *Scope) joinsSql() string { func (scope *Scope) joinsSql() string {
return s.Search.Joins + " " return scope.Search.Joins + " "
} }
func (scope *Scope) prepareQuerySql() { func (scope *Scope) prepareQuerySql() {
@ -415,9 +408,9 @@ func (scope *Scope) typeName() string {
value := scope.IndirectValue() value := scope.IndirectValue()
if value.Kind() == reflect.Slice { if value.Kind() == reflect.Slice {
return value.Type().Elem().Name() return value.Type().Elem().Name()
} else {
return value.Type().Name()
} }
return value.Type().Name()
} }
func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
@ -557,9 +550,9 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
} }
func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) { func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) {
var table string = scope.TableName() var table = scope.TableName()
var keyName string = fmt.Sprintf("%s_%s_foreign", table, field) var keyName = fmt.Sprintf("%s_%s_foreign", table, field)
var query string = ` var query = `
ALTER TABLE %s ALTER TABLE %s
ADD CONSTRAINT %s ADD CONSTRAINT %s
FOREIGN KEY (%s) FOREIGN KEY (%s)

View File

@ -32,9 +32,8 @@ func (s *sqlite3) SqlTag(value reflect.Value, size int) string {
case reflect.String: case reflect.String:
if size > 0 && size < 65532 { if size > 0 && size < 65532 {
return fmt.Sprintf("varchar(%d)", size) return fmt.Sprintf("varchar(%d)", size)
} else {
return "text"
} }
return "text"
case reflect.Struct: case reflect.Struct:
if value.Type() == timeType { if value.Type() == timeType {
return "datetime" return "datetime"

View File

@ -34,7 +34,7 @@ func FieldValueByName(name string, value interface{}) (i interface{}, err error)
if field := data.FieldByName(name); field.IsValid() { if field := data.FieldByName(name); field.IsValid() {
i = field.Interface() i = field.Interface()
} else { } else {
return nil, errors.New(fmt.Sprintf("struct has no field with name %s", name)) return nil, fmt.Errorf("struct has no field with name %s", name)
} }
} else { } else {
return nil, errors.New("value must be of kind struct") return nil, errors.New("value must be of kind struct")