Work on clauses

This commit is contained in:
Jinzhu 2020-02-02 14:40:44 +08:00
parent 8cb15cadde
commit d833efe8b9
27 changed files with 413 additions and 185 deletions

View File

@ -1,9 +1,11 @@
package gorm package gorm
import ( import (
"errors"
"fmt" "fmt"
"github.com/jinzhu/gorm/logger" "github.com/jinzhu/gorm/logger"
"github.com/jinzhu/gorm/schema"
"github.com/jinzhu/gorm/utils" "github.com/jinzhu/gorm/utils"
) )
@ -67,6 +69,17 @@ func (cs *callbacks) Raw() *processor {
} }
func (p *processor) Execute(db *DB) { func (p *processor) Execute(db *DB) {
if stmt := db.Statement; stmt != nil && stmt.Dest != nil {
var err error
stmt.Schema, err = schema.Parse(stmt.Dest, db.cacheStore, db.NamingStrategy)
if err != nil && !errors.Is(err, schema.ErrUnsupportedDataType) {
db.AddError(err)
} else if stmt.Table == "" && stmt.Schema != nil {
stmt.Table = stmt.Schema.Table
}
}
for _, f := range p.fns { for _, f := range p.fns {
f(db) f(db)
} }

View File

@ -1,6 +1,10 @@
package callbacks package callbacks
import "github.com/jinzhu/gorm" import (
"fmt"
"github.com/jinzhu/gorm"
)
func BeforeCreate(db *gorm.DB) { func BeforeCreate(db *gorm.DB) {
// before save // before save
@ -13,6 +17,9 @@ func SaveBeforeAssociations(db *gorm.DB) {
} }
func Create(db *gorm.DB) { func Create(db *gorm.DB) {
db.Statement.Build("WITH", "INSERT", "VALUES", "ON_CONFLICT", "RETURNING")
fmt.Println(db.Statement.SQL.String(), db.Statement.Vars)
} }
func SaveAfterAssociations(db *gorm.DB) { func SaveAfterAssociations(db *gorm.DB) {
@ -22,3 +29,17 @@ func AfterCreate(db *gorm.DB) {
// after save // after save
// after create // after create
} }
func objectToFieldsMap(stmt *gorm.Statement) {
if stmt.Schema != nil {
if s, ok := stmt.Clauses["SELECT"]; ok {
s.Attrs
}
if s, ok := stmt.Clauses["OMIT"]; ok {
s.Attrs
}
stmt.Schema.LookUpField(s.S)
}
}

13
callbacks/query.go Normal file
View File

@ -0,0 +1,13 @@
package callbacks
import "github.com/jinzhu/gorm"
func Query(db *gorm.DB) {
}
func Preload(db *gorm.DB) {
}
func AfterQuery(db *gorm.DB) {
// after find
}

View File

@ -51,124 +51,21 @@ type OverrideNameInterface interface {
OverrideName() string OverrideName() string
} }
//////////////////////////////////////////////////////////////////////////////// // Column quote with name
// Predefined Clauses type Column struct {
//////////////////////////////////////////////////////////////////////////////// Table string
Name string
// Where where clause Alias string
type Where struct { Raw bool
AndConditions AddConditions
ORConditions []ORConditions
builders []Expression
} }
func (where Where) Name() string { func ToColumns(value ...interface{}) []Column {
return "WHERE" return nil
} }
func (where Where) Build(builder Builder) { // Table quote with name
var withConditions bool type Table struct {
Table string
if len(where.AndConditions) > 0 { Alias string
withConditions = true Raw bool
where.AndConditions.Build(builder)
}
if len(where.builders) > 0 {
for _, b := range where.builders {
if withConditions {
builder.Write(" AND ")
}
withConditions = true
b.Build(builder)
}
}
var singleOrConditions []ORConditions
for _, or := range where.ORConditions {
if len(or) == 1 {
if withConditions {
builder.Write(" OR ")
or.Build(builder)
} else {
singleOrConditions = append(singleOrConditions, or)
}
} else {
withConditions = true
builder.Write(" AND (")
or.Build(builder)
builder.WriteByte(')')
}
}
for _, or := range singleOrConditions {
if withConditions {
builder.Write(" AND ")
or.Build(builder)
} else {
withConditions = true
or.Build(builder)
}
}
if !withConditions {
builder.Write(" FALSE")
}
return
}
func (where Where) MergeExpression(expr Expression) {
if w, ok := expr.(Where); ok {
where.AndConditions = append(where.AndConditions, w.AndConditions...)
where.ORConditions = append(where.ORConditions, w.ORConditions...)
where.builders = append(where.builders, w.builders...)
} else {
where.builders = append(where.builders, expr)
}
}
// Select select attrs when querying, updating, creating
type Select struct {
Omit bool
}
// Join join clause
type Join struct {
Table string
Type string // left join books on
ON []Expression
builders []Expression
}
func (join Join) Build(builder Builder) {
// TODO
}
func (join Join) MergeExpression(expr Expression) {
if j, ok := expr.(Join); ok {
join.builders = append(join.builders, j.builders...)
} else {
join.builders = append(join.builders, expr)
}
}
// GroupBy group by clause
type GroupBy struct {
}
// Having having clause
type Having struct {
}
// Order order clause
type Order struct {
}
// Limit limit clause
type Limit struct {
}
// Offset offset clause
type Offset struct {
} }

22
clause/from.go Normal file
View File

@ -0,0 +1,22 @@
package clause
// From from clause
type From struct {
Tables []Table
}
// Name from clause name
func (From) Name() string {
return "FROM"
}
// Build build from clause
func (from From) Build(builder Builder) {
for idx, table := range from.Tables {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(table)
}
}

6
clause/group_by.go Normal file
View File

@ -0,0 +1,6 @@
package clause
// GroupBy group by clause
type GroupBy struct {
Having Where
}

23
clause/join.go Normal file
View File

@ -0,0 +1,23 @@
package clause
// Join join clause
type Join struct {
Table From // From
Type string // INNER, LEFT, RIGHT, FULL, CROSS JOIN
Using []Column
ON Where
}
// TODO multiple joins
func (join Join) Build(builder Builder) {
// TODO
}
func (join Join) MergeExpression(expr Expression) {
// if j, ok := expr.(Join); ok {
// join.builders = append(join.builders, j.builders...)
// } else {
// join.builders = append(join.builders, expr)
// }
}

6
clause/limit.go Normal file
View File

@ -0,0 +1,6 @@
package clause
// Limit limit clause
type Limit struct {
Offset uint
}

4
clause/order_by.go Normal file
View File

@ -0,0 +1,4 @@
package clause
type OrderBy struct {
}

View File

@ -2,12 +2,6 @@ package clause
import "strings" import "strings"
// Column quote with name
type Column struct {
Table string
Name string
}
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// Query Expressions // Query Expressions
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////

45
clause/select.go Normal file
View File

@ -0,0 +1,45 @@
package clause
// Select select attrs when querying, updating, creating
type Select struct {
SelectColumns []Column
OmitColumns []Column
}
// SelectInterface select clause interface
type SelectInterface interface {
Selects() []Column
Omits() []Column
}
func (s Select) Selects() []Column {
return s.SelectColumns
}
func (s Select) Omits() []Column {
return s.OmitColumns
}
func (s Select) Build(builder Builder) {
if len(s.SelectColumns) > 0 {
for idx, column := range s.SelectColumns {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(column)
}
} else {
builder.WriteByte('*')
}
}
func (s Select) MergeExpression(expr Expression) {
if v, ok := expr.(SelectInterface); ok {
if len(s.SelectColumns) == 0 {
s.SelectColumns = v.Selects()
}
if len(s.OmitColumns) == 0 {
s.OmitColumns = v.Omits()
}
}
}

77
clause/where.go Normal file
View File

@ -0,0 +1,77 @@
package clause
// Where where clause
type Where struct {
AndConditions AddConditions
ORConditions []ORConditions
builders []Expression
}
// Name where clause name
func (where Where) Name() string {
return "WHERE"
}
// Build build where clause
func (where Where) Build(builder Builder) {
var withConditions bool
if len(where.AndConditions) > 0 {
withConditions = true
where.AndConditions.Build(builder)
}
if len(where.builders) > 0 {
for _, b := range where.builders {
if withConditions {
builder.Write(" AND ")
}
withConditions = true
b.Build(builder)
}
}
var singleOrConditions []ORConditions
for _, or := range where.ORConditions {
if len(or) == 1 {
if withConditions {
builder.Write(" OR ")
or.Build(builder)
} else {
singleOrConditions = append(singleOrConditions, or)
}
} else {
withConditions = true
builder.Write(" AND (")
or.Build(builder)
builder.WriteByte(')')
}
}
for _, or := range singleOrConditions {
if withConditions {
builder.Write(" AND ")
or.Build(builder)
} else {
withConditions = true
or.Build(builder)
}
}
if !withConditions {
builder.Write(" FALSE")
}
return
}
// MergeExpression merge where clauses
func (where Where) MergeExpression(expr Expression) {
if w, ok := expr.(Where); ok {
where.AndConditions = append(where.AndConditions, w.AndConditions...)
where.ORConditions = append(where.ORConditions, w.ORConditions...)
where.builders = append(where.builders, w.builders...)
} else {
where.builders = append(where.builders, expr)
}
}

4
clause/with.go Normal file
View File

@ -0,0 +1,4 @@
package clause
type With struct {
}

View File

@ -1,7 +1,5 @@
module github.com/jinzhu/gorm/dialects/mysql module github.com/jinzhu/gorm/dialects/sqlite
go 1.13 go 1.13
require ( require github.com/mattn/go-sqlite3 v2.0.3+incompatible
github.com/mattn/go-sqlite3 v2.0.3+incompatible
)

2
dialects/sqlite/go.sum Normal file
View File

@ -0,0 +1,2 @@
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=

View File

@ -1,6 +1,7 @@
package sqlite package sqlite
import ( import (
"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/callbacks" "github.com/jinzhu/gorm/callbacks"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
) )

View File

@ -1,15 +1,27 @@
package sqlite_test package sqlite_test
import ( import (
"fmt"
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/dialects/sqlite"
"github.com/jinzhu/gorm/tests"
) )
var DB *gorm.DB var (
DB *gorm.DB
err error
)
func TestOpen(t *testing.T) { func init() {
db, err = gorm.Open("sqlite3", filepath.Join(os.TempDir(), "gorm.db")) if DB, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{}); err != nil {
panic(fmt.Sprintf("failed to initialize database, got error %v", err))
}
}
func TestSqlite(t *testing.T) {
tests.RunTestsSuit(t, DB)
} }

View File

@ -12,7 +12,9 @@ func (db *DB) Count(sql string, values ...interface{}) (tx *DB) {
// First find first record that match given conditions, order by primary key // First find first record that match given conditions, order by primary key
func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) { func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.callbacks.Create().Execute(tx.Limit(1).Order("id")) tx.Statement.Dest = out
tx.Limit(1)
tx.callbacks.Query().Execute(tx)
return return
} }
@ -35,12 +37,10 @@ func (db *DB) Find(out interface{}, where ...interface{}) (tx *DB) {
} }
func (db *DB) Row() *sql.Row { func (db *DB) Row() *sql.Row {
// TODO
return nil return nil
} }
func (db *DB) Rows() (*sql.Rows, error) { func (db *DB) Rows() (*sql.Rows, error) {
// TODO
return nil, nil return nil, nil
} }

5
go.mod
View File

@ -2,4 +2,7 @@ module github.com/jinzhu/gorm
go 1.13 go 1.13
require github.com/jinzhu/inflection v1.0.0 require (
github.com/jinzhu/inflection v1.0.0
gopkg.in/errgo.v2 v2.1.0
)

43
gorm.go
View File

@ -2,6 +2,7 @@ package gorm
import ( import (
"context" "context"
"sync"
"time" "time"
"github.com/jinzhu/gorm/clause" "github.com/jinzhu/gorm/clause"
@ -12,36 +13,28 @@ import (
// Config GORM config // Config GORM config
type Config struct { type Config struct {
// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity // GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
// You can cancel it by setting `SkipDefaultTransaction` to true // You can disable it by setting `SkipDefaultTransaction` to true
SkipDefaultTransaction bool // TODO SkipDefaultTransaction bool
// NamingStrategy tables, columns naming strategy // NamingStrategy tables, columns naming strategy
NamingStrategy schema.Namer NamingStrategy schema.Namer
// Logger // Logger
Logger logger.Interface Logger logger.Interface
// NowFunc the function to be used when creating a new timestamp // NowFunc the function to be used when creating a new timestamp
NowFunc func() time.Time NowFunc func() time.Time
} }
// Dialector GORM database dialector
type Dialector interface {
Initialize(*DB) error
Migrator() Migrator
BindVar(stmt Statement, v interface{}) string
}
// DB GORM DB definition // DB GORM DB definition
type DB struct { type DB struct {
*Config *Config
Dialector Dialector
Instance Instance
clone bool DB CommonDB
callbacks *callbacks clone bool
callbacks *callbacks
cacheStore *sync.Map
} }
// Session session config when create new session // Session session config when create session with Session() method
type Session struct { type Session struct {
Context context.Context Context context.Context
Logger logger.Interface Logger logger.Interface
@ -67,10 +60,11 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) {
} }
db = &DB{ db = &DB{
Config: config, Config: config,
Dialector: dialector, Dialector: dialector,
clone: true, clone: true,
callbacks: InitializeCallbacks(), callbacks: InitializeCallbacks(),
cacheStore: &sync.Map{},
} }
if dialector != nil { if dialector != nil {
@ -113,10 +107,6 @@ func (db *DB) Debug() (tx *DB) {
return db.Session(&Session{Logger: db.Logger.LogMode(logger.Info)}) return db.Session(&Session{Logger: db.Logger.LogMode(logger.Info)})
} }
func (db *DB) Close() error {
return nil
}
// Set store value with key into current db instance's context // Set store value with key into current db instance's context
func (db *DB) Set(key string, value interface{}) *DB { func (db *DB) Set(key string, value interface{}) *DB {
tx := db.getInstance() tx := db.getInstance()
@ -145,12 +135,15 @@ func (db *DB) getInstance() *DB {
} }
return &DB{ return &DB{
Config: db.Config,
Dialector: db.Dialector,
Instance: Instance{ Instance: Instance{
Context: ctx, Context: ctx,
Statement: &Statement{DB: db, Clauses: map[string]clause.Clause{}}, Statement: &Statement{DB: db, Clauses: map[string]clause.Clause{}},
}, },
Config: db.Config,
Dialector: db.Dialector,
DB: db.DB,
callbacks: db.callbacks,
cacheStore: db.cacheStore,
} }
} }

21
interfaces.go Normal file
View File

@ -0,0 +1,21 @@
package gorm
import (
"context"
"database/sql"
)
// Dialector GORM database dialector
type Dialector interface {
Initialize(*DB) error
Migrator() Migrator
BindVar(stmt Statement, v interface{}) string
}
// CommonDB common db interface
type CommonDB interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
}

View File

@ -1,6 +1,7 @@
package schema package schema
import ( import (
"errors"
"fmt" "fmt"
"go/ast" "go/ast"
"reflect" "reflect"
@ -9,6 +10,9 @@ import (
"github.com/jinzhu/gorm/logger" "github.com/jinzhu/gorm/logger"
) )
// ErrUnsupportedDataType unsupported data type
var ErrUnsupportedDataType = errors.New("unsupported data type")
type Schema struct { type Schema struct {
Name string Name string
ModelType reflect.Type ModelType reflect.Type
@ -50,9 +54,9 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
if modelType.Kind() != reflect.Struct { if modelType.Kind() != reflect.Struct {
if modelType.PkgPath() == "" { if modelType.PkgPath() == "" {
return nil, fmt.Errorf("unsupported data %+v when parsing model", dest) return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
} }
return nil, fmt.Errorf("unsupported data type %v when parsing model", modelType.PkgPath()) return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
} }
if v, ok := cacheStore.Load(modelType); ok { if v, ok := cacheStore.Load(modelType); ok {
@ -88,7 +92,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
} }
for _, field := range schema.Fields { for _, field := range schema.Fields {
if field.DBName == "" { if field.DBName == "" && field.DataType != "" {
field.DBName = namer.ColumnName(schema.Table, field.Name) field.DBName = namer.ColumnName(schema.Table, field.Name)
} }

View File

@ -2,24 +2,16 @@ package schema_test
import ( import (
"fmt" "fmt"
"reflect"
"strings" "strings"
"testing" "testing"
"github.com/jinzhu/gorm/schema" "github.com/jinzhu/gorm/schema"
"github.com/jinzhu/gorm/tests"
) )
func checkSchema(t *testing.T, s *schema.Schema, v schema.Schema, primaryFields []string) { func checkSchema(t *testing.T, s *schema.Schema, v schema.Schema, primaryFields []string) {
t.Run("CheckSchema/"+s.Name, func(t *testing.T) { t.Run("CheckSchema/"+s.Name, func(t *testing.T) {
equalFieldNames := []string{"Name", "Table"} tests.AssertEqual(t, s, v, "Name", "Table")
for _, name := range equalFieldNames {
got := reflect.ValueOf(s).Elem().FieldByName(name).Interface()
expects := reflect.ValueOf(v).FieldByName(name).Interface()
if !reflect.DeepEqual(got, expects) {
t.Errorf("schema %v %v is not equal, expects: %v, got %v", s, name, expects, got)
}
}
for idx, field := range primaryFields { for idx, field := range primaryFields {
var found bool var found bool
@ -59,15 +51,7 @@ func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(*
if parsedField, ok := s.FieldsByName[f.Name]; !ok { if parsedField, ok := s.FieldsByName[f.Name]; !ok {
t.Errorf("schema %v failed to look up field with name %v", s, f.Name) t.Errorf("schema %v failed to look up field with name %v", s, f.Name)
} else { } else {
equalFieldNames := []string{"Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings"} tests.AssertEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings")
for _, name := range equalFieldNames {
got := reflect.ValueOf(parsedField).Elem().FieldByName(name).Interface()
expects := reflect.ValueOf(f).Elem().FieldByName(name).Interface()
if !reflect.DeepEqual(got, expects) {
t.Errorf("%v is not equal, expects: %v, got %v", name, expects, got)
}
}
if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field { if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field {
t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName)

View File

@ -10,6 +10,7 @@ import (
"sync" "sync"
"github.com/jinzhu/gorm/clause" "github.com/jinzhu/gorm/clause"
"github.com/jinzhu/gorm/schema"
) )
// Instance db instance // Instance db instance
@ -37,6 +38,7 @@ type Statement struct {
Clauses map[string]clause.Clause Clauses map[string]clause.Clause
Settings sync.Map Settings sync.Map
DB *DB DB *DB
Schema *schema.Schema
// SQL Builder // SQL Builder
SQL strings.Builder SQL strings.Builder
@ -69,9 +71,32 @@ func (stmt Statement) WriteQuoted(field interface{}) (err error) {
} }
// Quote returns quoted value // Quote returns quoted value
func (stmt Statement) Quote(field interface{}) (str string) { func (stmt Statement) Quote(field interface{}) string {
// FIXME var str strings.Builder
return fmt.Sprint(field)
switch v := field.(type) {
case clause.Table:
str.WriteString(v.Table)
if v.Alias != "" {
str.WriteString(" AS ")
str.WriteString(v.Alias)
}
case clause.Column:
if v.Table != "" {
str.WriteString(v.Table)
str.WriteByte('.')
}
str.WriteString(v.Name)
if v.Alias != "" {
str.WriteString(" AS ")
str.WriteString(v.Alias)
}
default:
fmt.Sprint(field)
}
return str.String()
} }
// Write write string // Write write string

View File

@ -1 +0,0 @@
package tests

42
tests/tests.go Normal file
View File

@ -0,0 +1,42 @@
package tests
import (
"testing"
"time"
"github.com/jinzhu/gorm"
)
func Now() *time.Time {
now := time.Now()
return &now
}
func RunTestsSuit(t *testing.T, db *gorm.DB) {
TestCreate(t, db)
}
func TestCreate(t *testing.T, db *gorm.DB) {
t.Run("Create", func(t *testing.T) {
var user = User{
Name: "create",
Age: 18,
Birthday: Now(),
}
if err := db.Create(&user).Error; err != nil {
t.Errorf("errors happened when create: %v", err)
}
if user.ID == 0 {
t.Errorf("user's primary key should has value after create, got : %v", user.ID)
}
var newUser User
if err := db.Where("id = ?", user.ID).First(&newUser).Error; err != nil {
t.Errorf("errors happened when query: %v", err)
} else {
AssertEqual(t, newUser, user, "Name", "Age", "Birthday")
}
})
}

19
tests/utils.go Normal file
View File

@ -0,0 +1,19 @@
package tests
import (
"reflect"
"testing"
)
func AssertEqual(t *testing.T, r, e interface{}, names ...string) {
for _, name := range names {
got := reflect.Indirect(reflect.ValueOf(r)).FieldByName(name).Interface()
expects := reflect.Indirect(reflect.ValueOf(e)).FieldByName(name).Interface()
if !reflect.DeepEqual(got, expects) {
t.Run(name, func(t *testing.T) {
t.Errorf("expects: %v, got %v", expects, got)
})
}
}
}