From 11758c647f7140f99f5721cc8025a142d2ba196a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 26 Oct 2013 09:49:40 +0800 Subject: [PATCH] Add some utils --- orm.go | 22 +++++++++++++++++++++- orm_test.go | 25 ++++++++++++++++++++----- sql.go | 40 ++++++++++++++++++++++++++++++++++++++++ utils.go | 28 ++++++++++++++++++++++++++++ 4 files changed, 109 insertions(+), 6 deletions(-) create mode 100644 sql.go diff --git a/orm.go b/orm.go index c12a9968..e1b926f3 100644 --- a/orm.go +++ b/orm.go @@ -2,6 +2,7 @@ package gorm import ( "errors" + "strconv" "database/sql" @@ -11,6 +12,8 @@ type Orm struct { TableName string PrimaryKey string Error error + Sql string + SqlVars []interface{} db *sql.DB whereClause []interface{} @@ -18,6 +21,13 @@ type Orm struct { orderStr string offsetInt int limitInt int + operation string +} + +func (s *Orm) setModel(model interface{}) (err error) { + s.TableName = "user" + s.PrimaryKey = "id" + return } func (s *Orm) Where(querystring interface{}, args ...interface{}) *Orm { @@ -74,10 +84,14 @@ func (s *Orm) Select(value interface{}) *Orm { } func (s *Orm) Save(value interface{}) *Orm { + s.explain(value, "Save") + s.Exec() return s } func (s *Orm) Delete(value interface{}) *Orm { + s.explain(value, "Delete") + s.Exec() return s } @@ -89,11 +103,17 @@ func (s *Orm) Updates(values map[string]string) *Orm { return s } -func (s *Orm) Exec(sql string) *Orm { +func (s *Orm) Exec(sql ...string) *Orm { + if len(sql) == 0 { + s.db.Exec(s.Sql, s.SqlVars...) + } else { + s.db.Exec(sql[0]) + } return s } func (s *Orm) First(out interface{}) *Orm { + s.setModel(out) return s } diff --git a/orm_test.go b/orm_test.go index 6bf69931..8dc1492c 100644 --- a/orm_test.go +++ b/orm_test.go @@ -2,6 +2,7 @@ package gorm import ( "fmt" + "time" "testing" ) @@ -10,14 +11,28 @@ type User struct { Name string } -func TestWhere(t *testing.T) { - db, err := Open("postgres", "user=gorm dbname=gorm") +func getDB() DB { + db, _ := Open("postgres", "user=gorm dbname=gorm") + return db +} - if err != err { - t.Errorf("Error should be nil") +func TestSaveAndFirst(t *testing.T) { + db := getDB() + u := &User{Name: "jinzhu"} + fmt.Println(db.Save(u).Sql) + + fmt.Println(time.Now().String()) + + user := &User{} + db.First(&user) + if user.Name != "jinzhu" { + t.Errorf("User should be saved and fetched correctly") } - orm := db.Where("id = $1", 1, 3, 4, []int64{1, 2, 3}).Where("name = $1", "jinzhu") +} +func TestWhere(t *testing.T) { + db := getDB() + orm := db.Where("id = $1", 1, 3, 4, []int64{1, 2, 3}).Where("name = $1", "jinzhu") user := &User{} orm.First(user) fmt.Println(user) diff --git a/sql.go b/sql.go new file mode 100644 index 00000000..c61b0910 --- /dev/null +++ b/sql.go @@ -0,0 +1,40 @@ +package gorm + +import ( + "strings" + + "fmt" +) + +func (s *Orm) explain(value interface{}, operation string) { + s.setModel(value) + switch operation { + case "Save": + s.saveSql(value) + case "Delete": + s.deleteSql(value) + } + return +} + +func (s *Orm) saveSql(value interface{}) { + columns, values := modelValues(value) + s.Sql = fmt.Sprintf( + "INSERT INTO %v (%v) VALUES (%v)", + s.TableName, + strings.Join(columns, ","), + valuesToBinVar(values), + ) + s.SqlVars = values + return +} + +func (s *Orm) deleteSql(value interface{}) { + s.Sql = fmt.Sprintf("DELETE FROM %v WHERE %v", s.TableName, s.whereSql) + return +} + +func (s *Orm) whereSql() (sql string) { + sql = "1=1" + return +} diff --git a/utils.go b/utils.go index 4506bcfe..cc687a6a 100644 --- a/utils.go +++ b/utils.go @@ -1 +1,29 @@ package gorm + +import ( + "fmt" + "reflect" + "strings" +) + +func modelValues(m interface{}) (columns []string, values []interface{}) { + typ := reflect.TypeOf(m).Elem() + + for i := 0; i < typ.NumField(); i++ { + p := typ.Field(i) + if !p.Anonymous { + columns = append(columns, strings.ToLower(p.Name)) + value := reflect.ValueOf(m).Elem().FieldByName(p.Name) + values = append(values, value.Interface()) + } + } + return +} + +func valuesToBinVar(values []interface{}) string { + var sqls []string + for index, _ := range values { + sqls = append(sqls, fmt.Sprintf("$%d", index+1)) + } + return strings.Join(sqls, ",") +}