Add Tag Support

This commit is contained in:
Jinzhu 2013-11-14 00:03:31 +08:00
parent 5341a5ed43
commit c03999a44e
6 changed files with 187 additions and 106 deletions

View File

@ -11,6 +11,7 @@ Yet Another ORM library for Go, aims for developer friendly
* Auto Migration * Auto Migration
* Transaction * Transaction
* Logger Support * Logger Support
* Bind struct with tag
* Every feature comes with tests * Every feature comes with tests
* Convention Over Configuration * Convention Over Configuration
* Developer Friendly * Developer Friendly
@ -22,30 +23,32 @@ type User struct { // TableName: `users`, gorm will pluralize struct's n
Id int64 // Id: Database Primary key Id int64 // Id: Database Primary key
Birthday time.Time Birthday time.Time
Age int64 Age int64
Name string Name string `sql:"size:255"` // set this field's length and as not null with tag
CreatedAt time.Time // Time of record is created, will be insert automatically CreatedAt time.Time // Time of record is created, will be insert automatically
UpdatedAt time.Time // Time of record is updated, will be updated automatically UpdatedAt time.Time // Time of record is updated, will be updated automatically
DeletedAt time.Time // Time of record is deleted, refer `Soft Delete` for more DeletedAt time.Time // Time of record is deleted, refer `Soft Delete` for more
Email []Email // Embedded structs Email []Email // Embedded structs
BillingAddress Address // Embedded struct BillingAddress Address // Embedded struct
BillingAddressId sql.NullInt64 // Embedded struct BillingAddress's foreign key BillingAddressId sql.NullInt64 // Embedded struct BillingAddress's foreign key
ShippingAddress Address // Embedded struct ShippingAddress Address // Embedded struct
ShippingAddressId int64 // Embedded struct ShippingAddress's foreign key ShippingAddressId int64 // Embedded struct ShippingAddress's foreign key
IgnoreMe int64 `sql:"-"` // Ignore this field with tag
} }
type Email struct { // TableName: `emails` type Email struct { // TableName: `emails`
Id int64 Id int64
UserId int64 // Foreign key for above embedded structs UserId int64 // Foreign key for above embedded structs
Email string Email string `sql:"type:varchar(100);"` // Set column type directly with tag
Subscribed bool Subscribed bool
} }
type Address struct { // TableName: `addresses` type Address struct { // TableName: `addresses`
Id int64 Id int64
Address1 string Address1 string `sql:"not null;unique"` // Set column as unique with tag
Address2 string Address2 string `sql:"type:varchar(100);unique"`
Post string Post sql.NullString `sql:not null`
// Be careful: "NOT NULL" will only works for NullXXX scanner, because golang will initalize a default value for most type...
} }
``` ```
@ -723,8 +726,8 @@ db.Where("email = ?", "x@example.org").Attrs(User{FromIp: "111.111.111.111"}).Fi
## TODO ## TODO
* Join, Having, Group, Includes * Join, Having, Group, Includes
* Scopes * Scopes, Valiations
* Index, Unique, Valiations * AlertColumn, DropColumn, AddIndex, RemoveIndex
# Author # Author

View File

@ -18,7 +18,7 @@ type User struct {
Id int64 // Id: Primary key Id int64 // Id: Primary key
Birthday time.Time // Time Birthday time.Time // Time
Age int64 Age int64
Name string Name string `sql:"size:255"`
CreatedAt time.Time // CreatedAt: Time of record is created, will be insert automatically CreatedAt time.Time // CreatedAt: Time of record is created, will be insert automatically
UpdatedAt time.Time // UpdatedAt: Time of record is updated, will be updated automatically UpdatedAt time.Time // UpdatedAt: Time of record is updated, will be updated automatically
DeletedAt time.Time // DeletedAt: Time of record is deleted, refer Soft Delete for more DeletedAt time.Time // DeletedAt: Time of record is deleted, refer Soft Delete for more
@ -28,6 +28,7 @@ type User struct {
ShippingAddress Address // Embedded struct ShippingAddress Address // Embedded struct
ShippingAddressId int64 // Embedded struct's foreign key ShippingAddressId int64 // Embedded struct's foreign key
CreditCard CreditCard CreditCard CreditCard
IgnoreMe int64 `sql:"-"`
} }
type CreditCard struct { type CreditCard struct {
@ -42,7 +43,7 @@ type CreditCard struct {
type Email struct { type Email struct {
Id int64 Id int64
UserId int64 UserId int64
Email string Email string `sql:"type:varchar(100); unique"`
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time
} }
@ -1279,7 +1280,6 @@ type BigEmail struct {
RegisteredAt time.Time RegisteredAt time.Time
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time
DeletedAt time.Time
} }
func (b BigEmail) TableName() string { func (b BigEmail) TableName() string {
@ -1299,6 +1299,10 @@ func TestAutoMigration(t *testing.T) {
if big_email.Email != "jinzhu@example.org" || big_email.UserAgent != "pc" || big_email.RegisteredAt.IsZero() { if big_email.Email != "jinzhu@example.org" || big_email.UserAgent != "pc" || big_email.RegisteredAt.IsZero() {
t.Error("Big Emails should be saved and fetched correctly") t.Error("Big Emails should be saved and fetched correctly")
} }
if err := db.Save(&BigEmail{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: time.Now()}).Error; err == nil {
t.Error("Should not be able to save because of unique tag")
}
} }
type NullTime struct { type NullTime struct {
@ -1324,7 +1328,7 @@ func (nt NullTime) Value() (driver.Value, error) {
type NullValue struct { type NullValue struct {
Id int64 Id int64
Name sql.NullString Name sql.NullString `sql:"not null"`
Age sql.NullInt64 Age sql.NullInt64
Male sql.NullBool Male sql.NullBool
Height sql.NullFloat64 Height sql.NullFloat64
@ -1355,6 +1359,10 @@ func TestSqlNullValue(t *testing.T) {
if nv2.Name.String != "hello-2" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 || nv2.AddedAt.Valid != false { if nv2.Name.String != "hello-2" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 || nv2.AddedAt.Valid != false {
t.Errorf("Should be able to fetch null value") t.Errorf("Should be able to fetch null value")
} }
if err := db.Save(&NullValue{Name: sql.NullString{"hello-3", false}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), false}}).Error; err == nil {
t.Errorf("Can't save because of name can't be null", err)
}
} }
func TestTransaction(t *testing.T) { func TestTransaction(t *testing.T) {
@ -1401,10 +1409,11 @@ func (s *CreditCard) BeforeSave() (err error) {
func BenchmarkGorm(b *testing.B) { func BenchmarkGorm(b *testing.B) {
for x := 0; x < b.N; x++ { for x := 0; x < b.N; x++ {
email := BigEmail{Email: "benchmark@example.org", UserAgent: "pc", RegisteredAt: time.Now()} e := strconv.Itoa(x) + "benchmark@example.org"
email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()}
db.Save(&email) db.Save(&email)
db.First(&BigEmail{}, "email = ?", "benchmark@benchmark.org") db.First(&BigEmail{}, "email = ?", e)
db.Model(&email).Update("email", "benchmark@benchmark.org") db.Model(&email).Update("email", e)
db.Delete(&email) db.Delete(&email)
} }
} }
@ -1419,11 +1428,12 @@ func BenchmarkRawSql(b *testing.B) {
var id int64 var id int64
for x := 0; x < b.N; x++ { for x := 0; x < b.N; x++ {
email := BigEmail{Email: "benchmark@example.org", UserAgent: "pc", RegisteredAt: time.Now()} e := strconv.Itoa(x) + "benchmark@example.org"
email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()}
db.QueryRow(insert_sql, email.UserId, email.Email, email.UserAgent, email.RegisteredAt, time.Now(), time.Now()).Scan(&id) db.QueryRow(insert_sql, email.UserId, email.Email, email.UserAgent, email.RegisteredAt, time.Now(), time.Now()).Scan(&id)
rows, _ := db.Query(query_sql, email.Email) rows, _ := db.Query(query_sql, email.Email)
rows.Close() rows.Close()
db.Exec(update_sql, "benchmark@benchmark.org", time.Now(), id) db.Exec(update_sql, e, time.Now(), id)
db.Exec(delete_sql, id) db.Exec(delete_sql, id)
} }
} }

View File

@ -57,6 +57,10 @@ func (s *Chain) warn(v ...interface{}) {
go s.print("warn", v...) go s.print("warn", v...)
} }
func (s *Chain) info(v ...interface{}) {
go s.print("info", v...)
}
func (s *Chain) slog(sql string, t time.Time, vars ...interface{}) { func (s *Chain) slog(sql string, t time.Time, vars ...interface{}) {
go s.print("sql", time.Now().Sub(t), sql, vars) go s.print("sql", time.Now().Sub(t), sql, vars)
} }

View File

@ -3,6 +3,11 @@ package gorm
import "database/sql" import "database/sql"
var singularTableName bool var singularTableName bool
var tagIdentifier string
func init() {
tagIdentifier = "sql"
}
type DB struct { type DB struct {
db sql_common db sql_common
@ -23,6 +28,10 @@ func (s *DB) SetPool(n int) {
} }
} }
func (s *DB) SetTagIdentifier(str string) {
tagIdentifier = str
}
func (s *DB) SetLogger(l Logger) { func (s *DB) SetLogger(l Logger) {
s.logger = l s.logger = l
} }

View File

@ -115,6 +115,7 @@ func (m *Model) fields(operation string) (fields []Field) {
} }
} }
tag_value := p.Tag.Get(tagIdentifier)
if is_time { if is_time {
field.AutoCreateTime = "created_at" == field.DbName field.AutoCreateTime = "created_at" == field.DbName
field.AutoUpdateTime = "updated_at" == field.DbName field.AutoUpdateTime = "updated_at" == field.DbName
@ -130,9 +131,9 @@ func (m *Model) fields(operation string) (fields []Field) {
} }
} }
field.SqlType = getSqlType(m.do.chain.driver(), value, 0) field.SqlType = getSqlType(m.do.chain.driver(), value, tag_value)
} else if field.IsPrimaryKey { } else if field.IsPrimaryKey {
field.SqlType = getPrimaryKeySqlType(m.do.chain.driver(), value, 0) field.SqlType = getPrimaryKeySqlType(m.do.chain.driver(), value, tag_value)
} else { } else {
field_value := reflect.Indirect(value) field_value := reflect.Indirect(value)
@ -147,7 +148,7 @@ func (m *Model) fields(operation string) (fields []Field) {
_, is_scanner := reflect.New(field_value.Type()).Interface().(sql.Scanner) _, is_scanner := reflect.New(field_value.Type()).Interface().(sql.Scanner)
if is_scanner { if is_scanner {
field.SqlType = getSqlType(m.do.chain.driver(), value, 0) field.SqlType = getSqlType(m.do.chain.driver(), value, tag_value)
} else { } else {
if indirect_value.FieldByName(p.Name + "Id").IsValid() { if indirect_value.FieldByName(p.Name + "Id").IsValid() {
field.foreignKey = p.Name + "Id" field.foreignKey = p.Name + "Id"
@ -161,7 +162,7 @@ func (m *Model) fields(operation string) (fields []Field) {
} }
} }
default: default:
field.SqlType = getSqlType(m.do.chain.driver(), value, 0) field.SqlType = getSqlType(m.do.chain.driver(), value, tag_value)
} }
} }

View File

@ -5,9 +5,40 @@ import (
"database/sql/driver" "database/sql/driver"
"fmt" "fmt"
"reflect" "reflect"
"strconv"
"strings"
"time" "time"
) )
func parseTag(str string) (typ string, addational_typ string, size int) {
if str == "-" {
typ = str
} else if str != "" {
tags := strings.Split(str, ";")
m := make(map[string]string)
for _, value := range tags {
v := strings.Split(value, ":")
k := strings.Trim(strings.ToUpper(v[0]), " ")
if len(v) == 2 {
m[k] = v[1]
} else {
m[k] = k
}
}
if len(m["SIZE"]) > 0 {
size, _ = strconv.Atoi(m["SIZE"])
}
if len(m["TYPE"]) > 0 {
typ = m["TYPE"]
}
addational_typ = m["NOT NULL"] + " " + m["UNIQUE"]
}
return
}
func formatColumnValue(column interface{}) interface{} { func formatColumnValue(column interface{}) interface{} {
if v, ok := column.(reflect.Value); ok { if v, ok := column.(reflect.Value); ok {
column = v.Interface() column = v.Interface()
@ -19,8 +50,13 @@ func formatColumnValue(column interface{}) interface{} {
return column return column
} }
func getPrimaryKeySqlType(adaptor string, column interface{}, size int) string { func getPrimaryKeySqlType(adaptor string, column interface{}, tag string) string {
column = formatColumnValue(column) column = formatColumnValue(column)
typ, addational_typ, _ := parseTag(tag)
if len(typ) != 0 {
return typ + addational_typ
}
switch adaptor { switch adaptor {
case "sqlite3": case "sqlite3":
@ -29,94 +65,112 @@ func getPrimaryKeySqlType(adaptor string, column interface{}, size int) string {
suffix_str := " NOT NULL AUTO_INCREMENT PRIMARY KEY" suffix_str := " NOT NULL AUTO_INCREMENT PRIMARY KEY"
switch column.(type) { switch column.(type) {
case int, int8, int16, int32, uint, uint8, uint16, uint32: case int, int8, int16, int32, uint, uint8, uint16, uint32:
return "int" + suffix_str typ = "int" + suffix_str
case int64, uint64: case int64, uint64:
return "bigint" + suffix_str typ = "bigint" + suffix_str
} }
case "postgres": case "postgres":
switch column.(type) { switch column.(type) {
case int, int8, int16, int32, uint, uint8, uint16, uint32: case int, int8, int16, int32, uint, uint8, uint16, uint32:
return "serial" typ = "serial"
case int64, uint64: case int64, uint64:
return "bigserial" typ = "bigserial"
}
}
panic("unsupported sql adaptor, please submit an issue in github")
}
func getSqlType(adaptor string, column interface{}, size int) string {
column = formatColumnValue(column)
switch adaptor {
case "sqlite3":
switch column.(type) {
case time.Time:
return "datetime"
case bool, sql.NullBool:
return "bool"
case int, int8, int16, int32, uint, uint8, uint16, uint32:
return "integer"
case int64, uint64, sql.NullInt64:
return "bigint"
case float32, float64, sql.NullFloat64:
return "real"
case string, sql.NullString:
if size > 0 && size < 65532 {
return fmt.Sprintf("varchar(%d)", size)
}
return "text"
default:
panic("invalid sql type")
}
case "mysql":
switch column.(type) {
case time.Time:
return "timestamp"
case bool, sql.NullBool:
return "boolean"
case int, int8, int16, int32, uint, uint8, uint16, uint32:
return "int"
case int64, uint64, sql.NullInt64:
return "bigint"
case float32, float64, sql.NullFloat64:
return "double"
case []byte:
if size > 0 && size < 65532 {
return fmt.Sprintf("varbinary(%d)", size)
}
return "longblob"
case string, sql.NullString:
if size > 0 && size < 65532 {
return fmt.Sprintf("varchar(%d)", size)
}
return "longtext"
default:
panic("invalid sql type")
}
case "postgres":
switch column.(type) {
case time.Time:
return "timestamp with time zone"
case bool, sql.NullBool:
return "boolean"
case int, int8, int16, int32, uint, uint8, uint16, uint32:
return "integer"
case int64, uint64, sql.NullInt64:
return "bigint"
case float32, float64, sql.NullFloat64:
return "double precision"
case []byte:
return "bytea"
case string, sql.NullString:
if size > 0 && size < 65532 {
return fmt.Sprintf("varchar(%d)", size)
}
return "text"
default:
panic("invalid sql type")
} }
default: default:
panic("unsupported sql adaptor, please submit an issue in github") panic("unsupported sql adaptor, please submit an issue in github")
} }
return typ
}
func getSqlType(adaptor string, column interface{}, tag string) string {
column = formatColumnValue(column)
typ, addational_typ, size := parseTag(tag)
if typ == "-" {
return ""
}
if len(typ) == 0 {
switch adaptor {
case "sqlite3":
switch column.(type) {
case time.Time:
typ = "datetime"
case bool, sql.NullBool:
typ = "bool"
case int, int8, int16, int32, uint, uint8, uint16, uint32:
typ = "integer"
case int64, uint64, sql.NullInt64:
typ = "bigint"
case float32, float64, sql.NullFloat64:
typ = "real"
case string, sql.NullString:
if size > 0 && size < 65532 {
typ = fmt.Sprintf("varchar(%d)", size)
} else {
typ = "text"
}
default:
panic("invalid sql type")
}
case "mysql":
switch column.(type) {
case time.Time:
typ = "timestamp"
case bool, sql.NullBool:
typ = "boolean"
case int, int8, int16, int32, uint, uint8, uint16, uint32:
typ = "int"
case int64, uint64, sql.NullInt64:
typ = "bigint"
case float32, float64, sql.NullFloat64:
typ = "double"
case []byte:
if size > 0 && size < 65532 {
typ = fmt.Sprintf("varbinary(%d)", size)
} else {
typ = "longblob"
}
case string, sql.NullString:
if size > 0 && size < 65532 {
typ = fmt.Sprintf("varchar(%d)", size)
} else {
typ = "longtext"
}
default:
panic("invalid sql type")
}
case "postgres":
switch column.(type) {
case time.Time:
typ = "timestamp with time zone"
case bool, sql.NullBool:
typ = "boolean"
case int, int8, int16, int32, uint, uint8, uint16, uint32:
typ = "integer"
case int64, uint64, sql.NullInt64:
typ = "bigint"
case float32, float64, sql.NullFloat64:
typ = "double precision"
case []byte:
typ = "bytea"
case string, sql.NullString:
if size > 0 && size < 65532 {
typ = fmt.Sprintf("varchar(%d)", size)
} else {
typ = "text"
}
default:
panic("invalid sql type")
}
default:
panic("unsupported sql adaptor, please submit an issue in github")
}
}
if len(addational_typ) > 0 {
typ = typ + " " + addational_typ
}
return typ
} }