diff --git a/.gitignore b/.gitignore index 912d58f7..c14d6005 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ -TODO +TODO* documents coverage.txt _book diff --git a/association.go b/association.go new file mode 100644 index 00000000..17f8f4a5 --- /dev/null +++ b/association.go @@ -0,0 +1,5 @@ +package gorm + +// Association Mode contains some helper methods to handle relationship things easily. +type Association struct { +} diff --git a/chainable_api.go b/chainable_api.go new file mode 100644 index 00000000..d8f2116c --- /dev/null +++ b/chainable_api.go @@ -0,0 +1,138 @@ +package gorm + +// Model specify the model you would like to run db operations +// // update all users's name to `hello` +// db.Model(&User{}).Update("name", "hello") +// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello` +// db.Model(&user).Update("name", "hello") +func (db *DB) Model(value interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +// Table specify the table you would like to run db operations +func (db *DB) Table(name string) (tx *DB) { + tx = db.getInstance() + return +} + +// Select specify fields that you want when querying, creating, updating +func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +// Omit specify fields that you want to ignore when creating, updating and querying +func (db *DB) Omit(columns ...string) (tx *DB) { + tx = db.getInstance() + return +} + +func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +// Not add NOT condition +func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +// Or add OR conditions +func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +// Joins specify Joins conditions +// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) +func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +// Group specify the group method on the find +func (db *DB) Group(column string) (tx *DB) { + tx = db.getInstance() + return +} + +// Having specify HAVING conditions for GROUP BY +func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +// Order specify order when retrieve records from database +// db.Order("name DESC") +// db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression +func (db *DB) Order(value interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +// Limit specify the number of records to be retrieved +func (db *DB) Limit(limit int64) (tx *DB) { + tx = db.getInstance() + return +} + +// Offset specify the number of records to skip before starting to return the records +func (db *DB) Offset(offset int64) (tx *DB) { + tx = db.getInstance() + return +} + +// Scopes pass current database connection to arguments `func(*DB) *DB`, which could be used to add conditions dynamically +// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB { +// return db.Where("amount > ?", 1000) +// } +// +// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB { +// return func (db *gorm.DB) *gorm.DB { +// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status) +// } +// } +// +// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders) +// Refer https://jinzhu.github.io/gorm/crud.html#scopes +func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) { + for _, f := range funcs { + db = f(db) + } + return db +} + +//Preloads only preloads relations, don`t touch out +func (db *DB) Preloads(out interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +// Preload preload associations with given conditions +// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) +func (db *DB) Preload(column string, conditions ...interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +func (db *DB) Assign(attrs ...interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +func (db *DB) Attrs(attrs ...interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +func (db *DB) Unscoped() (tx *DB) { + tx = db.getInstance() + return +} + +func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) { + tx = db.getInstance() + return +} diff --git a/clause/clause.go b/clause/clause.go new file mode 100644 index 00000000..4495a9d5 --- /dev/null +++ b/clause/clause.go @@ -0,0 +1,53 @@ +package clause + +// Builder builder interface +type BuilderInterface interface { + Write(sql ...string) error + WriteQuoted(field interface{}) error + AddVar(vars ...interface{}) string + Quote(field interface{}) string +} + +// Interface clause interface +type Interface interface { + Name() string + Build(builder BuilderInterface) +} + +// NegationBuilder negation condition builder +type NegationBuilder interface { + NegationBuild(builder BuilderInterface) +} + +// Where where clause +type Where struct { +} + +// Select select attrs when querying, updating, creating +type Select struct { + Omit bool +} + +// Join join clause +type Join struct { +} + +// GroupBy group by clause +type GroupBy struct { +} + +// Having having clause +type Having struct { +} + +// Order order clause +type Order struct { +} + +// Limit limit clause +type Limit struct { +} + +// Offset offset clause +type Offset struct { +} diff --git a/clause/expr.go b/clause/expr.go new file mode 100644 index 00000000..94edb702 --- /dev/null +++ b/clause/expr.go @@ -0,0 +1,19 @@ +package clause + +type ExprInterface interface { +} + +type Expr struct { +} + +type Average struct { +} + +type Minimum struct { +} + +type Maximum struct { +} + +type Sum struct { +} diff --git a/clause/operators.go b/clause/operators.go new file mode 100644 index 00000000..331abea7 --- /dev/null +++ b/clause/operators.go @@ -0,0 +1,195 @@ +package clause + +import "strings" + +type AddConditions []Interface + +func (cs AddConditions) Build(builder BuilderInterface) { + for idx, c := range cs { + if idx > 0 { + builder.Write(" AND ") + } + c.Build(builder) + } +} + +type ORConditions []Interface + +func (cs ORConditions) Build(builder BuilderInterface) { + for idx, c := range cs { + if idx > 0 { + builder.Write(" OR ") + } + c.Build(builder) + } +} + +type NotConditions []Interface + +func (cs NotConditions) Build(builder BuilderInterface) { + for idx, c := range cs { + if idx > 0 { + builder.Write(" AND ") + } + + if negationBuilder, ok := c.(NegationBuilder); ok { + negationBuilder.NegationBuild(builder) + } else { + builder.Write(" NOT ") + c.Build(builder) + } + } +} + +// Raw raw sql for where +type Raw struct { + SQL string + Values []interface{} +} + +func (raw Raw) Build(builder BuilderInterface) { + sql := raw.SQL + for _, v := range raw.Values { + sql = strings.Replace(sql, " ? ", " "+builder.AddVar(v)+" ", 1) + } + builder.Write(sql) +} + +// IN Whether a value is within a set of values +type IN struct { + Column interface{} + Values []interface{} +} + +func (in IN) Build(builder BuilderInterface) { + builder.WriteQuoted(in.Column) + + if len(in.Values) == 0 { + builder.Write(" IN (NULL)") + } else { + builder.Write(" IN (", builder.AddVar(in.Values...), ")") + } +} + +func (in IN) NegationBuild(builder BuilderInterface) { + if len(in.Values) != 0 { + builder.WriteQuoted(in.Column) + builder.Write(" NOT IN (", builder.AddVar(in.Values...), ")") + } +} + +// Eq equal to for where +type Eq struct { + Column interface{} + Value interface{} +} + +func (eq Eq) Build(builder BuilderInterface) { + builder.WriteQuoted(eq.Column) + + if eq.Value == nil { + builder.Write(" IS NULL") + } else { + builder.Write(" = ", builder.AddVar(eq.Value)) + } +} + +func (eq Eq) NegationBuild(builder BuilderInterface) { + Neq{eq.Column, eq.Value}.Build(builder) +} + +// Neq not equal to for where +type Neq struct { + Column interface{} + Value interface{} +} + +func (neq Neq) Build(builder BuilderInterface) { + builder.WriteQuoted(neq.Column) + + if neq.Value == nil { + builder.Write(" IS NOT NULL") + } else { + builder.Write(" <> ", builder.AddVar(neq.Value)) + } +} + +func (neq Neq) NegationBuild(builder BuilderInterface) { + Eq{neq.Column, neq.Value}.Build(builder) +} + +// Gt greater than for where +type Gt struct { + Column interface{} + Value interface{} +} + +func (gt Gt) Build(builder BuilderInterface) { + builder.WriteQuoted(gt.Column) + builder.Write(" > ", builder.AddVar(gt.Value)) +} + +func (gt Gt) NegationBuild(builder BuilderInterface) { + Lte{gt.Column, gt.Value}.Build(builder) +} + +// Gte greater than or equal to for where +type Gte struct { + Column interface{} + Value interface{} +} + +func (gte Gte) Build(builder BuilderInterface) { + builder.WriteQuoted(gte.Column) + builder.Write(" >= ", builder.AddVar(gte.Value)) +} + +func (gte Gte) NegationBuild(builder BuilderInterface) { + Lt{gte.Column, gte.Value}.Build(builder) +} + +// Lt less than for where +type Lt struct { + Column interface{} + Value interface{} +} + +func (lt Lt) Build(builder BuilderInterface) { + builder.WriteQuoted(lt.Column) + builder.Write(" < ", builder.AddVar(lt.Value)) +} + +func (lt Lt) NegationBuild(builder BuilderInterface) { + Gte{lt.Column, lt.Value}.Build(builder) +} + +// Lte less than or equal to for where +type Lte struct { + Column interface{} + Value interface{} +} + +func (lte Lte) Build(builder BuilderInterface) { + builder.WriteQuoted(lte.Column) + builder.Write(" <= ", builder.AddVar(lte.Value)) +} + +func (lte Lte) NegationBuild(builder BuilderInterface) { + Gt{lte.Column, lte.Value}.Build(builder) +} + +// Like whether string matches regular expression +type Like struct { + Column interface{} + Value interface{} +} + +func (like Like) Build(builder BuilderInterface) { + builder.WriteQuoted(like.Column) + builder.Write(" LIKE ", builder.AddVar(like.Value)) +} + +func (like Like) NegationBuild(builder BuilderInterface) { + builder.WriteQuoted(like.Column) + builder.Write(" NOT LIKE ", builder.AddVar(like.Value)) +} diff --git a/finisher_api.go b/finisher_api.go new file mode 100644 index 00000000..687843e3 --- /dev/null +++ b/finisher_api.go @@ -0,0 +1,154 @@ +package gorm + +import ( + "database/sql" +) + +func (db *DB) Count(sql string, values ...interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +// First find first record that match given conditions, order by primary key +func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +// Take return a record that match given conditions, the order will depend on the database implementation +func (db *DB) Take(out interface{}, where ...interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +// Last find last record that match given conditions, order by primary key +func (db *DB) Last(out interface{}, where ...interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +// Find find records that match given conditions +func (db *DB) Find(out interface{}, where ...interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +// Scan scan value to a struct + +func (db *DB) Row() *sql.Row { + // TODO + return nil +} + +func (db *DB) Rows() (*sql.Rows, error) { + // TODO + return nil, nil +} + +func (db *DB) Scan(dest interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +func (db *DB) ScanRows(rows *sql.Rows, result interface{}) error { + return nil +} + +// Create insert the value into database +func (db *DB) Create(value interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +// Save update value in database, if the value doesn't have primary key, will insert it +func (db *DB) Save(value interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +// Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update +func (db *DB) Update(column string, value interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +// Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update +func (db *DB) Updates(values interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +func (db *DB) UpdateColumn(attrs ...interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +func (db *DB) UpdateColumns(values interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +func (db *DB) FirstOrCreate(out interface{}, where ...interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +func (db *DB) FirstOrInit(out interface{}, where ...interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition +func (db *DB) Delete(value interface{}, where ...interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +func (db *DB) Related(value interface{}, foreignKeys ...string) (tx *DB) { + tx = db.getInstance() + return +} + +func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) { + panicked := true + tx := db.Begin(opts...) + defer func() { + // Make sure to rollback when panic, Block error or Commit error + if panicked || err != nil { + tx.Rollback() + } + }() + + err = fc(tx) + + if err == nil { + err = tx.Commit().Error + } + + panicked = false + return +} + +func (db *DB) Begin(opts ...*sql.TxOptions) (tx *DB) { + tx = db.getInstance() + return +} + +func (db *DB) Commit() (tx *DB) { + tx = db.getInstance() + return +} + +func (db *DB) Rollback() (tx *DB) { + tx = db.getInstance() + return +} + +func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +func (db *DB) Association(column string) *Association { + return nil +} diff --git a/gorm.go b/gorm.go index 274f4c62..1b6d88df 100644 --- a/gorm.go +++ b/gorm.go @@ -1,8 +1,10 @@ package gorm import ( + "context" "time" + "github.com/jinzhu/gorm/clause" "github.com/jinzhu/gorm/logger" ) @@ -38,9 +40,69 @@ type Model struct { // Dialector GORM database dialector type Dialector interface { Migrator() Migrator + BindVar(stmt Statement, v interface{}) string +} + +// Result +type Result struct { + Error error + RowsAffected int64 + Statement *Statement } // DB GORM DB definition type DB struct { *Config + Dialector + Result + Context context.Context +} + +// WithContext change current instance db's context to ctx +func (db *DB) WithContext(ctx context.Context) *DB { + tx := db.getInstance() + tx.Context = ctx + return tx +} + +// Set store value with key into current db instance's context +func (db *DB) Set(key string, value interface{}) *DB { + tx := db.getInstance() + tx.Statement.Settings.Store(key, value) + return tx +} + +// Get get value with key from current db instance's context +func (db *DB) Get(key string) (interface{}, bool) { + if db.Statement != nil { + return db.Statement.Settings.Load(key) + } + return nil, false +} + +func (db *DB) Close() *DB { + // TODO + return db +} + +func (db *DB) getInstance() *DB { + // db.Result.Statement == nil means root DB + if db.Result.Statement == nil { + return &DB{ + Config: db.Config, + Dialector: db.Dialector, + Context: context.Background(), + Result: Result{ + Statement: &Statement{DB: db, Clauses: map[string][]clause.Interface{}}, + }, + } + } + + return db +} + +// Debug start debug mode +func (db *DB) Debug() (tx *DB) { + tx = db.getInstance() + return } diff --git a/model/model.go b/model/model.go new file mode 100644 index 00000000..316f3ab5 --- /dev/null +++ b/model/model.go @@ -0,0 +1,37 @@ +package model + +import ( + "reflect" +) + +type Model struct { + ModelType reflect.Type + Table string + PrioritizedPrimaryField *Field + PrimaryFields []*Field + Fields []*Field + FieldsByName map[string]*Field + FieldsByDBName map[string]*Field + Relationships Relationships +} + +type Field struct { + Name string + DBName string + DataType reflect.Type + DBDataType string + Tag reflect.StructTag + TagSettings map[string]string + PrimaryKey bool + AutoIncrement bool + Creatable bool + Updatable bool + Nullable bool + Unique bool + Precision int + Size int + HasDefaultValue bool + DefaultValue string + StructField reflect.StructField + Model *Model +} diff --git a/model/relationship.go b/model/relationship.go new file mode 100644 index 00000000..60b0751e --- /dev/null +++ b/model/relationship.go @@ -0,0 +1,37 @@ +package model + +// RelationshipType relationship type +type RelationshipType string + +const ( + HasOneRel RelationshipType = "has_one" // HasOneRel has one relationship + HasManyRel RelationshipType = "has_many" // HasManyRel has many relationship + BelongsToRel RelationshipType = "belongs_to" // BelongsToRel belongs to relationship + Many2ManyRel RelationshipType = "many_to_many" // Many2ManyRel many to many relationship +) + +type Relationships struct { + HasOne map[string]*Relationship + BelongsTo map[string]*Relationship + HasMany map[string]*Relationship + Many2Many map[string]*Relationship +} + +type Relationship struct { + Type RelationshipType + ForeignKeys []*RelationField // self + AssociationForeignKeys []*RelationField // association + JoinTable *JoinTable +} + +type RelationField struct { + *Field + PolymorphicField *Field + PolymorphicValue string +} + +type JoinTable struct { + Table string + ForeignKeys []*RelationField + AssociationForeignKeys []*RelationField +} diff --git a/statement.go b/statement.go new file mode 100644 index 00000000..21e95e11 --- /dev/null +++ b/statement.go @@ -0,0 +1,68 @@ +package gorm + +import ( + "bytes" + "context" + "database/sql" + "fmt" + "strings" + "sync" + + "github.com/jinzhu/gorm/clause" +) + +// Statement statement +type Statement struct { + Dest interface{} + Table interface{} + Clauses map[string][]clause.Interface + Settings sync.Map + Context context.Context + DB *DB + StatementBuilder +} + +// StatementBuilder statement builder +type StatementBuilder struct { + SQL bytes.Buffer + Vars []interface{} + NamedVars []sql.NamedArg +} + +// Write write string +func (stmt Statement) Write(sql ...string) (err error) { + for _, s := range sql { + _, err = stmt.SQL.WriteString(s) + } + return +} + +// WriteQuoted write quoted field +func (stmt Statement) WriteQuoted(field interface{}) (err error) { + _, err = stmt.SQL.WriteString(stmt.Quote(field)) + return +} + +// Write write string +func (stmt Statement) AddVar(vars ...interface{}) string { + var placeholders []string + for _, v := range vars { + if namedArg, ok := v.(sql.NamedArg); ok && len(namedArg.Name) > 0 { + stmt.NamedVars = append(stmt.NamedVars, namedArg) + placeholders = append(placeholders, "@"+namedArg.Name) + } else { + placeholders = append(placeholders, stmt.DB.Dialector.BindVar(stmt, v)) + } + } + return strings.Join(placeholders, ",") +} + +// Quote returns quoted value +func (stmt Statement) Quote(field interface{}) (str string) { + return fmt.Sprint(field) +} + +// AddClause add clause +func (s Statement) AddClause(clause clause.Interface) { + s.Clauses[clause.Name()] = append(s.Clauses[clause.Name()], clause) +}