diff --git a/README.md b/README.md index 4cede1de..0552bce6 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,7 @@ Yet Another ORM library for Go, aims for developer friendly * Auto Migration * Transaction * Logger Support +* Bind struct with tag * Every feature comes with tests * Convention Over Configuration * Developer Friendly @@ -22,30 +23,32 @@ type User struct { // TableName: `users`, gorm will pluralize struct's n Id int64 // Id: Database Primary key Birthday time.Time Age int64 - Name string + Name string `sql:"size:255"` // set this field's length and as not null with tag CreatedAt time.Time // Time of record is created, will be insert automatically UpdatedAt time.Time // Time of record is updated, will be updated automatically DeletedAt time.Time // Time of record is deleted, refer `Soft Delete` for more - Email []Email // Embedded structs - BillingAddress Address // Embedded struct - BillingAddressId sql.NullInt64 // Embedded struct BillingAddress's foreign key - ShippingAddress Address // Embedded struct - ShippingAddressId int64 // Embedded struct ShippingAddress's foreign key + Email []Email // Embedded structs + BillingAddress Address // Embedded struct + BillingAddressId sql.NullInt64 // Embedded struct BillingAddress's foreign key + ShippingAddress Address // Embedded struct + ShippingAddressId int64 // Embedded struct ShippingAddress's foreign key + IgnoreMe int64 `sql:"-"` // Ignore this field with tag } type Email struct { // TableName: `emails` Id int64 UserId int64 // Foreign key for above embedded structs - Email string + Email string `sql:"type:varchar(100);"` // Set column type directly with tag Subscribed bool } type Address struct { // TableName: `addresses` Id int64 - Address1 string - Address2 string - Post string + Address1 string `sql:"not null;unique"` // Set column as unique with tag + Address2 string `sql:"type:varchar(100);unique"` + Post sql.NullString `sql:not null` + // Be careful: "NOT NULL" will only works for NullXXX scanner, because golang will initalize a default value for most type... } ``` @@ -723,8 +726,8 @@ db.Where("email = ?", "x@example.org").Attrs(User{FromIp: "111.111.111.111"}).Fi ## TODO * Join, Having, Group, Includes -* Scopes -* Index, Unique, Valiations +* Scopes, Valiations +* AlertColumn, DropColumn, AddIndex, RemoveIndex # Author diff --git a/gorm_test.go b/gorm_test.go index 225f2668..03db697e 100644 --- a/gorm_test.go +++ b/gorm_test.go @@ -18,7 +18,7 @@ type User struct { Id int64 // Id: Primary key Birthday time.Time // Time Age int64 - Name string + Name string `sql:"size:255"` CreatedAt time.Time // CreatedAt: Time of record is created, will be insert automatically UpdatedAt time.Time // UpdatedAt: Time of record is updated, will be updated automatically DeletedAt time.Time // DeletedAt: Time of record is deleted, refer Soft Delete for more @@ -28,6 +28,7 @@ type User struct { ShippingAddress Address // Embedded struct ShippingAddressId int64 // Embedded struct's foreign key CreditCard CreditCard + IgnoreMe int64 `sql:"-"` } type CreditCard struct { @@ -42,7 +43,7 @@ type CreditCard struct { type Email struct { Id int64 UserId int64 - Email string + Email string `sql:"type:varchar(100); unique"` CreatedAt time.Time UpdatedAt time.Time } @@ -1279,7 +1280,6 @@ type BigEmail struct { RegisteredAt time.Time CreatedAt time.Time UpdatedAt time.Time - DeletedAt time.Time } func (b BigEmail) TableName() string { @@ -1299,6 +1299,10 @@ func TestAutoMigration(t *testing.T) { if big_email.Email != "jinzhu@example.org" || big_email.UserAgent != "pc" || big_email.RegisteredAt.IsZero() { t.Error("Big Emails should be saved and fetched correctly") } + + if err := db.Save(&BigEmail{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: time.Now()}).Error; err == nil { + t.Error("Should not be able to save because of unique tag") + } } type NullTime struct { @@ -1324,7 +1328,7 @@ func (nt NullTime) Value() (driver.Value, error) { type NullValue struct { Id int64 - Name sql.NullString + Name sql.NullString `sql:"not null"` Age sql.NullInt64 Male sql.NullBool Height sql.NullFloat64 @@ -1355,6 +1359,10 @@ func TestSqlNullValue(t *testing.T) { if nv2.Name.String != "hello-2" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 || nv2.AddedAt.Valid != false { t.Errorf("Should be able to fetch null value") } + + if err := db.Save(&NullValue{Name: sql.NullString{"hello-3", false}, Age: sql.NullInt64{18, false}, Male: sql.NullBool{true, true}, Height: sql.NullFloat64{100.11, true}, AddedAt: NullTime{time.Now(), false}}).Error; err == nil { + t.Errorf("Can't save because of name can't be null", err) + } } func TestTransaction(t *testing.T) { @@ -1401,10 +1409,11 @@ func (s *CreditCard) BeforeSave() (err error) { func BenchmarkGorm(b *testing.B) { for x := 0; x < b.N; x++ { - email := BigEmail{Email: "benchmark@example.org", UserAgent: "pc", RegisteredAt: time.Now()} + e := strconv.Itoa(x) + "benchmark@example.org" + email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()} db.Save(&email) - db.First(&BigEmail{}, "email = ?", "benchmark@benchmark.org") - db.Model(&email).Update("email", "benchmark@benchmark.org") + db.First(&BigEmail{}, "email = ?", e) + db.Model(&email).Update("email", e) db.Delete(&email) } } @@ -1419,11 +1428,12 @@ func BenchmarkRawSql(b *testing.B) { var id int64 for x := 0; x < b.N; x++ { - email := BigEmail{Email: "benchmark@example.org", UserAgent: "pc", RegisteredAt: time.Now()} + e := strconv.Itoa(x) + "benchmark@example.org" + email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()} db.QueryRow(insert_sql, email.UserId, email.Email, email.UserAgent, email.RegisteredAt, time.Now(), time.Now()).Scan(&id) rows, _ := db.Query(query_sql, email.Email) rows.Close() - db.Exec(update_sql, "benchmark@benchmark.org", time.Now(), id) + db.Exec(update_sql, e, time.Now(), id) db.Exec(delete_sql, id) } } diff --git a/logger.go b/logger.go index 226345fb..42d5239c 100644 --- a/logger.go +++ b/logger.go @@ -57,6 +57,10 @@ func (s *Chain) warn(v ...interface{}) { go s.print("warn", v...) } +func (s *Chain) info(v ...interface{}) { + go s.print("info", v...) +} + func (s *Chain) slog(sql string, t time.Time, vars ...interface{}) { go s.print("sql", time.Now().Sub(t), sql, vars) } diff --git a/main.go b/main.go index f5f61401..ff85ddd8 100644 --- a/main.go +++ b/main.go @@ -3,6 +3,11 @@ package gorm import "database/sql" var singularTableName bool +var tagIdentifier string + +func init() { + tagIdentifier = "sql" +} type DB struct { db sql_common @@ -23,6 +28,10 @@ func (s *DB) SetPool(n int) { } } +func (s *DB) SetTagIdentifier(str string) { + tagIdentifier = str +} + func (s *DB) SetLogger(l Logger) { s.logger = l } diff --git a/model.go b/model.go index cc93a4e1..c3e4b06e 100644 --- a/model.go +++ b/model.go @@ -115,6 +115,7 @@ func (m *Model) fields(operation string) (fields []Field) { } } + tag_value := p.Tag.Get(tagIdentifier) if is_time { field.AutoCreateTime = "created_at" == field.DbName field.AutoUpdateTime = "updated_at" == field.DbName @@ -130,9 +131,9 @@ func (m *Model) fields(operation string) (fields []Field) { } } - field.SqlType = getSqlType(m.do.chain.driver(), value, 0) + field.SqlType = getSqlType(m.do.chain.driver(), value, tag_value) } else if field.IsPrimaryKey { - field.SqlType = getPrimaryKeySqlType(m.do.chain.driver(), value, 0) + field.SqlType = getPrimaryKeySqlType(m.do.chain.driver(), value, tag_value) } else { field_value := reflect.Indirect(value) @@ -147,7 +148,7 @@ func (m *Model) fields(operation string) (fields []Field) { _, is_scanner := reflect.New(field_value.Type()).Interface().(sql.Scanner) if is_scanner { - field.SqlType = getSqlType(m.do.chain.driver(), value, 0) + field.SqlType = getSqlType(m.do.chain.driver(), value, tag_value) } else { if indirect_value.FieldByName(p.Name + "Id").IsValid() { field.foreignKey = p.Name + "Id" @@ -161,7 +162,7 @@ func (m *Model) fields(operation string) (fields []Field) { } } default: - field.SqlType = getSqlType(m.do.chain.driver(), value, 0) + field.SqlType = getSqlType(m.do.chain.driver(), value, tag_value) } } diff --git a/sql_type.go b/sql_type.go index d82cdc7e..56c5be5f 100644 --- a/sql_type.go +++ b/sql_type.go @@ -5,9 +5,40 @@ import ( "database/sql/driver" "fmt" "reflect" + "strconv" + "strings" "time" ) +func parseTag(str string) (typ string, addational_typ string, size int) { + if str == "-" { + typ = str + } else if str != "" { + tags := strings.Split(str, ";") + m := make(map[string]string) + for _, value := range tags { + v := strings.Split(value, ":") + k := strings.Trim(strings.ToUpper(v[0]), " ") + if len(v) == 2 { + m[k] = v[1] + } else { + m[k] = k + } + } + + if len(m["SIZE"]) > 0 { + size, _ = strconv.Atoi(m["SIZE"]) + } + + if len(m["TYPE"]) > 0 { + typ = m["TYPE"] + } + + addational_typ = m["NOT NULL"] + " " + m["UNIQUE"] + } + return +} + func formatColumnValue(column interface{}) interface{} { if v, ok := column.(reflect.Value); ok { column = v.Interface() @@ -19,8 +50,13 @@ func formatColumnValue(column interface{}) interface{} { return column } -func getPrimaryKeySqlType(adaptor string, column interface{}, size int) string { +func getPrimaryKeySqlType(adaptor string, column interface{}, tag string) string { column = formatColumnValue(column) + typ, addational_typ, _ := parseTag(tag) + + if len(typ) != 0 { + return typ + addational_typ + } switch adaptor { case "sqlite3": @@ -29,94 +65,112 @@ func getPrimaryKeySqlType(adaptor string, column interface{}, size int) string { suffix_str := " NOT NULL AUTO_INCREMENT PRIMARY KEY" switch column.(type) { case int, int8, int16, int32, uint, uint8, uint16, uint32: - return "int" + suffix_str + typ = "int" + suffix_str case int64, uint64: - return "bigint" + suffix_str + typ = "bigint" + suffix_str } case "postgres": switch column.(type) { case int, int8, int16, int32, uint, uint8, uint16, uint32: - return "serial" + typ = "serial" case int64, uint64: - return "bigserial" - } - } - panic("unsupported sql adaptor, please submit an issue in github") -} - -func getSqlType(adaptor string, column interface{}, size int) string { - column = formatColumnValue(column) - - switch adaptor { - case "sqlite3": - switch column.(type) { - case time.Time: - return "datetime" - case bool, sql.NullBool: - return "bool" - case int, int8, int16, int32, uint, uint8, uint16, uint32: - return "integer" - case int64, uint64, sql.NullInt64: - return "bigint" - case float32, float64, sql.NullFloat64: - return "real" - case string, sql.NullString: - if size > 0 && size < 65532 { - return fmt.Sprintf("varchar(%d)", size) - } - return "text" - default: - panic("invalid sql type") - } - case "mysql": - switch column.(type) { - case time.Time: - return "timestamp" - case bool, sql.NullBool: - return "boolean" - case int, int8, int16, int32, uint, uint8, uint16, uint32: - return "int" - case int64, uint64, sql.NullInt64: - return "bigint" - case float32, float64, sql.NullFloat64: - return "double" - case []byte: - if size > 0 && size < 65532 { - return fmt.Sprintf("varbinary(%d)", size) - } - return "longblob" - case string, sql.NullString: - if size > 0 && size < 65532 { - return fmt.Sprintf("varchar(%d)", size) - } - return "longtext" - default: - panic("invalid sql type") - } - - case "postgres": - switch column.(type) { - case time.Time: - return "timestamp with time zone" - case bool, sql.NullBool: - return "boolean" - case int, int8, int16, int32, uint, uint8, uint16, uint32: - return "integer" - case int64, uint64, sql.NullInt64: - return "bigint" - case float32, float64, sql.NullFloat64: - return "double precision" - case []byte: - return "bytea" - case string, sql.NullString: - if size > 0 && size < 65532 { - return fmt.Sprintf("varchar(%d)", size) - } - return "text" - default: - panic("invalid sql type") + typ = "bigserial" } default: panic("unsupported sql adaptor, please submit an issue in github") } + return typ +} + +func getSqlType(adaptor string, column interface{}, tag string) string { + column = formatColumnValue(column) + typ, addational_typ, size := parseTag(tag) + + if typ == "-" { + return "" + } + + if len(typ) == 0 { + switch adaptor { + case "sqlite3": + switch column.(type) { + case time.Time: + typ = "datetime" + case bool, sql.NullBool: + typ = "bool" + case int, int8, int16, int32, uint, uint8, uint16, uint32: + typ = "integer" + case int64, uint64, sql.NullInt64: + typ = "bigint" + case float32, float64, sql.NullFloat64: + typ = "real" + case string, sql.NullString: + if size > 0 && size < 65532 { + typ = fmt.Sprintf("varchar(%d)", size) + } else { + typ = "text" + } + default: + panic("invalid sql type") + } + case "mysql": + switch column.(type) { + case time.Time: + typ = "timestamp" + case bool, sql.NullBool: + typ = "boolean" + case int, int8, int16, int32, uint, uint8, uint16, uint32: + typ = "int" + case int64, uint64, sql.NullInt64: + typ = "bigint" + case float32, float64, sql.NullFloat64: + typ = "double" + case []byte: + if size > 0 && size < 65532 { + typ = fmt.Sprintf("varbinary(%d)", size) + } else { + typ = "longblob" + } + case string, sql.NullString: + if size > 0 && size < 65532 { + typ = fmt.Sprintf("varchar(%d)", size) + } else { + typ = "longtext" + } + default: + panic("invalid sql type") + } + + case "postgres": + switch column.(type) { + case time.Time: + typ = "timestamp with time zone" + case bool, sql.NullBool: + typ = "boolean" + case int, int8, int16, int32, uint, uint8, uint16, uint32: + typ = "integer" + case int64, uint64, sql.NullInt64: + typ = "bigint" + case float32, float64, sql.NullFloat64: + typ = "double precision" + case []byte: + typ = "bytea" + case string, sql.NullString: + if size > 0 && size < 65532 { + typ = fmt.Sprintf("varchar(%d)", size) + } else { + typ = "text" + } + default: + panic("invalid sql type") + } + default: + panic("unsupported sql adaptor, please submit an issue in github") + } + } + + if len(addational_typ) > 0 { + typ = typ + " " + addational_typ + } + return typ }