diff --git a/callbacks/create.go b/callbacks/create.go index 028cdbc4..983b95ce 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/clause" ) func BeforeCreate(db *gorm.DB) { @@ -17,8 +18,14 @@ func SaveBeforeAssociations(db *gorm.DB) { } func Create(db *gorm.DB) { - db.Statement.Build("WITH", "INSERT", "VALUES", "ON_CONFLICT", "RETURNING") - db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) + db.Statement.AddClauseIfNotExists(clause.Insert{ + Table: clause.Table{Table: db.Statement.Table}, + }) + + db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT", "RETURNING") + result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) + fmt.Println(err) + fmt.Println(result) fmt.Println(db.Statement.SQL.String(), db.Statement.Vars) } diff --git a/chainable_api.go b/chainable_api.go index 95d5975c..b577d5cf 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -55,7 +55,9 @@ func (db *DB) Omit(columns ...string) (tx *DB) { func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() - tx.Statement.AddClause(clause.Where{AndConditions: tx.Statement.BuildCondtion(query, args...)}) + tx.Statement.AddClause(clause.Where{ + AndConditions: tx.Statement.BuildCondtion(query, args...), + }) return } @@ -63,7 +65,9 @@ func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.AddClause(clause.Where{ - AndConditions: []clause.Expression{clause.NotConditions(tx.Statement.BuildCondtion(query, args...))}, + AndConditions: []clause.Expression{ + clause.NotConditions(tx.Statement.BuildCondtion(query, args...)), + }, }) return } @@ -72,7 +76,9 @@ func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.AddClause(clause.Where{ - ORConditions: []clause.ORConditions{tx.Statement.BuildCondtion(query, args...)}, + ORConditions: []clause.ORConditions{ + tx.Statement.BuildCondtion(query, args...), + }, }) return } diff --git a/clause/insert.go b/clause/insert.go new file mode 100644 index 00000000..e056b35e --- /dev/null +++ b/clause/insert.go @@ -0,0 +1,34 @@ +package clause + +type Insert struct { + Table Table + Priority string +} + +// Name insert clause name +func (insert Insert) Name() string { + return "INSERT" +} + +// Build build insert clause +func (insert Insert) Build(builder Builder) { + if insert.Priority != "" { + builder.Write(insert.Priority) + builder.WriteByte(' ') + } + + builder.Write("INTO ") + builder.WriteQuoted(insert.Table) +} + +// MergeExpression merge insert clauses +func (insert Insert) MergeExpression(expr Expression) { + if v, ok := expr.(Insert); ok { + if insert.Priority == "" { + insert.Priority = v.Priority + } + if insert.Table.Table == "" { + insert.Table = v.Table + } + } +} diff --git a/clause/value.go b/clause/value.go new file mode 100644 index 00000000..4de0d91e --- /dev/null +++ b/clause/value.go @@ -0,0 +1,39 @@ +package clause + +type Values struct { + Columns []Column + Values [][]interface{} +} + +// Name from clause name +func (Values) Name() string { + return "" +} + +// Build build from clause +func (values Values) Build(builder Builder) { + if len(values.Columns) > 0 { + builder.WriteByte('(') + for idx, column := range values.Columns { + if idx > 0 { + builder.WriteByte(',') + } + builder.WriteQuoted(column) + } + builder.WriteByte(')') + + builder.Write(" VALUES ") + + for idx, value := range values.Values { + builder.WriteByte('(') + if idx > 0 { + builder.WriteByte(',') + } + + builder.Write(builder.AddVar(value...)) + builder.WriteByte(')') + } + } else { + builder.Write("DEFAULT VALUES") + } +} diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go new file mode 100644 index 00000000..3abf05e3 --- /dev/null +++ b/dialects/postgres/postgres.go @@ -0,0 +1,33 @@ +package postgres + +import ( + "database/sql" + + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/callbacks" + _ "github.com/lib/pq" +) + +type Dialector struct { + DSN string +} + +func Open(dsn string) gorm.Dialector { + return &Dialector{DSN: dsn} +} + +func (dialector Dialector) Initialize(db *gorm.DB) (err error) { + // register callbacks + callbacks.RegisterDefaultCallbacks(db) + + db.DB, err = sql.Open("postgres", dialector.DSN) + return +} + +func (Dialector) Migrator() gorm.Migrator { + return nil +} + +func (Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { + return "?" +} diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go index bcd6bd5c..91c3389e 100644 --- a/dialects/sqlite/sqlite.go +++ b/dialects/sqlite/sqlite.go @@ -1,29 +1,33 @@ package sqlite import ( + "database/sql" + "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/callbacks" _ "github.com/mattn/go-sqlite3" ) type Dialector struct { + DSN string } func Open(dsn string) gorm.Dialector { - return &Dialector{} + return &Dialector{DSN: dsn} } -func (Dialector) Initialize(db *gorm.DB) error { +func (dialector Dialector) Initialize(db *gorm.DB) (err error) { // register callbacks callbacks.RegisterDefaultCallbacks(db) - return nil + db.DB, err = sql.Open("sqlite3", dialector.DSN) + return } func (Dialector) Migrator() gorm.Migrator { return nil } -func (Dialector) BindVar(stmt gorm.Statement, v interface{}) string { +func (Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { return "?" } diff --git a/finisher_api.go b/finisher_api.go index c79915d2..a311ca78 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -4,7 +4,16 @@ import ( "database/sql" ) -func (db *DB) Count(sql string, values ...interface{}) (tx *DB) { +// Create insert the value into database +func (db *DB) Create(value interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.Dest = value + tx.callbacks.Create().Execute(tx) + return +} + +// Save update value in database, if the value doesn't have primary key, will insert it +func (db *DB) Save(value interface{}) (tx *DB) { tx = db.getInstance() return } @@ -36,32 +45,12 @@ func (db *DB) Find(out interface{}, where ...interface{}) (tx *DB) { return } -func (db *DB) Row() *sql.Row { - return nil -} - -func (db *DB) Rows() (*sql.Rows, error) { - return nil, nil -} - -// Scan scan value to a struct -func (db *DB) Scan(dest interface{}) (tx *DB) { +func (db *DB) FirstOrInit(out interface{}, where ...interface{}) (tx *DB) { tx = db.getInstance() return } -func (db *DB) ScanRows(rows *sql.Rows, result interface{}) error { - return nil -} - -// Create insert the value into database -func (db *DB) Create(value interface{}) (tx *DB) { - tx = db.getInstance() - return -} - -// Save update value in database, if the value doesn't have primary key, will insert it -func (db *DB) Save(value interface{}) (tx *DB) { +func (db *DB) FirstOrCreate(out interface{}, where ...interface{}) (tx *DB) { tx = db.getInstance() return } @@ -78,7 +67,7 @@ func (db *DB) Updates(values interface{}) (tx *DB) { return } -func (db *DB) UpdateColumn(attrs ...interface{}) (tx *DB) { +func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) { tx = db.getInstance() return } @@ -88,16 +77,6 @@ func (db *DB) UpdateColumns(values interface{}) (tx *DB) { return } -func (db *DB) FirstOrInit(out interface{}, where ...interface{}) (tx *DB) { - tx = db.getInstance() - return -} - -func (db *DB) FirstOrCreate(out interface{}, where ...interface{}) (tx *DB) { - tx = db.getInstance() - return -} - // Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition func (db *DB) Delete(value interface{}, where ...interface{}) (tx *DB) { tx = db.getInstance() @@ -119,6 +98,29 @@ func (db *DB) Association(column string) *Association { return nil } +func (db *DB) Count(value interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +func (db *DB) Row() *sql.Row { + return nil +} + +func (db *DB) Rows() (*sql.Rows, error) { + return nil, nil +} + +// Scan scan value to a struct +func (db *DB) Scan(dest interface{}) (tx *DB) { + tx = db.getInstance() + return +} + +func (db *DB) ScanRows(rows *sql.Rows, result interface{}) error { + return nil +} + func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) { panicked := true tx := db.Begin(opts...) diff --git a/go.mod b/go.mod index 516a9759..1f4d31a2 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,8 @@ module github.com/jinzhu/gorm go 1.13 -require github.com/jinzhu/inflection v1.0.0 +require ( + github.com/jinzhu/inflection v1.0.0 + github.com/lib/pq v1.3.0 + github.com/mattn/go-sqlite3 v2.0.3+incompatible +) diff --git a/gorm.go b/gorm.go index 8ac7e057..a72314bd 100644 --- a/gorm.go +++ b/gorm.go @@ -28,10 +28,11 @@ type DB struct { *Config Dialector Instance - DB CommonDB - clone bool - callbacks *callbacks - cacheStore *sync.Map + DB CommonDB + ClauseBuilders map[string]clause.ClauseBuilder + clone bool + callbacks *callbacks + cacheStore *sync.Map } // Session session config when create session with Session() method @@ -140,11 +141,12 @@ func (db *DB) getInstance() *DB { Context: ctx, Statement: &Statement{DB: db, Clauses: map[string]clause.Clause{}}, }, - Config: db.Config, - Dialector: db.Dialector, - DB: db.DB, - callbacks: db.callbacks, - cacheStore: db.cacheStore, + Config: db.Config, + Dialector: db.Dialector, + ClauseBuilders: db.ClauseBuilders, + DB: db.DB, + callbacks: db.callbacks, + cacheStore: db.cacheStore, } } diff --git a/interfaces.go b/interfaces.go index 98d04592..6ba24dc4 100644 --- a/interfaces.go +++ b/interfaces.go @@ -9,7 +9,7 @@ import ( type Dialector interface { Initialize(*DB) error Migrator() Migrator - BindVar(stmt Statement, v interface{}) string + BindVar(stmt *Statement, v interface{}) string } // CommonDB common db interface diff --git a/statement.go b/statement.go index 4d959cbb..c01be0f5 100644 --- a/statement.go +++ b/statement.go @@ -5,6 +5,7 @@ import ( "database/sql" "database/sql/driver" "fmt" + "log" "strconv" "strings" "sync" @@ -21,7 +22,7 @@ type Instance struct { Statement *Statement } -func (instance Instance) ToSQL(clauses ...string) (string, []interface{}) { +func (instance *Instance) ToSQL(clauses ...string) (string, []interface{}) { if len(clauses) > 0 { instance.Statement.Build(clauses...) } @@ -29,7 +30,7 @@ func (instance Instance) ToSQL(clauses ...string) (string, []interface{}) { } // AddError add error to instance -func (inst Instance) AddError(err error) { +func (inst *Instance) AddError(err error) { if inst.Error == nil { inst.Error = err } else { @@ -55,11 +56,11 @@ type Statement struct { // StatementOptimizer statement optimizer interface type StatementOptimizer interface { - OptimizeStatement(Statement) + OptimizeStatement(*Statement) } // Write write string -func (stmt Statement) Write(sql ...string) (err error) { +func (stmt *Statement) Write(sql ...string) (err error) { for _, s := range sql { _, err = stmt.SQL.WriteString(s) } @@ -67,12 +68,12 @@ func (stmt Statement) Write(sql ...string) (err error) { } // Write write string -func (stmt Statement) WriteByte(c byte) (err error) { +func (stmt *Statement) WriteByte(c byte) (err error) { return stmt.SQL.WriteByte(c) } // WriteQuoted write quoted field -func (stmt Statement) WriteQuoted(field interface{}) (err error) { +func (stmt *Statement) WriteQuoted(field interface{}) (err error) { _, err = stmt.SQL.WriteString(stmt.Quote(field)) return } @@ -107,7 +108,7 @@ func (stmt Statement) Quote(field interface{}) string { } // Write write string -func (stmt Statement) AddVar(vars ...interface{}) string { +func (stmt *Statement) AddVar(vars ...interface{}) string { var placeholders strings.Builder for idx, v := range vars { if idx > 0 { @@ -134,7 +135,7 @@ func (stmt Statement) AddVar(vars ...interface{}) string { } // AddClause add clause -func (stmt Statement) AddClause(v clause.Interface) { +func (stmt *Statement) AddClause(v clause.Interface) { if optimizer, ok := v.(StatementOptimizer); ok { optimizer.OptimizeStatement(stmt) } @@ -154,6 +155,30 @@ func (stmt Statement) AddClause(v clause.Interface) { stmt.Clauses[v.Name()] = c } +// AddClauseIfNotExists add clause if not exists +func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) { + if optimizer, ok := v.(StatementOptimizer); ok { + optimizer.OptimizeStatement(stmt) + } + + log.Println(v.Name()) + if c, ok := stmt.Clauses[v.Name()]; !ok { + if namer, ok := v.(clause.OverrideNameInterface); ok { + c.Name = namer.OverrideName() + } else { + c.Name = v.Name() + } + + if c.Expression != nil { + v.MergeExpression(c.Expression) + } + + c.Expression = v + stmt.Clauses[v.Name()] = c + log.Println(stmt.Clauses[v.Name()]) + } +} + // BuildCondtion build condition func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (conditions []clause.Expression) { if sql, ok := query.(string); ok { @@ -211,7 +236,7 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con } // Build build sql with clauses names -func (stmt Statement) Build(clauses ...string) { +func (stmt *Statement) Build(clauses ...string) { var firstClauseWritten bool for _, name := range clauses { @@ -221,7 +246,11 @@ func (stmt Statement) Build(clauses ...string) { } firstClauseWritten = true - c.Build(stmt) + if b, ok := stmt.DB.ClauseBuilders[name]; ok { + b.Build(c, stmt) + } else { + c.Build(stmt) + } } } // TODO handle named vars