mirror of https://github.com/go-gorm/gorm.git
Add Tag Support
This commit is contained in:
parent
5341a5ed43
commit
c03999a44e
27
README.md
27
README.md
|
@ -11,6 +11,7 @@ Yet Another ORM library for Go, aims for developer friendly
|
|||
* Auto Migration
|
||||
* Transaction
|
||||
* Logger Support
|
||||
* Bind struct with tag
|
||||
* Every feature comes with tests
|
||||
* Convention Over Configuration
|
||||
* Developer Friendly
|
||||
|
@ -22,30 +23,32 @@ type User struct { // TableName: `users`, gorm will pluralize struct's n
|
|||
Id int64 // Id: Database Primary key
|
||||
Birthday time.Time
|
||||
Age int64
|
||||
Name string
|
||||
Name string `sql:"size:255"` // set this field's length and as not null with tag
|
||||
CreatedAt time.Time // Time of record is created, will be insert automatically
|
||||
UpdatedAt time.Time // Time of record is updated, will be updated automatically
|
||||
DeletedAt time.Time // Time of record is deleted, refer `Soft Delete` for more
|
||||
|
||||
Email []Email // Embedded structs
|
||||
BillingAddress Address // Embedded struct
|
||||
BillingAddressId sql.NullInt64 // Embedded struct BillingAddress's foreign key
|
||||
ShippingAddress Address // Embedded struct
|
||||
ShippingAddressId int64 // Embedded struct ShippingAddress's foreign key
|
||||
Email []Email // Embedded structs
|
||||
BillingAddress Address // Embedded struct
|
||||
BillingAddressId sql.NullInt64 // Embedded struct BillingAddress's foreign key
|
||||
ShippingAddress Address // Embedded struct
|
||||
ShippingAddressId int64 // Embedded struct ShippingAddress's foreign key
|
||||
IgnoreMe int64 `sql:"-"` // Ignore this field with tag
|
||||
}
|
||||
|
||||
type Email struct { // TableName: `emails`
|
||||
Id int64
|
||||
UserId int64 // Foreign key for above embedded structs
|
||||
Email string
|
||||
Email string `sql:"type:varchar(100);"` // Set column type directly with tag
|
||||
Subscribed bool
|
||||
}
|
||||
|
||||
type Address struct { // TableName: `addresses`
|
||||
Id int64
|
||||
Address1 string
|
||||
Address2 string
|
||||
Post string
|
||||
Address1 string `sql:"not null;unique"` // Set column as unique with tag
|
||||
Address2 string `sql:"type:varchar(100);unique"`
|
||||
Post sql.NullString `sql:not null`
|
||||
// Be careful: "NOT NULL" will only works for NullXXX scanner, because golang will initalize a default value for most type...
|
||||
}
|
||||
```
|
||||
|
||||
|
@ -723,8 +726,8 @@ db.Where("email = ?", "x@example.org").Attrs(User{FromIp: "111.111.111.111"}).Fi
|
|||
|
||||
## TODO
|
||||
* Join, Having, Group, Includes
|
||||
* Scopes
|
||||
* Index, Unique, Valiations
|
||||
* Scopes, Valiations
|
||||
* AlertColumn, DropColumn, AddIndex, RemoveIndex
|
||||
|
||||
# Author
|
||||
|
||||
|
|
28
gorm_test.go
28
gorm_test.go
|
@ -18,7 +18,7 @@ type User struct {
|
|||
Id int64 // Id: Primary key
|
||||
Birthday time.Time // Time
|
||||
Age int64
|
||||
Name string
|
||||
Name string `sql:"size:255"`
|
||||
CreatedAt time.Time // CreatedAt: Time of record is created, will be insert automatically
|
||||
UpdatedAt time.Time // UpdatedAt: Time of record is updated, will be updated automatically
|
||||
DeletedAt time.Time // DeletedAt: Time of record is deleted, refer Soft Delete for more
|
||||
|
@ -28,6 +28,7 @@ type User struct {
|
|||
ShippingAddress Address // Embedded struct
|
||||
ShippingAddressId int64 // Embedded struct's foreign key
|
||||
CreditCard CreditCard
|
||||
IgnoreMe int64 `sql:"-"`
|
||||
}
|
||||
|
||||
type CreditCard struct {
|
||||
|
@ -42,7 +43,7 @@ type CreditCard struct {
|
|||
type Email struct {
|
||||
Id int64
|
||||
UserId int64
|
||||
Email string
|
||||
Email string `sql:"type:varchar(100); unique"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
@ -1279,7 +1280,6 @@ type BigEmail struct {
|
|||
RegisteredAt time.Time
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
DeletedAt time.Time
|
||||
}
|
||||
|
||||
func (b BigEmail) TableName() string {
|
||||
|
@ -1299,6 +1299,10 @@ func TestAutoMigration(t *testing.T) {
|
|||
if big_email.Email != "jinzhu@example.org" || big_email.UserAgent != "pc" || big_email.RegisteredAt.IsZero() {
|
||||
t.Error("Big Emails should be saved and fetched correctly")
|
||||
}
|
||||
|
||||
if err := db.Save(&BigEmail{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: time.Now()}).Error; err == nil {
|
||||
t.Error("Should not be able to save because of unique tag")
|
||||
}
|
||||
}
|
||||
|
||||
type NullTime struct {
|
||||
|
@ -1324,7 +1328,7 @@ func (nt NullTime) Value() (driver.Value, error) {
|
|||
|
||||
type NullValue struct {
|
||||
Id int64
|
||||
Name sql.NullString
|
||||
Name sql.NullString `sql:"not null"`
|
||||
Age sql.NullInt64
|
||||
Male sql.NullBool
|
||||
Height sql.NullFloat64
|
||||
|
@ -1355,6 +1359,10 @@ func TestSqlNullValue(t *testing.T) {
|
|||
if nv2.Name.String != "hello-2" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 || nv2.AddedAt.Valid != false {
|
||||
t.Errorf("Should be able to fetch null value")
|
||||
}
|
||||
|
||||
if err := db.Save(&NullValue{Name: sql.NullString{"hello-3", false}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), false}}).Error; err == nil {
|
||||
t.Errorf("Can't save because of name can't be null", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransaction(t *testing.T) {
|
||||
|
@ -1401,10 +1409,11 @@ func (s *CreditCard) BeforeSave() (err error) {
|
|||
|
||||
func BenchmarkGorm(b *testing.B) {
|
||||
for x := 0; x < b.N; x++ {
|
||||
email := BigEmail{Email: "benchmark@example.org", UserAgent: "pc", RegisteredAt: time.Now()}
|
||||
e := strconv.Itoa(x) + "benchmark@example.org"
|
||||
email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()}
|
||||
db.Save(&email)
|
||||
db.First(&BigEmail{}, "email = ?", "benchmark@benchmark.org")
|
||||
db.Model(&email).Update("email", "benchmark@benchmark.org")
|
||||
db.First(&BigEmail{}, "email = ?", e)
|
||||
db.Model(&email).Update("email", e)
|
||||
db.Delete(&email)
|
||||
}
|
||||
}
|
||||
|
@ -1419,11 +1428,12 @@ func BenchmarkRawSql(b *testing.B) {
|
|||
|
||||
var id int64
|
||||
for x := 0; x < b.N; x++ {
|
||||
email := BigEmail{Email: "benchmark@example.org", UserAgent: "pc", RegisteredAt: time.Now()}
|
||||
e := strconv.Itoa(x) + "benchmark@example.org"
|
||||
email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()}
|
||||
db.QueryRow(insert_sql, email.UserId, email.Email, email.UserAgent, email.RegisteredAt, time.Now(), time.Now()).Scan(&id)
|
||||
rows, _ := db.Query(query_sql, email.Email)
|
||||
rows.Close()
|
||||
db.Exec(update_sql, "benchmark@benchmark.org", time.Now(), id)
|
||||
db.Exec(update_sql, e, time.Now(), id)
|
||||
db.Exec(delete_sql, id)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -57,6 +57,10 @@ func (s *Chain) warn(v ...interface{}) {
|
|||
go s.print("warn", v...)
|
||||
}
|
||||
|
||||
func (s *Chain) info(v ...interface{}) {
|
||||
go s.print("info", v...)
|
||||
}
|
||||
|
||||
func (s *Chain) slog(sql string, t time.Time, vars ...interface{}) {
|
||||
go s.print("sql", time.Now().Sub(t), sql, vars)
|
||||
}
|
||||
|
|
9
main.go
9
main.go
|
@ -3,6 +3,11 @@ package gorm
|
|||
import "database/sql"
|
||||
|
||||
var singularTableName bool
|
||||
var tagIdentifier string
|
||||
|
||||
func init() {
|
||||
tagIdentifier = "sql"
|
||||
}
|
||||
|
||||
type DB struct {
|
||||
db sql_common
|
||||
|
@ -23,6 +28,10 @@ func (s *DB) SetPool(n int) {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *DB) SetTagIdentifier(str string) {
|
||||
tagIdentifier = str
|
||||
}
|
||||
|
||||
func (s *DB) SetLogger(l Logger) {
|
||||
s.logger = l
|
||||
}
|
||||
|
|
9
model.go
9
model.go
|
@ -115,6 +115,7 @@ func (m *Model) fields(operation string) (fields []Field) {
|
|||
}
|
||||
}
|
||||
|
||||
tag_value := p.Tag.Get(tagIdentifier)
|
||||
if is_time {
|
||||
field.AutoCreateTime = "created_at" == field.DbName
|
||||
field.AutoUpdateTime = "updated_at" == field.DbName
|
||||
|
@ -130,9 +131,9 @@ func (m *Model) fields(operation string) (fields []Field) {
|
|||
}
|
||||
}
|
||||
|
||||
field.SqlType = getSqlType(m.do.chain.driver(), value, 0)
|
||||
field.SqlType = getSqlType(m.do.chain.driver(), value, tag_value)
|
||||
} else if field.IsPrimaryKey {
|
||||
field.SqlType = getPrimaryKeySqlType(m.do.chain.driver(), value, 0)
|
||||
field.SqlType = getPrimaryKeySqlType(m.do.chain.driver(), value, tag_value)
|
||||
} else {
|
||||
field_value := reflect.Indirect(value)
|
||||
|
||||
|
@ -147,7 +148,7 @@ func (m *Model) fields(operation string) (fields []Field) {
|
|||
_, is_scanner := reflect.New(field_value.Type()).Interface().(sql.Scanner)
|
||||
|
||||
if is_scanner {
|
||||
field.SqlType = getSqlType(m.do.chain.driver(), value, 0)
|
||||
field.SqlType = getSqlType(m.do.chain.driver(), value, tag_value)
|
||||
} else {
|
||||
if indirect_value.FieldByName(p.Name + "Id").IsValid() {
|
||||
field.foreignKey = p.Name + "Id"
|
||||
|
@ -161,7 +162,7 @@ func (m *Model) fields(operation string) (fields []Field) {
|
|||
}
|
||||
}
|
||||
default:
|
||||
field.SqlType = getSqlType(m.do.chain.driver(), value, 0)
|
||||
field.SqlType = getSqlType(m.do.chain.driver(), value, tag_value)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
216
sql_type.go
216
sql_type.go
|
@ -5,9 +5,40 @@ import (
|
|||
"database/sql/driver"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func parseTag(str string) (typ string, addational_typ string, size int) {
|
||||
if str == "-" {
|
||||
typ = str
|
||||
} else if str != "" {
|
||||
tags := strings.Split(str, ";")
|
||||
m := make(map[string]string)
|
||||
for _, value := range tags {
|
||||
v := strings.Split(value, ":")
|
||||
k := strings.Trim(strings.ToUpper(v[0]), " ")
|
||||
if len(v) == 2 {
|
||||
m[k] = v[1]
|
||||
} else {
|
||||
m[k] = k
|
||||
}
|
||||
}
|
||||
|
||||
if len(m["SIZE"]) > 0 {
|
||||
size, _ = strconv.Atoi(m["SIZE"])
|
||||
}
|
||||
|
||||
if len(m["TYPE"]) > 0 {
|
||||
typ = m["TYPE"]
|
||||
}
|
||||
|
||||
addational_typ = m["NOT NULL"] + " " + m["UNIQUE"]
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func formatColumnValue(column interface{}) interface{} {
|
||||
if v, ok := column.(reflect.Value); ok {
|
||||
column = v.Interface()
|
||||
|
@ -19,8 +50,13 @@ func formatColumnValue(column interface{}) interface{} {
|
|||
return column
|
||||
}
|
||||
|
||||
func getPrimaryKeySqlType(adaptor string, column interface{}, size int) string {
|
||||
func getPrimaryKeySqlType(adaptor string, column interface{}, tag string) string {
|
||||
column = formatColumnValue(column)
|
||||
typ, addational_typ, _ := parseTag(tag)
|
||||
|
||||
if len(typ) != 0 {
|
||||
return typ + addational_typ
|
||||
}
|
||||
|
||||
switch adaptor {
|
||||
case "sqlite3":
|
||||
|
@ -29,94 +65,112 @@ func getPrimaryKeySqlType(adaptor string, column interface{}, size int) string {
|
|||
suffix_str := " NOT NULL AUTO_INCREMENT PRIMARY KEY"
|
||||
switch column.(type) {
|
||||
case int, int8, int16, int32, uint, uint8, uint16, uint32:
|
||||
return "int" + suffix_str
|
||||
typ = "int" + suffix_str
|
||||
case int64, uint64:
|
||||
return "bigint" + suffix_str
|
||||
typ = "bigint" + suffix_str
|
||||
}
|
||||
case "postgres":
|
||||
switch column.(type) {
|
||||
case int, int8, int16, int32, uint, uint8, uint16, uint32:
|
||||
return "serial"
|
||||
typ = "serial"
|
||||
case int64, uint64:
|
||||
return "bigserial"
|
||||
}
|
||||
}
|
||||
panic("unsupported sql adaptor, please submit an issue in github")
|
||||
}
|
||||
|
||||
func getSqlType(adaptor string, column interface{}, size int) string {
|
||||
column = formatColumnValue(column)
|
||||
|
||||
switch adaptor {
|
||||
case "sqlite3":
|
||||
switch column.(type) {
|
||||
case time.Time:
|
||||
return "datetime"
|
||||
case bool, sql.NullBool:
|
||||
return "bool"
|
||||
case int, int8, int16, int32, uint, uint8, uint16, uint32:
|
||||
return "integer"
|
||||
case int64, uint64, sql.NullInt64:
|
||||
return "bigint"
|
||||
case float32, float64, sql.NullFloat64:
|
||||
return "real"
|
||||
case string, sql.NullString:
|
||||
if size > 0 && size < 65532 {
|
||||
return fmt.Sprintf("varchar(%d)", size)
|
||||
}
|
||||
return "text"
|
||||
default:
|
||||
panic("invalid sql type")
|
||||
}
|
||||
case "mysql":
|
||||
switch column.(type) {
|
||||
case time.Time:
|
||||
return "timestamp"
|
||||
case bool, sql.NullBool:
|
||||
return "boolean"
|
||||
case int, int8, int16, int32, uint, uint8, uint16, uint32:
|
||||
return "int"
|
||||
case int64, uint64, sql.NullInt64:
|
||||
return "bigint"
|
||||
case float32, float64, sql.NullFloat64:
|
||||
return "double"
|
||||
case []byte:
|
||||
if size > 0 && size < 65532 {
|
||||
return fmt.Sprintf("varbinary(%d)", size)
|
||||
}
|
||||
return "longblob"
|
||||
case string, sql.NullString:
|
||||
if size > 0 && size < 65532 {
|
||||
return fmt.Sprintf("varchar(%d)", size)
|
||||
}
|
||||
return "longtext"
|
||||
default:
|
||||
panic("invalid sql type")
|
||||
}
|
||||
|
||||
case "postgres":
|
||||
switch column.(type) {
|
||||
case time.Time:
|
||||
return "timestamp with time zone"
|
||||
case bool, sql.NullBool:
|
||||
return "boolean"
|
||||
case int, int8, int16, int32, uint, uint8, uint16, uint32:
|
||||
return "integer"
|
||||
case int64, uint64, sql.NullInt64:
|
||||
return "bigint"
|
||||
case float32, float64, sql.NullFloat64:
|
||||
return "double precision"
|
||||
case []byte:
|
||||
return "bytea"
|
||||
case string, sql.NullString:
|
||||
if size > 0 && size < 65532 {
|
||||
return fmt.Sprintf("varchar(%d)", size)
|
||||
}
|
||||
return "text"
|
||||
default:
|
||||
panic("invalid sql type")
|
||||
typ = "bigserial"
|
||||
}
|
||||
default:
|
||||
panic("unsupported sql adaptor, please submit an issue in github")
|
||||
}
|
||||
return typ
|
||||
}
|
||||
|
||||
func getSqlType(adaptor string, column interface{}, tag string) string {
|
||||
column = formatColumnValue(column)
|
||||
typ, addational_typ, size := parseTag(tag)
|
||||
|
||||
if typ == "-" {
|
||||
return ""
|
||||
}
|
||||
|
||||
if len(typ) == 0 {
|
||||
switch adaptor {
|
||||
case "sqlite3":
|
||||
switch column.(type) {
|
||||
case time.Time:
|
||||
typ = "datetime"
|
||||
case bool, sql.NullBool:
|
||||
typ = "bool"
|
||||
case int, int8, int16, int32, uint, uint8, uint16, uint32:
|
||||
typ = "integer"
|
||||
case int64, uint64, sql.NullInt64:
|
||||
typ = "bigint"
|
||||
case float32, float64, sql.NullFloat64:
|
||||
typ = "real"
|
||||
case string, sql.NullString:
|
||||
if size > 0 && size < 65532 {
|
||||
typ = fmt.Sprintf("varchar(%d)", size)
|
||||
} else {
|
||||
typ = "text"
|
||||
}
|
||||
default:
|
||||
panic("invalid sql type")
|
||||
}
|
||||
case "mysql":
|
||||
switch column.(type) {
|
||||
case time.Time:
|
||||
typ = "timestamp"
|
||||
case bool, sql.NullBool:
|
||||
typ = "boolean"
|
||||
case int, int8, int16, int32, uint, uint8, uint16, uint32:
|
||||
typ = "int"
|
||||
case int64, uint64, sql.NullInt64:
|
||||
typ = "bigint"
|
||||
case float32, float64, sql.NullFloat64:
|
||||
typ = "double"
|
||||
case []byte:
|
||||
if size > 0 && size < 65532 {
|
||||
typ = fmt.Sprintf("varbinary(%d)", size)
|
||||
} else {
|
||||
typ = "longblob"
|
||||
}
|
||||
case string, sql.NullString:
|
||||
if size > 0 && size < 65532 {
|
||||
typ = fmt.Sprintf("varchar(%d)", size)
|
||||
} else {
|
||||
typ = "longtext"
|
||||
}
|
||||
default:
|
||||
panic("invalid sql type")
|
||||
}
|
||||
|
||||
case "postgres":
|
||||
switch column.(type) {
|
||||
case time.Time:
|
||||
typ = "timestamp with time zone"
|
||||
case bool, sql.NullBool:
|
||||
typ = "boolean"
|
||||
case int, int8, int16, int32, uint, uint8, uint16, uint32:
|
||||
typ = "integer"
|
||||
case int64, uint64, sql.NullInt64:
|
||||
typ = "bigint"
|
||||
case float32, float64, sql.NullFloat64:
|
||||
typ = "double precision"
|
||||
case []byte:
|
||||
typ = "bytea"
|
||||
case string, sql.NullString:
|
||||
if size > 0 && size < 65532 {
|
||||
typ = fmt.Sprintf("varchar(%d)", size)
|
||||
} else {
|
||||
typ = "text"
|
||||
}
|
||||
default:
|
||||
panic("invalid sql type")
|
||||
}
|
||||
default:
|
||||
panic("unsupported sql adaptor, please submit an issue in github")
|
||||
}
|
||||
}
|
||||
|
||||
if len(addational_typ) > 0 {
|
||||
typ = typ + " " + addational_typ
|
||||
}
|
||||
return typ
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue