forked from mirror/gorm
Merge branch 'adamar-master'
This commit is contained in:
commit
b8976ab20f
36
callback.go
36
callback.go
|
@ -9,10 +9,10 @@ type callback struct {
|
|||
updates []*func(scope *Scope)
|
||||
deletes []*func(scope *Scope)
|
||||
queries []*func(scope *Scope)
|
||||
processors []*callback_processor
|
||||
processors []*callbackProcessor
|
||||
}
|
||||
|
||||
type callback_processor struct {
|
||||
type callbackProcessor struct {
|
||||
name string
|
||||
before string
|
||||
after string
|
||||
|
@ -23,8 +23,8 @@ type callback_processor struct {
|
|||
callback *callback
|
||||
}
|
||||
|
||||
func (c *callback) addProcessor(typ string) *callback_processor {
|
||||
cp := &callback_processor{typ: typ, callback: c}
|
||||
func (c *callback) addProcessor(typ string) *callbackProcessor {
|
||||
cp := &callbackProcessor{typ: typ, callback: c}
|
||||
c.processors = append(c.processors, cp)
|
||||
return cp
|
||||
}
|
||||
|
@ -33,46 +33,46 @@ func (c *callback) clone() *callback {
|
|||
return &callback{processors: c.processors}
|
||||
}
|
||||
|
||||
func (c *callback) Create() *callback_processor {
|
||||
func (c *callback) Create() *callbackProcessor {
|
||||
return c.addProcessor("create")
|
||||
}
|
||||
|
||||
func (c *callback) Update() *callback_processor {
|
||||
func (c *callback) Update() *callbackProcessor {
|
||||
return c.addProcessor("update")
|
||||
}
|
||||
|
||||
func (c *callback) Delete() *callback_processor {
|
||||
func (c *callback) Delete() *callbackProcessor {
|
||||
return c.addProcessor("delete")
|
||||
}
|
||||
|
||||
func (c *callback) Query() *callback_processor {
|
||||
func (c *callback) Query() *callbackProcessor {
|
||||
return c.addProcessor("query")
|
||||
}
|
||||
|
||||
func (cp *callback_processor) Before(name string) *callback_processor {
|
||||
func (cp *callbackProcessor) Before(name string) *callbackProcessor {
|
||||
cp.before = name
|
||||
return cp
|
||||
}
|
||||
|
||||
func (cp *callback_processor) After(name string) *callback_processor {
|
||||
func (cp *callbackProcessor) After(name string) *callbackProcessor {
|
||||
cp.after = name
|
||||
return cp
|
||||
}
|
||||
|
||||
func (cp *callback_processor) Register(name string, fc func(scope *Scope)) {
|
||||
func (cp *callbackProcessor) Register(name string, fc func(scope *Scope)) {
|
||||
cp.name = name
|
||||
cp.processor = &fc
|
||||
cp.callback.sort()
|
||||
}
|
||||
|
||||
func (cp *callback_processor) Remove(name string) {
|
||||
func (cp *callbackProcessor) Remove(name string) {
|
||||
fmt.Printf("[info] removing callback `%v` from %v\n", name, fileWithLineNum())
|
||||
cp.name = name
|
||||
cp.remove = true
|
||||
cp.callback.sort()
|
||||
}
|
||||
|
||||
func (cp *callback_processor) Replace(name string, fc func(scope *Scope)) {
|
||||
func (cp *callbackProcessor) Replace(name string, fc func(scope *Scope)) {
|
||||
fmt.Printf("[info] replacing callback `%v` from %v\n", name, fileWithLineNum())
|
||||
cp.name = name
|
||||
cp.processor = &fc
|
||||
|
@ -89,8 +89,8 @@ func getRIndex(strs []string, str string) int {
|
|||
return -1
|
||||
}
|
||||
|
||||
func sortProcessors(cps []*callback_processor) []*func(scope *Scope) {
|
||||
var sortCallbackProcessor func(c *callback_processor)
|
||||
func sortProcessors(cps []*callbackProcessor) []*func(scope *Scope) {
|
||||
var sortCallbackProcessor func(c *callbackProcessor)
|
||||
var names, sortedNames = []string{}, []string{}
|
||||
|
||||
for _, cp := range cps {
|
||||
|
@ -102,7 +102,7 @@ func sortProcessors(cps []*callback_processor) []*func(scope *Scope) {
|
|||
names = append(names, cp.name)
|
||||
}
|
||||
|
||||
sortCallbackProcessor = func(c *callback_processor) {
|
||||
sortCallbackProcessor = func(c *callbackProcessor) {
|
||||
if getRIndex(sortedNames, c.name) > -1 {
|
||||
return
|
||||
}
|
||||
|
@ -162,7 +162,7 @@ func sortProcessors(cps []*callback_processor) []*func(scope *Scope) {
|
|||
}
|
||||
|
||||
func (c *callback) sort() {
|
||||
creates, updates, deletes, queries := []*callback_processor{}, []*callback_processor{}, []*callback_processor{}, []*callback_processor{}
|
||||
creates, updates, deletes, queries := []*callbackProcessor{}, []*callbackProcessor{}, []*callbackProcessor{}, []*callbackProcessor{}
|
||||
|
||||
for _, processor := range c.processors {
|
||||
switch processor.typ {
|
||||
|
@ -183,4 +183,4 @@ func (c *callback) sort() {
|
|||
c.queries = sortProcessors(queries)
|
||||
}
|
||||
|
||||
var DefaultCallback = &callback{processors: []*callback_processor{}}
|
||||
var DefaultCallback = &callback{processors: []*callbackProcessor{}}
|
||||
|
|
|
@ -50,7 +50,7 @@ func Query(scope *Scope) {
|
|||
columns, _ := rows.Columns()
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
scope.db.RowsAffected += 1
|
||||
scope.db.RowsAffected++
|
||||
|
||||
anyRecordFound = true
|
||||
elem := dest
|
||||
|
|
|
@ -16,97 +16,97 @@ func equalFuncs(funcs []*func(s *Scope), fnames []string) bool {
|
|||
return reflect.DeepEqual(names, fnames)
|
||||
}
|
||||
|
||||
func create(s *Scope) {}
|
||||
func before_create1(s *Scope) {}
|
||||
func before_create2(s *Scope) {}
|
||||
func after_create1(s *Scope) {}
|
||||
func after_create2(s *Scope) {}
|
||||
func create(s *Scope) {}
|
||||
func beforeCreate1(s *Scope) {}
|
||||
func beforeCreate2(s *Scope) {}
|
||||
func afterCreate1(s *Scope) {}
|
||||
func afterCreate2(s *Scope) {}
|
||||
|
||||
func TestRegisterCallback(t *testing.T) {
|
||||
var callback = &callback{processors: []*callback_processor{}}
|
||||
var callback = &callback{processors: []*callbackProcessor{}}
|
||||
|
||||
callback.Create().Register("before_create1", before_create1)
|
||||
callback.Create().Register("before_create2", before_create2)
|
||||
callback.Create().Register("before_create1", beforeCreate1)
|
||||
callback.Create().Register("before_create2", beforeCreate2)
|
||||
callback.Create().Register("create", create)
|
||||
callback.Create().Register("after_create1", after_create1)
|
||||
callback.Create().Register("after_create2", after_create2)
|
||||
callback.Create().Register("after_create1", afterCreate1)
|
||||
callback.Create().Register("after_create2", afterCreate2)
|
||||
|
||||
if !equalFuncs(callback.creates, []string{"before_create1", "before_create2", "create", "after_create1", "after_create2"}) {
|
||||
if !equalFuncs(callback.creates, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
|
||||
t.Errorf("register callback")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterCallbackWithOrder(t *testing.T) {
|
||||
var callback1 = &callback{processors: []*callback_processor{}}
|
||||
callback1.Create().Register("before_create1", before_create1)
|
||||
var callback1 = &callback{processors: []*callbackProcessor{}}
|
||||
callback1.Create().Register("before_create1", beforeCreate1)
|
||||
callback1.Create().Register("create", create)
|
||||
callback1.Create().Register("after_create1", after_create1)
|
||||
callback1.Create().Before("after_create1").Register("after_create2", after_create2)
|
||||
if !equalFuncs(callback1.creates, []string{"before_create1", "create", "after_create2", "after_create1"}) {
|
||||
callback1.Create().Register("after_create1", afterCreate1)
|
||||
callback1.Create().Before("after_create1").Register("after_create2", afterCreate2)
|
||||
if !equalFuncs(callback1.creates, []string{"beforeCreate1", "create", "afterCreate2", "afterCreate1"}) {
|
||||
t.Errorf("register callback with order")
|
||||
}
|
||||
|
||||
var callback2 = &callback{processors: []*callback_processor{}}
|
||||
var callback2 = &callback{processors: []*callbackProcessor{}}
|
||||
|
||||
callback2.Update().Register("create", create)
|
||||
callback2.Update().Before("create").Register("before_create1", before_create1)
|
||||
callback2.Update().After("after_create2").Register("after_create1", after_create1)
|
||||
callback2.Update().Before("before_create1").Register("before_create2", before_create2)
|
||||
callback2.Update().Register("after_create2", after_create2)
|
||||
callback2.Update().Before("create").Register("before_create1", beforeCreate1)
|
||||
callback2.Update().After("after_create2").Register("after_create1", afterCreate1)
|
||||
callback2.Update().Before("before_create1").Register("before_create2", beforeCreate2)
|
||||
callback2.Update().Register("after_create2", afterCreate2)
|
||||
|
||||
if !equalFuncs(callback2.updates, []string{"before_create2", "before_create1", "create", "after_create2", "after_create1"}) {
|
||||
if !equalFuncs(callback2.updates, []string{"beforeCreate2", "beforeCreate1", "create", "afterCreate2", "afterCreate1"}) {
|
||||
t.Errorf("register callback with order")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterCallbackWithComplexOrder(t *testing.T) {
|
||||
var callback1 = &callback{processors: []*callback_processor{}}
|
||||
var callback1 = &callback{processors: []*callbackProcessor{}}
|
||||
|
||||
callback1.Query().Before("after_create1").After("before_create1").Register("create", create)
|
||||
callback1.Query().Register("before_create1", before_create1)
|
||||
callback1.Query().Register("after_create1", after_create1)
|
||||
callback1.Query().Register("before_create1", beforeCreate1)
|
||||
callback1.Query().Register("after_create1", afterCreate1)
|
||||
|
||||
if !equalFuncs(callback1.queries, []string{"before_create1", "create", "after_create1"}) {
|
||||
if !equalFuncs(callback1.queries, []string{"beforeCreate1", "create", "afterCreate1"}) {
|
||||
t.Errorf("register callback with order")
|
||||
}
|
||||
|
||||
var callback2 = &callback{processors: []*callback_processor{}}
|
||||
var callback2 = &callback{processors: []*callbackProcessor{}}
|
||||
|
||||
callback2.Delete().Before("after_create1").After("before_create1").Register("create", create)
|
||||
callback2.Delete().Before("create").Register("before_create1", before_create1)
|
||||
callback2.Delete().After("before_create1").Register("before_create2", before_create2)
|
||||
callback2.Delete().Register("after_create1", after_create1)
|
||||
callback2.Delete().After("after_create1").Register("after_create2", after_create2)
|
||||
callback2.Delete().Before("create").Register("before_create1", beforeCreate1)
|
||||
callback2.Delete().After("before_create1").Register("before_create2", beforeCreate2)
|
||||
callback2.Delete().Register("after_create1", afterCreate1)
|
||||
callback2.Delete().After("after_create1").Register("after_create2", afterCreate2)
|
||||
|
||||
if !equalFuncs(callback2.deletes, []string{"before_create1", "before_create2", "create", "after_create1", "after_create2"}) {
|
||||
if !equalFuncs(callback2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
|
||||
t.Errorf("register callback with order")
|
||||
}
|
||||
}
|
||||
|
||||
func replace_create(s *Scope) {}
|
||||
func replaceCreate(s *Scope) {}
|
||||
|
||||
func TestReplaceCallback(t *testing.T) {
|
||||
var callback = &callback{processors: []*callback_processor{}}
|
||||
var callback = &callback{processors: []*callbackProcessor{}}
|
||||
|
||||
callback.Create().Before("after_create1").After("before_create1").Register("create", create)
|
||||
callback.Create().Register("before_create1", before_create1)
|
||||
callback.Create().Register("after_create1", after_create1)
|
||||
callback.Create().Replace("create", replace_create)
|
||||
callback.Create().Register("before_create1", beforeCreate1)
|
||||
callback.Create().Register("after_create1", afterCreate1)
|
||||
callback.Create().Replace("create", replaceCreate)
|
||||
|
||||
if !equalFuncs(callback.creates, []string{"before_create1", "replace_create", "after_create1"}) {
|
||||
if !equalFuncs(callback.creates, []string{"beforeCreate1", "replaceCreate", "afterCreate1"}) {
|
||||
t.Errorf("replace callback")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveCallback(t *testing.T) {
|
||||
var callback = &callback{processors: []*callback_processor{}}
|
||||
var callback = &callback{processors: []*callbackProcessor{}}
|
||||
|
||||
callback.Create().Before("after_create1").After("before_create1").Register("create", create)
|
||||
callback.Create().Register("before_create1", before_create1)
|
||||
callback.Create().Register("after_create1", after_create1)
|
||||
callback.Create().Register("before_create1", beforeCreate1)
|
||||
callback.Create().Register("after_create1", afterCreate1)
|
||||
callback.Create().Remove("create")
|
||||
|
||||
if !equalFuncs(callback.creates, []string{"before_create1", "after_create1"}) {
|
||||
if !equalFuncs(callback.creates, []string{"beforeCreate1", "afterCreate1"}) {
|
||||
t.Errorf("remove callback")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,7 +20,7 @@ func (s *commonDialect) HasTop() bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func (d *commonDialect) SqlTag(value reflect.Value, size int) string {
|
||||
func (s *commonDialect) SqlTag(value reflect.Value, size int) string {
|
||||
switch value.Kind() {
|
||||
case reflect.Bool:
|
||||
return "BOOLEAN"
|
||||
|
@ -33,9 +33,8 @@ func (d *commonDialect) SqlTag(value reflect.Value, size int) string {
|
|||
case reflect.String:
|
||||
if size > 0 && size < 65532 {
|
||||
return fmt.Sprintf("VARCHAR(%d)", size)
|
||||
} else {
|
||||
return "VARCHAR(65532)"
|
||||
}
|
||||
return "VARCHAR(65532)"
|
||||
case reflect.Struct:
|
||||
if value.Type() == timeType {
|
||||
return "TIMESTAMP"
|
||||
|
@ -44,21 +43,20 @@ func (d *commonDialect) SqlTag(value reflect.Value, size int) string {
|
|||
if _, ok := value.Interface().([]byte); ok {
|
||||
if size > 0 && size < 65532 {
|
||||
return fmt.Sprintf("BINARY(%d)", size)
|
||||
} else {
|
||||
return "BINARY(65532)"
|
||||
}
|
||||
return "BINARY(65532)"
|
||||
}
|
||||
}
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", value.Type().Name(), value.Kind().String()))
|
||||
}
|
||||
|
||||
func (s *commonDialect) PrimaryKeyTag(value reflect.Value, size int) string {
|
||||
suffix_str := " NOT NULL PRIMARY KEY"
|
||||
suffix := " NOT NULL PRIMARY KEY"
|
||||
switch value.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
return "INTEGER" + suffix_str
|
||||
return "INTEGER" + suffix
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
return "BIGINT" + suffix_str
|
||||
return "BIGINT" + suffix
|
||||
default:
|
||||
panic("Invalid primary key type")
|
||||
}
|
||||
|
|
10
errors.go
10
errors.go
|
@ -3,9 +3,9 @@ package gorm
|
|||
import "errors"
|
||||
|
||||
var (
|
||||
RecordNotFound = errors.New("Record Not Found")
|
||||
InvalidSql = errors.New("Invalid SQL")
|
||||
NoNewAttrs = errors.New("No new Attributes")
|
||||
NoValidTransaction = errors.New("No valid transaction")
|
||||
CantStartTransaction = errors.New("Can't start transaction")
|
||||
RecordNotFound = errors.New("record not found")
|
||||
InvalidSql = errors.New("invalid sql")
|
||||
NoNewAttrs = errors.New("no new attributes")
|
||||
NoValidTransaction = errors.New("no valid transaction")
|
||||
CantStartTransaction = errors.New("can't start transaction")
|
||||
)
|
||||
|
|
3
main.go
3
main.go
|
@ -265,9 +265,8 @@ func (s *DB) Save(value interface{}) *DB {
|
|||
scope := s.clone().NewScope(value)
|
||||
if scope.PrimaryKeyZero() {
|
||||
return scope.callCallbacks(s.parent.callback.creates).db
|
||||
} else {
|
||||
return scope.callCallbacks(s.parent.callback.updates).db
|
||||
}
|
||||
return scope.callCallbacks(s.parent.callback.updates).db
|
||||
}
|
||||
|
||||
func (s *DB) Create(value interface{}) *DB {
|
||||
|
|
16
main_test.go
16
main_test.go
|
@ -588,10 +588,10 @@ func BenchmarkGorm(b *testing.B) {
|
|||
func BenchmarkRawSql(b *testing.B) {
|
||||
DB, _ := sql.Open("postgres", "user=gorm DB.ame=gorm sslmode=disable")
|
||||
DB.SetMaxIdleConns(10)
|
||||
insert_sql := "INSERT INTO emails (user_id,email,user_agent,registered_at,created_at,updated_at) VALUES ($1,$2,$3,$4,$5,$6) RETURNING id"
|
||||
query_sql := "SELECT * FROM emails WHERE email = $1 ORDER BY id LIMIT 1"
|
||||
update_sql := "UPDATE emails SET email = $1, updated_at = $2 WHERE id = $3"
|
||||
delete_sql := "DELETE FROM orders WHERE id = $1"
|
||||
insertSql := "INSERT INTO emails (user_id,email,user_agent,registered_at,created_at,updated_at) VALUES ($1,$2,$3,$4,$5,$6) RETURNING id"
|
||||
querySql := "SELECT * FROM emails WHERE email = $1 ORDER BY id LIMIT 1"
|
||||
updateSql := "UPDATE emails SET email = $1, updated_at = $2 WHERE id = $3"
|
||||
deleteSql := "DELETE FROM orders WHERE id = $1"
|
||||
|
||||
b.N = 2000
|
||||
for x := 0; x < b.N; x++ {
|
||||
|
@ -599,13 +599,13 @@ func BenchmarkRawSql(b *testing.B) {
|
|||
e := strconv.Itoa(x) + "benchmark@example.org"
|
||||
email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()}
|
||||
// Insert
|
||||
DB.QueryRow(insert_sql, email.UserId, email.Email, email.UserAgent, email.RegisteredAt, time.Now(), time.Now()).Scan(&id)
|
||||
DB.QueryRow(insertSql, email.UserId, email.Email, email.UserAgent, email.RegisteredAt, time.Now(), time.Now()).Scan(&id)
|
||||
// Query
|
||||
rows, _ := DB.Query(query_sql, email.Email)
|
||||
rows, _ := DB.Query(querySql, email.Email)
|
||||
rows.Close()
|
||||
// Update
|
||||
DB.Exec(update_sql, "new-"+e, time.Now(), id)
|
||||
DB.Exec(updateSql, "new-"+e, time.Now(), id)
|
||||
// Delete
|
||||
DB.Exec(delete_sql, id)
|
||||
DB.Exec(deleteSql, id)
|
||||
}
|
||||
}
|
||||
|
|
14
mssql.go
14
mssql.go
|
@ -20,7 +20,7 @@ func (s *mssql) HasTop() bool {
|
|||
return true
|
||||
}
|
||||
|
||||
func (d *mssql) SqlTag(value reflect.Value, size int) string {
|
||||
func (s *mssql) SqlTag(value reflect.Value, size int) string {
|
||||
switch value.Kind() {
|
||||
case reflect.Bool:
|
||||
return "bit"
|
||||
|
@ -33,9 +33,8 @@ func (d *mssql) SqlTag(value reflect.Value, size int) string {
|
|||
case reflect.String:
|
||||
if size > 0 && size < 65532 {
|
||||
return fmt.Sprintf("nvarchar(%d)", size)
|
||||
} else {
|
||||
return "text"
|
||||
}
|
||||
return "text"
|
||||
case reflect.Struct:
|
||||
if value.Type() == timeType {
|
||||
return "datetime2"
|
||||
|
@ -44,21 +43,20 @@ func (d *mssql) SqlTag(value reflect.Value, size int) string {
|
|||
if _, ok := value.Interface().([]byte); ok {
|
||||
if size > 0 && size < 65532 {
|
||||
return fmt.Sprintf("varchar(%d)", size)
|
||||
} else {
|
||||
return "text"
|
||||
}
|
||||
return "text"
|
||||
}
|
||||
}
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String()))
|
||||
}
|
||||
|
||||
func (s *mssql) PrimaryKeyTag(value reflect.Value, size int) string {
|
||||
suffix_str := " IDENTITY(1,1) PRIMARY KEY"
|
||||
suffix := " IDENTITY(1,1) PRIMARY KEY"
|
||||
switch value.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
return "int" + suffix_str
|
||||
return "int" + suffix
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
return "bigint" + suffix_str
|
||||
return "bigint" + suffix
|
||||
default:
|
||||
panic("Invalid primary key type")
|
||||
}
|
||||
|
|
14
mysql.go
14
mysql.go
|
@ -20,7 +20,7 @@ func (s *mysql) HasTop() bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func (d *mysql) SqlTag(value reflect.Value, size int) string {
|
||||
func (s *mysql) SqlTag(value reflect.Value, size int) string {
|
||||
switch value.Kind() {
|
||||
case reflect.Bool:
|
||||
return "boolean"
|
||||
|
@ -33,9 +33,8 @@ func (d *mysql) SqlTag(value reflect.Value, size int) string {
|
|||
case reflect.String:
|
||||
if size > 0 && size < 65532 {
|
||||
return fmt.Sprintf("varchar(%d)", size)
|
||||
} else {
|
||||
return "longtext"
|
||||
}
|
||||
return "longtext"
|
||||
case reflect.Struct:
|
||||
if value.Type() == timeType {
|
||||
return "timestamp NULL"
|
||||
|
@ -44,21 +43,20 @@ func (d *mysql) SqlTag(value reflect.Value, size int) string {
|
|||
if _, ok := value.Interface().([]byte); ok {
|
||||
if size > 0 && size < 65532 {
|
||||
return fmt.Sprintf("varbinary(%d)", size)
|
||||
} else {
|
||||
return "longblob"
|
||||
}
|
||||
return "longblob"
|
||||
}
|
||||
}
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String()))
|
||||
}
|
||||
|
||||
func (s *mysql) PrimaryKeyTag(value reflect.Value, size int) string {
|
||||
suffix_str := " NOT NULL AUTO_INCREMENT PRIMARY KEY"
|
||||
suffix := " NOT NULL AUTO_INCREMENT PRIMARY KEY"
|
||||
switch value.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
return "int" + suffix_str
|
||||
return "int" + suffix
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
return "bigint" + suffix_str
|
||||
return "bigint" + suffix
|
||||
default:
|
||||
panic("Invalid primary key type")
|
||||
}
|
||||
|
|
|
@ -24,7 +24,7 @@ func (s *postgres) HasTop() bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func (d *postgres) SqlTag(value reflect.Value, size int) string {
|
||||
func (s *postgres) SqlTag(value reflect.Value, size int) string {
|
||||
switch value.Kind() {
|
||||
case reflect.Bool:
|
||||
return "boolean"
|
||||
|
|
91
scope.go
91
scope.go
|
@ -41,7 +41,7 @@ func (db *DB) NewScope(value interface{}) *Scope {
|
|||
func (scope *Scope) NeedPtr() *Scope {
|
||||
reflectKind := reflect.ValueOf(scope.Value).Kind()
|
||||
if !((reflectKind == reflect.Invalid) || (reflectKind == reflect.Ptr)) {
|
||||
err := errors.New(fmt.Sprintf("%v %v\n", fileWithLineNum(), "using unaddressable value"))
|
||||
err := fmt.Errorf("%v %v\n", fileWithLineNum(), "using unaddressable value")
|
||||
scope.Err(err)
|
||||
fmt.Printf(err.Error())
|
||||
}
|
||||
|
@ -125,9 +125,8 @@ func (scope *Scope) PrimaryKeyField() *Field {
|
|||
func (scope *Scope) PrimaryKey() string {
|
||||
if field := scope.PrimaryKeyField(); field != nil {
|
||||
return field.DBName
|
||||
} else {
|
||||
return ""
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// PrimaryKeyZero check the primary key is blank or not
|
||||
|
@ -139,9 +138,8 @@ func (scope *Scope) PrimaryKeyZero() bool {
|
|||
func (scope *Scope) PrimaryKeyValue() interface{} {
|
||||
if field := scope.PrimaryKeyField(); field != nil {
|
||||
return field.Field.Interface()
|
||||
} else {
|
||||
return 0
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// HasColumn to check if has column
|
||||
|
@ -207,7 +205,7 @@ func (scope *Scope) CallMethod(name string) {
|
|||
case func(s *DB) error:
|
||||
scope.Err(f(scope.db.New()))
|
||||
default:
|
||||
scope.Err(errors.New(fmt.Sprintf("unsupported function %v", name)))
|
||||
scope.Err(fmt.Errorf("unsupported function %v", name))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -234,53 +232,54 @@ var pluralMapValues = []string{"ches", "sses", "shes", "days", "ies", "xes", "${
|
|||
func (scope *Scope) TableName() string {
|
||||
if scope.Search != nil && len(scope.Search.TableName) > 0 {
|
||||
return scope.Search.TableName
|
||||
} else {
|
||||
if scope.Value == nil {
|
||||
scope.Err(errors.New("can't get table name"))
|
||||
return ""
|
||||
}
|
||||
|
||||
data := scope.IndirectValue()
|
||||
if data.Kind() == reflect.Slice {
|
||||
elem := data.Type().Elem()
|
||||
if elem.Kind() == reflect.Ptr {
|
||||
elem = elem.Elem()
|
||||
}
|
||||
data = reflect.New(elem).Elem()
|
||||
}
|
||||
|
||||
if fm := data.MethodByName("TableName"); fm.IsValid() {
|
||||
if v := fm.Call([]reflect.Value{}); len(v) > 0 {
|
||||
if result, ok := v[0].Interface().(string); ok {
|
||||
return result
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
str := ToSnake(data.Type().Name())
|
||||
|
||||
if scope.db == nil || !scope.db.parent.singularTable {
|
||||
for index, reg := range pluralMapKeys {
|
||||
if reg.MatchString(str) {
|
||||
return reg.ReplaceAllString(str, pluralMapValues[index])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return str
|
||||
}
|
||||
|
||||
if scope.Value == nil {
|
||||
scope.Err(errors.New("can't get table name"))
|
||||
return ""
|
||||
}
|
||||
|
||||
data := scope.IndirectValue()
|
||||
if data.Kind() == reflect.Slice {
|
||||
elem := data.Type().Elem()
|
||||
if elem.Kind() == reflect.Ptr {
|
||||
elem = elem.Elem()
|
||||
}
|
||||
data = reflect.New(elem).Elem()
|
||||
}
|
||||
|
||||
if fm := data.MethodByName("TableName"); fm.IsValid() {
|
||||
if v := fm.Call([]reflect.Value{}); len(v) > 0 {
|
||||
if result, ok := v[0].Interface().(string); ok {
|
||||
return result
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
str := ToSnake(data.Type().Name())
|
||||
|
||||
if scope.db == nil || !scope.db.parent.singularTable {
|
||||
for index, reg := range pluralMapKeys {
|
||||
if reg.MatchString(str) {
|
||||
return reg.ReplaceAllString(str, pluralMapValues[index])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return str
|
||||
}
|
||||
|
||||
func (scope *Scope) QuotedTableName() string {
|
||||
if scope.Search != nil && len(scope.Search.TableName) > 0 {
|
||||
return scope.Search.TableName
|
||||
} else {
|
||||
keys := strings.Split(scope.TableName(), ".")
|
||||
for i, v := range keys {
|
||||
keys[i] = scope.Quote(v)
|
||||
}
|
||||
return strings.Join(keys, ".")
|
||||
}
|
||||
|
||||
keys := strings.Split(scope.TableName(), ".")
|
||||
for i, v := range keys {
|
||||
keys[i] = scope.Quote(v)
|
||||
}
|
||||
return strings.Join(keys, ".")
|
||||
|
||||
}
|
||||
|
||||
// CombinedConditionSql get combined condition sql
|
||||
|
|
|
@ -215,87 +215,80 @@ func (scope *Scope) whereSql() (sql string) {
|
|||
return
|
||||
}
|
||||
|
||||
func (s *Scope) selectSql() string {
|
||||
if len(s.Search.Selects) == 0 {
|
||||
func (scope *Scope) selectSql() string {
|
||||
if len(scope.Search.Selects) == 0 {
|
||||
return "*"
|
||||
}
|
||||
|
||||
var selectQueries []string
|
||||
|
||||
for _, clause := range s.Search.Selects {
|
||||
selectQueries = append(selectQueries, s.buildSelectQuery(clause))
|
||||
for _, clause := range scope.Search.Selects {
|
||||
selectQueries = append(selectQueries, scope.buildSelectQuery(clause))
|
||||
}
|
||||
|
||||
return strings.Join(selectQueries, ", ")
|
||||
|
||||
}
|
||||
|
||||
func (s *Scope) orderSql() string {
|
||||
if len(s.Search.Orders) == 0 {
|
||||
func (scope *Scope) orderSql() string {
|
||||
if len(scope.Search.Orders) == 0 {
|
||||
return ""
|
||||
} else {
|
||||
return " ORDER BY " + strings.Join(s.Search.Orders, ",")
|
||||
}
|
||||
return " ORDER BY " + strings.Join(scope.Search.Orders, ",")
|
||||
}
|
||||
|
||||
func (s *Scope) limitSql() string {
|
||||
if !s.Dialect().HasTop() {
|
||||
if len(s.Search.Limit) == 0 {
|
||||
func (scope *Scope) limitSql() string {
|
||||
if !scope.Dialect().HasTop() {
|
||||
if len(scope.Search.Limit) == 0 {
|
||||
return ""
|
||||
} else {
|
||||
return " LIMIT " + s.Search.Limit
|
||||
}
|
||||
} else {
|
||||
return ""
|
||||
return " LIMIT " + scope.Search.Limit
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *Scope) topSql() string {
|
||||
if s.Dialect().HasTop() && len(s.Search.Offset) == 0 {
|
||||
if len(s.Search.Limit) == 0 {
|
||||
func (scope *Scope) topSql() string {
|
||||
if scope.Dialect().HasTop() && len(scope.Search.Offset) == 0 {
|
||||
if len(scope.Search.Limit) == 0 {
|
||||
return ""
|
||||
} else {
|
||||
return " TOP(" + s.Search.Limit + ")"
|
||||
}
|
||||
} else {
|
||||
return ""
|
||||
return " TOP(" + scope.Search.Limit + ")"
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *Scope) offsetSql() string {
|
||||
if len(s.Search.Offset) == 0 {
|
||||
func (scope *Scope) offsetSql() string {
|
||||
if len(scope.Search.Offset) == 0 {
|
||||
return ""
|
||||
} else {
|
||||
if s.Dialect().HasTop() {
|
||||
sql := " OFFSET " + s.Search.Offset + " ROW "
|
||||
if len(s.Search.Limit) > 0 {
|
||||
sql += "FETCH NEXT " + s.Search.Limit + " ROWS ONLY"
|
||||
}
|
||||
return sql
|
||||
} else {
|
||||
return " OFFSET " + s.Search.Offset
|
||||
}
|
||||
|
||||
if scope.Dialect().HasTop() {
|
||||
sql := " OFFSET " + scope.Search.Offset + " ROW "
|
||||
if len(scope.Search.Limit) > 0 {
|
||||
sql += "FETCH NEXT " + scope.Search.Limit + " ROWS ONLY"
|
||||
}
|
||||
return sql
|
||||
}
|
||||
return " OFFSET " + scope.Search.Offset
|
||||
}
|
||||
|
||||
func (s *Scope) groupSql() string {
|
||||
if len(s.Search.Group) == 0 {
|
||||
func (scope *Scope) groupSql() string {
|
||||
if len(scope.Search.Group) == 0 {
|
||||
return ""
|
||||
} else {
|
||||
return " GROUP BY " + s.Search.Group
|
||||
}
|
||||
return " GROUP BY " + scope.Search.Group
|
||||
}
|
||||
|
||||
func (s *Scope) havingSql() string {
|
||||
if s.Search.HavingCondition == nil {
|
||||
func (scope *Scope) havingSql() string {
|
||||
if scope.Search.HavingCondition == nil {
|
||||
return ""
|
||||
} else {
|
||||
return " HAVING " + s.buildWhereCondition(s.Search.HavingCondition)
|
||||
}
|
||||
return " HAVING " + scope.buildWhereCondition(scope.Search.HavingCondition)
|
||||
}
|
||||
|
||||
func (s *Scope) joinsSql() string {
|
||||
return s.Search.Joins + " "
|
||||
func (scope *Scope) joinsSql() string {
|
||||
return scope.Search.Joins + " "
|
||||
}
|
||||
|
||||
func (scope *Scope) prepareQuerySql() {
|
||||
|
@ -479,9 +472,9 @@ func (scope *Scope) typeName() string {
|
|||
value := scope.IndirectValue()
|
||||
if value.Kind() == reflect.Slice {
|
||||
return value.Type().Elem().Name()
|
||||
} else {
|
||||
return value.Type().Name()
|
||||
}
|
||||
|
||||
return value.Type().Name()
|
||||
}
|
||||
|
||||
func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
|
||||
|
@ -634,9 +627,9 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
|
|||
}
|
||||
|
||||
func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) {
|
||||
var table string = scope.TableName()
|
||||
var keyName string = fmt.Sprintf("%s_%s_foreign", table, field)
|
||||
var query string = `
|
||||
var table = scope.TableName()
|
||||
var keyName = fmt.Sprintf("%s_%s_foreign", table, field)
|
||||
var query = `
|
||||
ALTER TABLE %s
|
||||
ADD CONSTRAINT %s
|
||||
FOREIGN KEY (%s)
|
||||
|
|
|
@ -32,9 +32,8 @@ func (s *sqlite3) SqlTag(value reflect.Value, size int) string {
|
|||
case reflect.String:
|
||||
if size > 0 && size < 65532 {
|
||||
return fmt.Sprintf("varchar(%d)", size)
|
||||
} else {
|
||||
return "text"
|
||||
}
|
||||
return "text"
|
||||
case reflect.Struct:
|
||||
if value.Type() == timeType {
|
||||
return "datetime"
|
||||
|
|
2
utils.go
2
utils.go
|
@ -34,7 +34,7 @@ func FieldValueByName(name string, value interface{}) (i interface{}, err error)
|
|||
if field := data.FieldByName(name); field.IsValid() {
|
||||
i = field.Interface()
|
||||
} else {
|
||||
return nil, errors.New(fmt.Sprintf("struct has no field with name %s", name))
|
||||
return nil, fmt.Errorf("struct has no field with name %s", name)
|
||||
}
|
||||
} else {
|
||||
return nil, errors.New("value must be of kind struct")
|
||||
|
|
Loading…
Reference in New Issue