Add Tag Support

This commit is contained in:
Jinzhu 2013-11-14 00:03:31 +08:00
parent 5341a5ed43
commit c03999a44e
6 changed files with 187 additions and 106 deletions

View File

@ -11,6 +11,7 @@ Yet Another ORM library for Go, aims for developer friendly
* Auto Migration
* Transaction
* Logger Support
* Bind struct with tag
* Every feature comes with tests
* Convention Over Configuration
* Developer Friendly
@ -22,7 +23,7 @@ type User struct { // TableName: `users`, gorm will pluralize struct's n
Id int64 // Id: Database Primary key
Birthday time.Time
Age int64
Name string
Name string `sql:"size:255"` // set this field's length and as not null with tag
CreatedAt time.Time // Time of record is created, will be insert automatically
UpdatedAt time.Time // Time of record is updated, will be updated automatically
DeletedAt time.Time // Time of record is deleted, refer `Soft Delete` for more
@ -32,20 +33,22 @@ type User struct { // TableName: `users`, gorm will pluralize struct's n
BillingAddressId sql.NullInt64 // Embedded struct BillingAddress's foreign key
ShippingAddress Address // Embedded struct
ShippingAddressId int64 // Embedded struct ShippingAddress's foreign key
IgnoreMe int64 `sql:"-"` // Ignore this field with tag
}
type Email struct { // TableName: `emails`
Id int64
UserId int64 // Foreign key for above embedded structs
Email string
Email string `sql:"type:varchar(100);"` // Set column type directly with tag
Subscribed bool
}
type Address struct { // TableName: `addresses`
Id int64
Address1 string
Address2 string
Post string
Address1 string `sql:"not null;unique"` // Set column as unique with tag
Address2 string `sql:"type:varchar(100);unique"`
Post sql.NullString `sql:not null`
// Be careful: "NOT NULL" will only works for NullXXX scanner, because golang will initalize a default value for most type...
}
```
@ -723,8 +726,8 @@ db.Where("email = ?", "x@example.org").Attrs(User{FromIp: "111.111.111.111"}).Fi
## TODO
* Join, Having, Group, Includes
* Scopes
* Index, Unique, Valiations
* Scopes, Valiations
* AlertColumn, DropColumn, AddIndex, RemoveIndex
# Author

View File

@ -18,7 +18,7 @@ type User struct {
Id int64 // Id: Primary key
Birthday time.Time // Time
Age int64
Name string
Name string `sql:"size:255"`
CreatedAt time.Time // CreatedAt: Time of record is created, will be insert automatically
UpdatedAt time.Time // UpdatedAt: Time of record is updated, will be updated automatically
DeletedAt time.Time // DeletedAt: Time of record is deleted, refer Soft Delete for more
@ -28,6 +28,7 @@ type User struct {
ShippingAddress Address // Embedded struct
ShippingAddressId int64 // Embedded struct's foreign key
CreditCard CreditCard
IgnoreMe int64 `sql:"-"`
}
type CreditCard struct {
@ -42,7 +43,7 @@ type CreditCard struct {
type Email struct {
Id int64
UserId int64
Email string
Email string `sql:"type:varchar(100); unique"`
CreatedAt time.Time
UpdatedAt time.Time
}
@ -1279,7 +1280,6 @@ type BigEmail struct {
RegisteredAt time.Time
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt time.Time
}
func (b BigEmail) TableName() string {
@ -1299,6 +1299,10 @@ func TestAutoMigration(t *testing.T) {
if big_email.Email != "jinzhu@example.org" || big_email.UserAgent != "pc" || big_email.RegisteredAt.IsZero() {
t.Error("Big Emails should be saved and fetched correctly")
}
if err := db.Save(&BigEmail{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: time.Now()}).Error; err == nil {
t.Error("Should not be able to save because of unique tag")
}
}
type NullTime struct {
@ -1324,7 +1328,7 @@ func (nt NullTime) Value() (driver.Value, error) {
type NullValue struct {
Id int64
Name sql.NullString
Name sql.NullString `sql:"not null"`
Age sql.NullInt64
Male sql.NullBool
Height sql.NullFloat64
@ -1355,6 +1359,10 @@ func TestSqlNullValue(t *testing.T) {
if nv2.Name.String != "hello-2" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 || nv2.AddedAt.Valid != false {
t.Errorf("Should be able to fetch null value")
}
if err := db.Save(&NullValue{Name: sql.NullString{"hello-3", false}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), false}}).Error; err == nil {
t.Errorf("Can't save because of name can't be null", err)
}
}
func TestTransaction(t *testing.T) {
@ -1401,10 +1409,11 @@ func (s *CreditCard) BeforeSave() (err error) {
func BenchmarkGorm(b *testing.B) {
for x := 0; x < b.N; x++ {
email := BigEmail{Email: "benchmark@example.org", UserAgent: "pc", RegisteredAt: time.Now()}
e := strconv.Itoa(x) + "benchmark@example.org"
email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()}
db.Save(&email)
db.First(&BigEmail{}, "email = ?", "benchmark@benchmark.org")
db.Model(&email).Update("email", "benchmark@benchmark.org")
db.First(&BigEmail{}, "email = ?", e)
db.Model(&email).Update("email", e)
db.Delete(&email)
}
}
@ -1419,11 +1428,12 @@ func BenchmarkRawSql(b *testing.B) {
var id int64
for x := 0; x < b.N; x++ {
email := BigEmail{Email: "benchmark@example.org", UserAgent: "pc", RegisteredAt: time.Now()}
e := strconv.Itoa(x) + "benchmark@example.org"
email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()}
db.QueryRow(insert_sql, email.UserId, email.Email, email.UserAgent, email.RegisteredAt, time.Now(), time.Now()).Scan(&id)
rows, _ := db.Query(query_sql, email.Email)
rows.Close()
db.Exec(update_sql, "benchmark@benchmark.org", time.Now(), id)
db.Exec(update_sql, e, time.Now(), id)
db.Exec(delete_sql, id)
}
}

View File

@ -57,6 +57,10 @@ func (s *Chain) warn(v ...interface{}) {
go s.print("warn", v...)
}
func (s *Chain) info(v ...interface{}) {
go s.print("info", v...)
}
func (s *Chain) slog(sql string, t time.Time, vars ...interface{}) {
go s.print("sql", time.Now().Sub(t), sql, vars)
}

View File

@ -3,6 +3,11 @@ package gorm
import "database/sql"
var singularTableName bool
var tagIdentifier string
func init() {
tagIdentifier = "sql"
}
type DB struct {
db sql_common
@ -23,6 +28,10 @@ func (s *DB) SetPool(n int) {
}
}
func (s *DB) SetTagIdentifier(str string) {
tagIdentifier = str
}
func (s *DB) SetLogger(l Logger) {
s.logger = l
}

View File

@ -115,6 +115,7 @@ func (m *Model) fields(operation string) (fields []Field) {
}
}
tag_value := p.Tag.Get(tagIdentifier)
if is_time {
field.AutoCreateTime = "created_at" == field.DbName
field.AutoUpdateTime = "updated_at" == field.DbName
@ -130,9 +131,9 @@ func (m *Model) fields(operation string) (fields []Field) {
}
}
field.SqlType = getSqlType(m.do.chain.driver(), value, 0)
field.SqlType = getSqlType(m.do.chain.driver(), value, tag_value)
} else if field.IsPrimaryKey {
field.SqlType = getPrimaryKeySqlType(m.do.chain.driver(), value, 0)
field.SqlType = getPrimaryKeySqlType(m.do.chain.driver(), value, tag_value)
} else {
field_value := reflect.Indirect(value)
@ -147,7 +148,7 @@ func (m *Model) fields(operation string) (fields []Field) {
_, is_scanner := reflect.New(field_value.Type()).Interface().(sql.Scanner)
if is_scanner {
field.SqlType = getSqlType(m.do.chain.driver(), value, 0)
field.SqlType = getSqlType(m.do.chain.driver(), value, tag_value)
} else {
if indirect_value.FieldByName(p.Name + "Id").IsValid() {
field.foreignKey = p.Name + "Id"
@ -161,7 +162,7 @@ func (m *Model) fields(operation string) (fields []Field) {
}
}
default:
field.SqlType = getSqlType(m.do.chain.driver(), value, 0)
field.SqlType = getSqlType(m.do.chain.driver(), value, tag_value)
}
}

View File

@ -5,9 +5,40 @@ import (
"database/sql/driver"
"fmt"
"reflect"
"strconv"
"strings"
"time"
)
func parseTag(str string) (typ string, addational_typ string, size int) {
if str == "-" {
typ = str
} else if str != "" {
tags := strings.Split(str, ";")
m := make(map[string]string)
for _, value := range tags {
v := strings.Split(value, ":")
k := strings.Trim(strings.ToUpper(v[0]), " ")
if len(v) == 2 {
m[k] = v[1]
} else {
m[k] = k
}
}
if len(m["SIZE"]) > 0 {
size, _ = strconv.Atoi(m["SIZE"])
}
if len(m["TYPE"]) > 0 {
typ = m["TYPE"]
}
addational_typ = m["NOT NULL"] + " " + m["UNIQUE"]
}
return
}
func formatColumnValue(column interface{}) interface{} {
if v, ok := column.(reflect.Value); ok {
column = v.Interface()
@ -19,8 +50,13 @@ func formatColumnValue(column interface{}) interface{} {
return column
}
func getPrimaryKeySqlType(adaptor string, column interface{}, size int) string {
func getPrimaryKeySqlType(adaptor string, column interface{}, tag string) string {
column = formatColumnValue(column)
typ, addational_typ, _ := parseTag(tag)
if len(typ) != 0 {
return typ + addational_typ
}
switch adaptor {
case "sqlite3":
@ -29,67 +65,78 @@ func getPrimaryKeySqlType(adaptor string, column interface{}, size int) string {
suffix_str := " NOT NULL AUTO_INCREMENT PRIMARY KEY"
switch column.(type) {
case int, int8, int16, int32, uint, uint8, uint16, uint32:
return "int" + suffix_str
typ = "int" + suffix_str
case int64, uint64:
return "bigint" + suffix_str
typ = "bigint" + suffix_str
}
case "postgres":
switch column.(type) {
case int, int8, int16, int32, uint, uint8, uint16, uint32:
return "serial"
typ = "serial"
case int64, uint64:
return "bigserial"
}
typ = "bigserial"
}
default:
panic("unsupported sql adaptor, please submit an issue in github")
}
return typ
}
func getSqlType(adaptor string, column interface{}, size int) string {
func getSqlType(adaptor string, column interface{}, tag string) string {
column = formatColumnValue(column)
typ, addational_typ, size := parseTag(tag)
if typ == "-" {
return ""
}
if len(typ) == 0 {
switch adaptor {
case "sqlite3":
switch column.(type) {
case time.Time:
return "datetime"
typ = "datetime"
case bool, sql.NullBool:
return "bool"
typ = "bool"
case int, int8, int16, int32, uint, uint8, uint16, uint32:
return "integer"
typ = "integer"
case int64, uint64, sql.NullInt64:
return "bigint"
typ = "bigint"
case float32, float64, sql.NullFloat64:
return "real"
typ = "real"
case string, sql.NullString:
if size > 0 && size < 65532 {
return fmt.Sprintf("varchar(%d)", size)
typ = fmt.Sprintf("varchar(%d)", size)
} else {
typ = "text"
}
return "text"
default:
panic("invalid sql type")
}
case "mysql":
switch column.(type) {
case time.Time:
return "timestamp"
typ = "timestamp"
case bool, sql.NullBool:
return "boolean"
typ = "boolean"
case int, int8, int16, int32, uint, uint8, uint16, uint32:
return "int"
typ = "int"
case int64, uint64, sql.NullInt64:
return "bigint"
typ = "bigint"
case float32, float64, sql.NullFloat64:
return "double"
typ = "double"
case []byte:
if size > 0 && size < 65532 {
return fmt.Sprintf("varbinary(%d)", size)
typ = fmt.Sprintf("varbinary(%d)", size)
} else {
typ = "longblob"
}
return "longblob"
case string, sql.NullString:
if size > 0 && size < 65532 {
return fmt.Sprintf("varchar(%d)", size)
typ = fmt.Sprintf("varchar(%d)", size)
} else {
typ = "longtext"
}
return "longtext"
default:
panic("invalid sql type")
}
@ -97,22 +144,23 @@ func getSqlType(adaptor string, column interface{}, size int) string {
case "postgres":
switch column.(type) {
case time.Time:
return "timestamp with time zone"
typ = "timestamp with time zone"
case bool, sql.NullBool:
return "boolean"
typ = "boolean"
case int, int8, int16, int32, uint, uint8, uint16, uint32:
return "integer"
typ = "integer"
case int64, uint64, sql.NullInt64:
return "bigint"
typ = "bigint"
case float32, float64, sql.NullFloat64:
return "double precision"
typ = "double precision"
case []byte:
return "bytea"
typ = "bytea"
case string, sql.NullString:
if size > 0 && size < 65532 {
return fmt.Sprintf("varchar(%d)", size)
typ = fmt.Sprintf("varchar(%d)", size)
} else {
typ = "text"
}
return "text"
default:
panic("invalid sql type")
}
@ -120,3 +168,9 @@ func getSqlType(adaptor string, column interface{}, size int) string {
panic("unsupported sql adaptor, please submit an issue in github")
}
}
if len(addational_typ) > 0 {
typ = typ + " " + addational_typ
}
return typ
}