diff --git a/chain.go b/chain.go index ed2b8012..1efb46d8 100644 --- a/chain.go +++ b/chain.go @@ -71,16 +71,6 @@ func (s *Chain) Model(model interface{}) *Chain { return s } -func (s *Chain) Where(querystring interface{}, args ...interface{}) *Chain { - s.whereClause = append(s.whereClause, map[string]interface{}{"query": querystring, "args": args}) - return s -} - -func (s *Chain) Not(querystring interface{}, args ...interface{}) *Chain { - s.notClause = append(s.notClause, map[string]interface{}{"query": querystring, "args": args}) - return s -} - func (s *Chain) Limit(value interface{}) *Chain { if str, err := getInterfaceAsString(value); err == nil { s.limitStr = str @@ -151,16 +141,6 @@ func (s *Chain) Exec(sql string) *Chain { return s } -func (s *Chain) First(out interface{}, where ...interface{}) *Chain { - s.do(out).where(where...).first() - return s -} - -func (s *Chain) Last(out interface{}, where ...interface{}) *Chain { - s.do(out).where(where...).last() - return s -} - func (s *Chain) Attrs(attrs ...interface{}) *Chain { s.initAttrs = append(s.initAttrs, toSearchableMap(attrs...)) return s @@ -196,21 +176,11 @@ func (s *Chain) FirstOrCreate(out interface{}, where ...interface{}) *Chain { return s } -func (s *Chain) Find(out interface{}, where ...interface{}) *Chain { - s.do(out).where(where...).query() - return s -} - func (s *Chain) Pluck(column string, value interface{}) (orm *Chain) { s.do(s.value).pluck(column, value) return s } -func (s *Chain) Or(querystring interface{}, args ...interface{}) *Chain { - s.orClause = append(s.orClause, map[string]interface{}{"query": querystring, "args": args}) - return s -} - func (s *Chain) Unscoped() *Chain { s.unscoped = true return s diff --git a/do.go b/do.go index a1324adb..bb637e19 100644 --- a/do.go +++ b/do.go @@ -14,7 +14,7 @@ import ( type Do struct { chain *Chain - db sqlCommon + db *DB guessedTableName string specifiedTableName string startedTransaction bool @@ -732,6 +732,14 @@ func (s *Do) commit_or_rollback() { } } +func (s *Do) initialize() { + // TODO initializeWithSearchCondition +} + +func (s *Do) updateAttrs_() { + // TODO return false if no updates +} + func (s *Do) initializeWithSearchCondition() { for _, clause := range s.whereClause { switch value := clause["query"].(type) { diff --git a/errors.go b/errors.go index 2c968dbf..e99a7f23 100644 --- a/errors.go +++ b/errors.go @@ -3,6 +3,8 @@ package gorm import "errors" var ( - RecordNotFound = errors.New("Record Not Found") - InvalidSql = errors.New("Invalid SQL") + RecordNotFound = errors.New("Record Not Found") + InvalidSql = errors.New("Invalid SQL") + NoNewAttrs = errors.New("No new Attributes") + NoValidTransaction = errors.New("No valid transaction") ) diff --git a/main.go b/main.go index 0f0ea4db..ef881a47 100644 --- a/main.go +++ b/main.go @@ -1,25 +1,29 @@ package gorm -import "database/sql" +import ( + "errors" + + "database/sql" +) import "github.com/jinzhu/gorm/dialect" -var singularTableName bool -var tagIdentifier string - -func init() { - tagIdentifier = "sql" -} - type DB struct { - db sqlCommon - dialect dialect.Dialect - logger Logger - logMode bool + db sqlCommon + parent *DB + search *search + data interface{} + Error error + dialect dialect.Dialect + tagIdentifier string + singularTable bool + logger Logger + logMode bool } func Open(driver, source string) (db DB, err error) { db.db, err = sql.Open(driver, source) db.dialect = dialect.New(driver) + db.parent = &db return } @@ -30,117 +34,231 @@ func (s *DB) SetPool(n int) { } func (s *DB) SetTagIdentifier(str string) { - tagIdentifier = str + s.parent.tagIdentifier = str } func (s *DB) SetLogger(l Logger) { - s.logger = l + s.parent.logger = l } func (s *DB) LogMode(b bool) { - s.logMode = b + s.parent.logMode = b } -func (s *DB) SingularTable(result bool) { - singularTableName = result +func (s *DB) SingularTable(b bool) { + s.parent.singularTable = b } -func (s *DB) buildChain() *Chain { - return &Chain{db: s.db, d: s} +func (s *DB) clone() *DB { + db := &DB{db: s.db, parent: s.parent, search: s.parent.search.clone()} + db.search.db = db + return db } -func (s *DB) Where(querystring interface{}, args ...interface{}) *Chain { - return s.buildChain().Where(querystring, args...) +func (s *DB) do(data interface{}) *Do { + s.data = data + return &Do{db: s} } -func (s *DB) Not(querystring interface{}, args ...interface{}) *Chain { - return s.buildChain().Not(querystring, args...) +func (s *DB) err(err error) error { + if err != nil { + s.Error = err + s.warn(err) + } + return err } -func (s *DB) First(out interface{}, where ...interface{}) *Chain { - return s.buildChain().First(out, where...) +func (s *DB) hasError() bool { + return s.Error != nil } -func (s *DB) Last(out interface{}, where ...interface{}) *Chain { - return s.buildChain().Last(out, where...) +func (s *DB) Where(query interface{}, args ...interface{}) *DB { + return s.clone().search.where(query, args...).db } -func (s *DB) Attrs(attrs ...interface{}) *Chain { - return s.buildChain().Attrs(attrs...) +func (s *DB) Or(query interface{}, args ...interface{}) *DB { + return s.clone().search.where(query, args...).db } -func (s *DB) Assign(attrs ...interface{}) *Chain { - return s.buildChain().Assign(attrs...) +func (s *DB) Not(query interface{}, args ...interface{}) *DB { + return s.clone().search.not(query, args...).db } -func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *Chain { - return s.buildChain().FirstOrInit(out, where...) +func (s *DB) Limit(value interface{}) *DB { + return s.clone().search.limit(value).db } -func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *Chain { - return s.buildChain().FirstOrCreate(out, where...) +func (s *DB) Offset(value interface{}) *DB { + return s.clone().search.offset(value).db } -func (s *DB) Find(out interface{}, where ...interface{}) *Chain { - return s.buildChain().Find(out, where...) +func (s *DB) Order(value string, reorder ...bool) *DB { + return s.clone().search.order(value, reorder...).db } -func (s *DB) Limit(value interface{}) *Chain { - return s.buildChain().Limit(value) +func (s *DB) Select(value interface{}) *DB { + return s.clone().search.selects(value).db } -func (s *DB) Offset(value interface{}) *Chain { - return s.buildChain().Offset(value) +func (s *DB) Unscoped() *DB { + return s.clone().search.unscoped().db } -func (s *DB) Order(value string, reorder ...bool) *Chain { - return s.buildChain().Order(value, reorder...) +func (s *DB) First(out interface{}, where ...interface{}) *DB { + s.clone().search.limit(1).where(where[0], where[1:]).db.do(out).first() + return s } -func (s *DB) Select(value interface{}) *Chain { - return s.buildChain().Select(value) +func (s *DB) Last(out interface{}, where ...interface{}) *DB { + s.clone().search.limit(1).where(where[0], where[1:]).db.do(out).last() + return s } -func (s *DB) Save(value interface{}) *Chain { - return s.buildChain().Save(value) +func (s *DB) Find(out interface{}, where ...interface{}) *DB { + s.clone().search.where(where[0], where[1:]).db.do(out).query() + return s } -func (s *DB) Delete(value interface{}) *Chain { - return s.buildChain().Delete(value) +func (s *DB) Attrs(attrs ...interface{}) *DB { + return s.clone().search.attrs(attrs...).db } -func (s *DB) Unscoped() *Chain { - return s.buildChain().Unscoped() +func (s *DB) Assign(attrs ...interface{}) *DB { + return s.clone().search.assign(attrs...).db } -func (s *DB) Exec(sql string) *Chain { - return s.buildChain().Exec(sql) +func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { + if s.First(out, where...).Error != nil { + s.clone().do(out).where(where).initialize() + } else { + if len(s.search.assignAttrs) > 0 { + s.do(out).updateAttrs(s.assignAttrs) //updated or not + } + } + return s } -func (s *DB) Model(value interface{}) *Chain { - return s.buildChain().Model(value) +func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB { + if s.First(out, where...).Error != nil { + s.clone().do(out).where(where...).initialize() + s.Save(out) + } else { + if len(s.search.assignAttrs) > 0 { + s.do(out).updateAttrs(s.search.assignAttrs).update() + } + } + return s } -func (s *DB) Table(name string) *Chain { - return s.buildChain().Table(name) +func (s *DB) Save(value interface{}) *DB { + s.do(value).begin().save().commit_or_rollback() + return s } -func (s *DB) CreateTable(value interface{}) *Chain { - return s.buildChain().CreateTable(value) +func (s *DB) Delete(value interface{}) *DB { + s.do(value).bengin().delete(value).commit_or_rollback() + return s } -func (s *DB) DropTable(value interface{}) *Chain { - return s.buildChain().DropTable(value) +func (s *DB) Exec(sql string) *DB { + s.do(nil).exec(sql) + return s } -func (s *DB) AutoMigrate(value interface{}) *Chain { - return s.buildChain().AutoMigrate(value) +func (s *DB) Model(value interface{}) *DB { + c := s.clone() + c.data = value + return c } -func (s *DB) Debug() *Chain { - return s.buildChain().Debug() +func (s *DB) Related(value interface{}, foreign_keys ...string) *DB { + s.clone().do(value).related(s.value, foreign_keys...) + return s } -func (s *DB) Begin() *Chain { - return s.buildChain().Begin() +func (s *DB) Pluck(column string, value interface{}) *DB { + s.clone().search.selects(column).do(s.value).pluck(column, value) + return s +} + +func (s *DB) Count(value interface{}) *DB { + s.clone().search.selects("count(*)").do(s.value).count(value) + return s +} + +func (s *DB) Table(name string) *DB { + return s.clone().table(name).db +} + +// Debug +func (s *DB) Debug() *DB { + s.logMode = true + return s +} + +// Transactions +func (s *DB) Begin() *DB { + c := s.clone() + if db, ok := c.db.(sqlDb); ok { + tx, err := db.Begin() + c.db = interface{}(tx).(sqlCommon) + c.err(err) + } else { + c.err(errors.New("Can't start a transaction.")) + } + return c +} + +func (s *DB) Commit() *DB { + if db, ok := s.db.(sqlTx); ok { + s.err(db.Commit()) + } else { + s.err(NoValidTransaction) + } + return s +} + +func (s *DB) Rollback() *DB { + if db, ok := s.db.(sqlTx); ok { + s.err(db.Rollback()) + } else { + s.err(NoValidTransaction) + } + return s +} + +// Migrations +func (s *DB) CreateTable(value interface{}) *DB { + s.do(value).createTable() + return s +} + +func (s *DB) DropTable(value interface{}) *DB { + s.do(value).dropTable() + return s +} + +func (s *DB) AutoMigrate(value interface{}) *DB { + s.do(value).autoMigrate() + return s +} + +func (s *DB) UpdateColumn(column string, typ string) *DB { + s.do(s.data).updateColumn(column, typ) + return s +} + +func (s *DB) DropColumn(column string) *DB { + s.do(s.data).dropColumn(column) + return s +} + +func (s *DB) AddIndex(column string, index_name ...string) *DB { + s.do(s.data).addIndex(column, index_name...) + return s +} + +func (s *DB) RemoveIndex(column string) *DB { + s.do(s.data).removeIndex(column) + return s } diff --git a/search.go b/search.go index bebf2c02..8345a033 100644 --- a/search.go +++ b/search.go @@ -3,6 +3,7 @@ package gorm import "strconv" type search struct { + db *DB whereClause []map[string]interface{} orClause []map[string]interface{} notClause []map[string]interface{} @@ -12,6 +13,7 @@ type search struct { selectStr string offsetStr string limitStr string + tableName string unscope bool } @@ -90,6 +92,11 @@ func (s *search) unscoped() *search { return s } +func (s *search) table(name string) *search { + s.tableName = name + return s +} + func getInterfaceAsString(value interface{}) (str string, err error) { switch value := value.(type) { case string: