From a7f62c24b0bba7914ae25d2fd5843e446f8a90d8 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 26 Oct 2013 17:56:00 +0800 Subject: [PATCH] yay, support primary key --- model.go | 10 +++++++++- orm.go | 2 +- sql_type.go | 25 +++++++++++++++++++++++-- 3 files changed, 33 insertions(+), 4 deletions(-) diff --git a/model.go b/model.go index 62da0ea4..e22a5c64 100644 --- a/model.go +++ b/model.go @@ -23,6 +23,10 @@ func (s *Orm) toModel(value interface{}) *Model { return &Model{Data: value, driver: s.driver} } +func (m *Model) PrimaryKey() string { + return "Id" +} + func (m *Model) Fields() (fields []Field) { typ := reflect.TypeOf(m.Data).Elem() @@ -33,7 +37,11 @@ func (m *Model) Fields() (fields []Field) { field.Name = p.Name field.DbName = toSnake(p.Name) field.Value = reflect.ValueOf(m.Data).Elem().FieldByName(p.Name).Interface() - field.SqlType = getSqlType(m.driver, field.Value, 0) + if m.PrimaryKey() == p.Name { + field.SqlType = getPrimaryKeySqlType(m.driver, field.Value, 0) + } else { + field.SqlType = getSqlType(m.driver, field.Value, 0) + } fields = append(fields, field) } } diff --git a/orm.go b/orm.go index 53d4aa24..ed8494f4 100644 --- a/orm.go +++ b/orm.go @@ -28,7 +28,7 @@ type Orm struct { func (s *Orm) setModel(model interface{}) (err error) { s.Model = s.toModel(model) s.TableName = s.Model.TableName() - s.PrimaryKey = "id" + s.PrimaryKey = s.Model.PrimaryKey() return } diff --git a/sql_type.go b/sql_type.go index 6069967d..2f57ba1d 100644 --- a/sql_type.go +++ b/sql_type.go @@ -5,6 +5,27 @@ import ( "time" ) +func getPrimaryKeySqlType(adaptor string, column interface{}, size int) string { + switch adaptor { + case "mysql": + suffix_str := " NOT NULL AUTO_INCREMENT PRIMARY KEY" + switch column.(type) { + case int, int8, int16, int32, uint, uint8, uint16, uint32: + return "int" + suffix_str + case int64, uint64: + return "bigint" + suffix_str + } + case "postgres": + switch column.(type) { + case int, int8, int16, int32, uint, uint8, uint16, uint32: + return "serial" + case int64, uint64: + return "bigserial" + } + } + panic("unsupported sql adaptor, please submit an issue in github") +} + func getSqlType(adaptor string, column interface{}, size int) string { switch adaptor { case "mysql": @@ -55,7 +76,7 @@ func getSqlType(adaptor string, column interface{}, size int) string { default: panic("invalid sql type") } + default: + panic("unsupported sql adaptor, please submit an issue in github") } - - panic("unsupported sql adaptor, please submit an issue in github") }