Add more clauses

This commit is contained in:
Jinzhu 2020-02-04 08:56:15 +08:00
parent d52ee0aa44
commit 46b1c85f88
14 changed files with 160 additions and 52 deletions

View File

@ -69,16 +69,22 @@ func (cs *callbacks) Raw() *processor {
} }
func (p *processor) Execute(db *DB) { func (p *processor) Execute(db *DB) {
if stmt := db.Statement; stmt != nil && stmt.Dest != nil { if stmt := db.Statement; stmt != nil {
var err error if stmt.Model == nil {
stmt.Schema, err = schema.Parse(stmt.Dest, db.cacheStore, db.NamingStrategy) stmt.Model = stmt.Dest
}
if err != nil && !errors.Is(err, schema.ErrUnsupportedDataType) { if stmt.Model != nil {
var err error
stmt.Schema, err = schema.Parse(stmt.Model, db.cacheStore, db.NamingStrategy)
if err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || stmt.Table == "") {
db.AddError(err) db.AddError(err)
} else if stmt.Table == "" && stmt.Schema != nil { } else if stmt.Table == "" && stmt.Schema != nil {
stmt.Table = stmt.Schema.Table stmt.Table = stmt.Schema.Table
} }
} }
}
for _, f := range p.fns { for _, f := range p.fns {
f(db) f(db)

View File

@ -1,6 +1,8 @@
package callbacks package callbacks
import "github.com/jinzhu/gorm" import (
"github.com/jinzhu/gorm"
)
func RegisterDefaultCallbacks(db *gorm.DB) { func RegisterDefaultCallbacks(db *gorm.DB) {
enableTransaction := func(db *gorm.DB) bool { enableTransaction := func(db *gorm.DB) bool {
@ -17,7 +19,7 @@ func RegisterDefaultCallbacks(db *gorm.DB) {
createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
queryCallback := db.Callback().Query() queryCallback := db.Callback().Query()
queryCallback.Register("gorm:query", BeforeCreate) queryCallback.Register("gorm:query", Query)
queryCallback.Register("gorm:preload", Preload) queryCallback.Register("gorm:preload", Preload)
queryCallback.Register("gorm:after_query", AfterQuery) queryCallback.Register("gorm:after_query", AfterQuery)

View File

@ -22,7 +22,7 @@ func Create(db *gorm.DB) {
Table: clause.Table{Table: db.Statement.Table}, Table: clause.Table{Table: db.Statement.Table},
}) })
db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT", "RETURNING") db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT")
result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
fmt.Println(err) fmt.Println(err)
fmt.Println(result) fmt.Println(result)

View File

@ -1,8 +1,23 @@
package callbacks package callbacks
import "github.com/jinzhu/gorm" import (
"fmt"
"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/clause"
)
func Query(db *gorm.DB) { func Query(db *gorm.DB) {
db.Statement.AddClauseIfNotExists(clause.Select{})
db.Statement.AddClauseIfNotExists(clause.From{
Tables: []clause.Table{{Table: clause.CurrentTable}},
})
db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
fmt.Println(err)
fmt.Println(result)
fmt.Println(db.Statement.SQL.String(), db.Statement.Vars)
} }
func Preload(db *gorm.DB) { func Preload(db *gorm.DB) {

View File

@ -1,6 +1,10 @@
package gorm package gorm
import "github.com/jinzhu/gorm/clause" import (
"fmt"
"github.com/jinzhu/gorm/clause"
)
// Model specify the model you would like to run db operations // Model specify the model you would like to run db operations
// // update all users's name to `hello` // // update all users's name to `hello`
@ -107,6 +111,19 @@ func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) {
// db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression // db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression
func (db *DB) Order(value interface{}) (tx *DB) { func (db *DB) Order(value interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
switch v := value.(type) {
case clause.OrderBy:
db.Statement.AddClause(clause.OrderByClause{
Columns: []clause.OrderBy{v},
})
default:
db.Statement.AddClause(clause.OrderByClause{
Columns: []clause.OrderBy{{
Column: clause.Column{Name: fmt.Sprint(value), Raw: true},
}},
})
}
return return
} }

View File

@ -11,11 +11,6 @@ type Clause struct {
Builder ClauseBuilder Builder ClauseBuilder
} }
// ClauseBuilder clause builder, allows to custmize how to build clause
type ClauseBuilder interface {
Build(Clause, Builder)
}
// Build build clause // Build build clause
func (c Clause) Build(builder Builder) { func (c Clause) Build(builder Builder) {
if c.Builder != nil { if c.Builder != nil {
@ -47,25 +42,21 @@ type Interface interface {
MergeExpression(Expression) MergeExpression(Expression)
} }
// OverrideNameInterface override name interface
type OverrideNameInterface interface { type OverrideNameInterface interface {
OverrideName() string OverrideName() string
} }
// Column quote with name // ClauseBuilder clause builder, allows to custmize how to build clause
type Column struct { type ClauseBuilder interface {
Table string Build(Clause, Builder)
Name string
Alias string
Raw bool
} }
func ToColumns(value ...interface{}) []Column { // Builder builder interface
return nil type Builder interface {
} WriteByte(byte) error
Write(sql ...string) error
// Table quote with name WriteQuoted(field interface{}) error
type Table struct { AddVar(vars ...interface{}) string
Table string Quote(field interface{}) string
Alias string
Raw bool
} }

View File

@ -1,5 +1,10 @@
package clause package clause
const (
PrimaryKey string = "@@@priamry_key@@@"
CurrentTable string = "@@@table@@@"
)
// Expression expression interface // Expression expression interface
type Expression interface { type Expression interface {
Build(builder Builder) Build(builder Builder)
@ -10,13 +15,19 @@ type NegationExpressionBuilder interface {
NegationBuild(builder Builder) NegationBuild(builder Builder)
} }
// Builder builder interface // Column quote with name
type Builder interface { type Column struct {
WriteByte(byte) error Table string
Write(sql ...string) error Name string
WriteQuoted(field interface{}) error Alias string
AddVar(vars ...interface{}) string Raw bool
Quote(field interface{}) string }
// Table quote with name
type Table struct {
Table string
Alias string
Raw bool
} }
// Expr raw expression // Expr raw expression

View File

@ -20,3 +20,10 @@ func (from From) Build(builder Builder) {
builder.WriteQuoted(table) builder.WriteQuoted(table)
} }
} }
// MergeExpression merge order by clauses
func (from From) MergeExpression(expr Expression) {
if v, ok := expr.(From); ok {
from.Tables = append(v.Tables, from.Tables...)
}
}

6
clause/on_conflict.go Normal file
View File

@ -0,0 +1,6 @@
package clause
type OnConflict struct {
ON string // duplicate key
Values *Values // update c=c+1
}

View File

@ -1,4 +1,38 @@
package clause package clause
type OrderBy struct { type OrderBy struct {
Column Column
Desc bool
Reorder bool
}
type OrderByClause struct {
Columns []OrderBy
}
// Name where clause name
func (orderBy OrderByClause) Name() string {
return "ORDER BY"
}
// Build build where clause
func (orderBy OrderByClause) Build(builder Builder) {
for i := len(orderBy.Columns) - 1; i >= 0; i-- {
builder.WriteQuoted(orderBy.Columns[i].Column)
if orderBy.Columns[i].Desc {
builder.Write(" DESC")
}
if orderBy.Columns[i].Reorder {
break
}
}
}
// MergeExpression merge order by clauses
func (orderBy OrderByClause) MergeExpression(expr Expression) {
if v, ok := expr.(OrderByClause); ok {
orderBy.Columns = append(v.Columns, orderBy.Columns...)
}
} }

View File

@ -1,15 +1,19 @@
package clause package clause
// SelectInterface select clause interface
type SelectInterface interface {
Selects() []Column
Omits() []Column
}
// Select select attrs when querying, updating, creating // Select select attrs when querying, updating, creating
type Select struct { type Select struct {
SelectColumns []Column SelectColumns []Column
OmitColumns []Column OmitColumns []Column
} }
// SelectInterface select clause interface func (s Select) Name() string {
type SelectInterface interface { return "SELECT"
Selects() []Column
Omits() []Column
} }
func (s Select) Selects() []Column { func (s Select) Selects() []Column {

View File

@ -2,6 +2,8 @@ package gorm
import ( import (
"database/sql" "database/sql"
"github.com/jinzhu/gorm/clause"
) )
// Create insert the value into database // Create insert the value into database
@ -20,9 +22,11 @@ func (db *DB) Save(value interface{}) (tx *DB) {
// First find first record that match given conditions, order by primary key // First find first record that match given conditions, order by primary key
func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) { func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance().Limit(1).Order(clause.OrderBy{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
Desc: true,
})
tx.Statement.Dest = out tx.Statement.Dest = out
tx.Limit(1)
tx.callbacks.Query().Execute(tx) tx.callbacks.Query().Execute(tx)
return return
} }

View File

@ -63,6 +63,7 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) {
db = &DB{ db = &DB{
Config: config, Config: config,
Dialector: dialector, Dialector: dialector,
ClauseBuilders: map[string]clause.ClauseBuilder{},
clone: true, clone: true,
cacheStore: &sync.Map{}, cacheStore: &sync.Map{},
} }

View File

@ -84,18 +84,28 @@ func (stmt Statement) Quote(field interface{}) string {
switch v := field.(type) { switch v := field.(type) {
case clause.Table: case clause.Table:
str.WriteString(v.Table)
if v.Alias != "" { if v.Alias != "" {
str.WriteString(" AS ") str.WriteString(" AS ")
str.WriteString(v.Alias) str.WriteString(v.Alias)
} }
case clause.Column: case clause.Column:
if v.Table != "" { if v.Table != "" {
if v.Table == clause.CurrentTable {
str.WriteString(stmt.Table)
} else {
str.WriteString(v.Table) str.WriteString(v.Table)
}
str.WriteByte('.') str.WriteByte('.')
} }
if v.Name == clause.PrimaryKey {
if stmt.Schema != nil && stmt.Schema.PrioritizedPrimaryField != nil {
str.WriteString(stmt.Schema.PrioritizedPrimaryField.DBName)
}
} else {
str.WriteString(v.Name) str.WriteString(v.Name)
}
if v.Alias != "" { if v.Alias != "" {
str.WriteString(" AS ") str.WriteString(" AS ")
str.WriteString(v.Alias) str.WriteString(v.Alias)