Work on clauses

This commit is contained in:
Jinzhu 2020-02-02 14:40:44 +08:00
parent 8cb15cadde
commit d833efe8b9
27 changed files with 413 additions and 185 deletions

View File

@ -1,9 +1,11 @@
package gorm
import (
"errors"
"fmt"
"github.com/jinzhu/gorm/logger"
"github.com/jinzhu/gorm/schema"
"github.com/jinzhu/gorm/utils"
)
@ -67,6 +69,17 @@ func (cs *callbacks) Raw() *processor {
}
func (p *processor) Execute(db *DB) {
if stmt := db.Statement; stmt != nil && stmt.Dest != nil {
var err error
stmt.Schema, err = schema.Parse(stmt.Dest, db.cacheStore, db.NamingStrategy)
if err != nil && !errors.Is(err, schema.ErrUnsupportedDataType) {
db.AddError(err)
} else if stmt.Table == "" && stmt.Schema != nil {
stmt.Table = stmt.Schema.Table
}
}
for _, f := range p.fns {
f(db)
}

View File

@ -1,6 +1,10 @@
package callbacks
import "github.com/jinzhu/gorm"
import (
"fmt"
"github.com/jinzhu/gorm"
)
func BeforeCreate(db *gorm.DB) {
// before save
@ -13,6 +17,9 @@ func SaveBeforeAssociations(db *gorm.DB) {
}
func Create(db *gorm.DB) {
db.Statement.Build("WITH", "INSERT", "VALUES", "ON_CONFLICT", "RETURNING")
fmt.Println(db.Statement.SQL.String(), db.Statement.Vars)
}
func SaveAfterAssociations(db *gorm.DB) {
@ -22,3 +29,17 @@ func AfterCreate(db *gorm.DB) {
// after save
// after create
}
func objectToFieldsMap(stmt *gorm.Statement) {
if stmt.Schema != nil {
if s, ok := stmt.Clauses["SELECT"]; ok {
s.Attrs
}
if s, ok := stmt.Clauses["OMIT"]; ok {
s.Attrs
}
stmt.Schema.LookUpField(s.S)
}
}

13
callbacks/query.go Normal file
View File

@ -0,0 +1,13 @@
package callbacks
import "github.com/jinzhu/gorm"
func Query(db *gorm.DB) {
}
func Preload(db *gorm.DB) {
}
func AfterQuery(db *gorm.DB) {
// after find
}

View File

@ -51,124 +51,21 @@ type OverrideNameInterface interface {
OverrideName() string
}
////////////////////////////////////////////////////////////////////////////////
// Predefined Clauses
////////////////////////////////////////////////////////////////////////////////
// Where where clause
type Where struct {
AndConditions AddConditions
ORConditions []ORConditions
builders []Expression
// Column quote with name
type Column struct {
Table string
Name string
Alias string
Raw bool
}
func (where Where) Name() string {
return "WHERE"
func ToColumns(value ...interface{}) []Column {
return nil
}
func (where Where) Build(builder Builder) {
var withConditions bool
if len(where.AndConditions) > 0 {
withConditions = true
where.AndConditions.Build(builder)
}
if len(where.builders) > 0 {
for _, b := range where.builders {
if withConditions {
builder.Write(" AND ")
}
withConditions = true
b.Build(builder)
}
}
var singleOrConditions []ORConditions
for _, or := range where.ORConditions {
if len(or) == 1 {
if withConditions {
builder.Write(" OR ")
or.Build(builder)
} else {
singleOrConditions = append(singleOrConditions, or)
}
} else {
withConditions = true
builder.Write(" AND (")
or.Build(builder)
builder.WriteByte(')')
}
}
for _, or := range singleOrConditions {
if withConditions {
builder.Write(" AND ")
or.Build(builder)
} else {
withConditions = true
or.Build(builder)
}
}
if !withConditions {
builder.Write(" FALSE")
}
return
}
func (where Where) MergeExpression(expr Expression) {
if w, ok := expr.(Where); ok {
where.AndConditions = append(where.AndConditions, w.AndConditions...)
where.ORConditions = append(where.ORConditions, w.ORConditions...)
where.builders = append(where.builders, w.builders...)
} else {
where.builders = append(where.builders, expr)
}
}
// Select select attrs when querying, updating, creating
type Select struct {
Omit bool
}
// Join join clause
type Join struct {
Table string
Type string // left join books on
ON []Expression
builders []Expression
}
func (join Join) Build(builder Builder) {
// TODO
}
func (join Join) MergeExpression(expr Expression) {
if j, ok := expr.(Join); ok {
join.builders = append(join.builders, j.builders...)
} else {
join.builders = append(join.builders, expr)
}
}
// GroupBy group by clause
type GroupBy struct {
}
// Having having clause
type Having struct {
}
// Order order clause
type Order struct {
}
// Limit limit clause
type Limit struct {
}
// Offset offset clause
type Offset struct {
// Table quote with name
type Table struct {
Table string
Alias string
Raw bool
}

22
clause/from.go Normal file
View File

@ -0,0 +1,22 @@
package clause
// From from clause
type From struct {
Tables []Table
}
// Name from clause name
func (From) Name() string {
return "FROM"
}
// Build build from clause
func (from From) Build(builder Builder) {
for idx, table := range from.Tables {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(table)
}
}

6
clause/group_by.go Normal file
View File

@ -0,0 +1,6 @@
package clause
// GroupBy group by clause
type GroupBy struct {
Having Where
}

23
clause/join.go Normal file
View File

@ -0,0 +1,23 @@
package clause
// Join join clause
type Join struct {
Table From // From
Type string // INNER, LEFT, RIGHT, FULL, CROSS JOIN
Using []Column
ON Where
}
// TODO multiple joins
func (join Join) Build(builder Builder) {
// TODO
}
func (join Join) MergeExpression(expr Expression) {
// if j, ok := expr.(Join); ok {
// join.builders = append(join.builders, j.builders...)
// } else {
// join.builders = append(join.builders, expr)
// }
}

6
clause/limit.go Normal file
View File

@ -0,0 +1,6 @@
package clause
// Limit limit clause
type Limit struct {
Offset uint
}

4
clause/order_by.go Normal file
View File

@ -0,0 +1,4 @@
package clause
type OrderBy struct {
}

View File

@ -2,12 +2,6 @@ package clause
import "strings"
// Column quote with name
type Column struct {
Table string
Name string
}
////////////////////////////////////////////////////////////////////////////////
// Query Expressions
////////////////////////////////////////////////////////////////////////////////

45
clause/select.go Normal file
View File

@ -0,0 +1,45 @@
package clause
// Select select attrs when querying, updating, creating
type Select struct {
SelectColumns []Column
OmitColumns []Column
}
// SelectInterface select clause interface
type SelectInterface interface {
Selects() []Column
Omits() []Column
}
func (s Select) Selects() []Column {
return s.SelectColumns
}
func (s Select) Omits() []Column {
return s.OmitColumns
}
func (s Select) Build(builder Builder) {
if len(s.SelectColumns) > 0 {
for idx, column := range s.SelectColumns {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(column)
}
} else {
builder.WriteByte('*')
}
}
func (s Select) MergeExpression(expr Expression) {
if v, ok := expr.(SelectInterface); ok {
if len(s.SelectColumns) == 0 {
s.SelectColumns = v.Selects()
}
if len(s.OmitColumns) == 0 {
s.OmitColumns = v.Omits()
}
}
}

77
clause/where.go Normal file
View File

@ -0,0 +1,77 @@
package clause
// Where where clause
type Where struct {
AndConditions AddConditions
ORConditions []ORConditions
builders []Expression
}
// Name where clause name
func (where Where) Name() string {
return "WHERE"
}
// Build build where clause
func (where Where) Build(builder Builder) {
var withConditions bool
if len(where.AndConditions) > 0 {
withConditions = true
where.AndConditions.Build(builder)
}
if len(where.builders) > 0 {
for _, b := range where.builders {
if withConditions {
builder.Write(" AND ")
}
withConditions = true
b.Build(builder)
}
}
var singleOrConditions []ORConditions
for _, or := range where.ORConditions {
if len(or) == 1 {
if withConditions {
builder.Write(" OR ")
or.Build(builder)
} else {
singleOrConditions = append(singleOrConditions, or)
}
} else {
withConditions = true
builder.Write(" AND (")
or.Build(builder)
builder.WriteByte(')')
}
}
for _, or := range singleOrConditions {
if withConditions {
builder.Write(" AND ")
or.Build(builder)
} else {
withConditions = true
or.Build(builder)
}
}
if !withConditions {
builder.Write(" FALSE")
}
return
}
// MergeExpression merge where clauses
func (where Where) MergeExpression(expr Expression) {
if w, ok := expr.(Where); ok {
where.AndConditions = append(where.AndConditions, w.AndConditions...)
where.ORConditions = append(where.ORConditions, w.ORConditions...)
where.builders = append(where.builders, w.builders...)
} else {
where.builders = append(where.builders, expr)
}
}

4
clause/with.go Normal file
View File

@ -0,0 +1,4 @@
package clause
type With struct {
}

View File

@ -1,7 +1,5 @@
module github.com/jinzhu/gorm/dialects/mysql
module github.com/jinzhu/gorm/dialects/sqlite
go 1.13
require (
github.com/mattn/go-sqlite3 v2.0.3+incompatible
)
require github.com/mattn/go-sqlite3 v2.0.3+incompatible

2
dialects/sqlite/go.sum Normal file
View File

@ -0,0 +1,2 @@
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=

View File

@ -1,6 +1,7 @@
package sqlite
import (
"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/callbacks"
_ "github.com/mattn/go-sqlite3"
)

View File

@ -1,15 +1,27 @@
package sqlite_test
import (
"fmt"
"os"
"path/filepath"
"testing"
"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/dialects/sqlite"
"github.com/jinzhu/gorm/tests"
)
var DB *gorm.DB
var (
DB *gorm.DB
err error
)
func TestOpen(t *testing.T) {
db, err = gorm.Open("sqlite3", filepath.Join(os.TempDir(), "gorm.db"))
func init() {
if DB, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{}); err != nil {
panic(fmt.Sprintf("failed to initialize database, got error %v", err))
}
}
func TestSqlite(t *testing.T) {
tests.RunTestsSuit(t, DB)
}

View File

@ -12,7 +12,9 @@ func (db *DB) Count(sql string, values ...interface{}) (tx *DB) {
// First find first record that match given conditions, order by primary key
func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) {
tx = db.getInstance()
tx.callbacks.Create().Execute(tx.Limit(1).Order("id"))
tx.Statement.Dest = out
tx.Limit(1)
tx.callbacks.Query().Execute(tx)
return
}
@ -35,12 +37,10 @@ func (db *DB) Find(out interface{}, where ...interface{}) (tx *DB) {
}
func (db *DB) Row() *sql.Row {
// TODO
return nil
}
func (db *DB) Rows() (*sql.Rows, error) {
// TODO
return nil, nil
}

5
go.mod
View File

@ -2,4 +2,7 @@ module github.com/jinzhu/gorm
go 1.13
require github.com/jinzhu/inflection v1.0.0
require (
github.com/jinzhu/inflection v1.0.0
gopkg.in/errgo.v2 v2.1.0
)

43
gorm.go
View File

@ -2,6 +2,7 @@ package gorm
import (
"context"
"sync"
"time"
"github.com/jinzhu/gorm/clause"
@ -12,36 +13,28 @@ import (
// Config GORM config
type Config struct {
// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
// You can cancel it by setting `SkipDefaultTransaction` to true
SkipDefaultTransaction bool // TODO
// You can disable it by setting `SkipDefaultTransaction` to true
SkipDefaultTransaction bool
// NamingStrategy tables, columns naming strategy
NamingStrategy schema.Namer
// Logger
Logger logger.Interface
// NowFunc the function to be used when creating a new timestamp
NowFunc func() time.Time
}
// Dialector GORM database dialector
type Dialector interface {
Initialize(*DB) error
Migrator() Migrator
BindVar(stmt Statement, v interface{}) string
}
// DB GORM DB definition
type DB struct {
*Config
Dialector
Instance
clone bool
callbacks *callbacks
DB CommonDB
clone bool
callbacks *callbacks
cacheStore *sync.Map
}
// Session session config when create new session
// Session session config when create session with Session() method
type Session struct {
Context context.Context
Logger logger.Interface
@ -67,10 +60,11 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) {
}
db = &DB{
Config: config,
Dialector: dialector,
clone: true,
callbacks: InitializeCallbacks(),
Config: config,
Dialector: dialector,
clone: true,
callbacks: InitializeCallbacks(),
cacheStore: &sync.Map{},
}
if dialector != nil {
@ -113,10 +107,6 @@ func (db *DB) Debug() (tx *DB) {
return db.Session(&Session{Logger: db.Logger.LogMode(logger.Info)})
}
func (db *DB) Close() error {
return nil
}
// Set store value with key into current db instance's context
func (db *DB) Set(key string, value interface{}) *DB {
tx := db.getInstance()
@ -145,12 +135,15 @@ func (db *DB) getInstance() *DB {
}
return &DB{
Config: db.Config,
Dialector: db.Dialector,
Instance: Instance{
Context: ctx,
Statement: &Statement{DB: db, Clauses: map[string]clause.Clause{}},
},
Config: db.Config,
Dialector: db.Dialector,
DB: db.DB,
callbacks: db.callbacks,
cacheStore: db.cacheStore,
}
}

21
interfaces.go Normal file
View File

@ -0,0 +1,21 @@
package gorm
import (
"context"
"database/sql"
)
// Dialector GORM database dialector
type Dialector interface {
Initialize(*DB) error
Migrator() Migrator
BindVar(stmt Statement, v interface{}) string
}
// CommonDB common db interface
type CommonDB interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
}

View File

@ -1,6 +1,7 @@
package schema
import (
"errors"
"fmt"
"go/ast"
"reflect"
@ -9,6 +10,9 @@ import (
"github.com/jinzhu/gorm/logger"
)
// ErrUnsupportedDataType unsupported data type
var ErrUnsupportedDataType = errors.New("unsupported data type")
type Schema struct {
Name string
ModelType reflect.Type
@ -50,9 +54,9 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
if modelType.Kind() != reflect.Struct {
if modelType.PkgPath() == "" {
return nil, fmt.Errorf("unsupported data %+v when parsing model", dest)
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
}
return nil, fmt.Errorf("unsupported data type %v when parsing model", modelType.PkgPath())
return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
}
if v, ok := cacheStore.Load(modelType); ok {
@ -88,7 +92,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
}
for _, field := range schema.Fields {
if field.DBName == "" {
if field.DBName == "" && field.DataType != "" {
field.DBName = namer.ColumnName(schema.Table, field.Name)
}

View File

@ -2,24 +2,16 @@ package schema_test
import (
"fmt"
"reflect"
"strings"
"testing"
"github.com/jinzhu/gorm/schema"
"github.com/jinzhu/gorm/tests"
)
func checkSchema(t *testing.T, s *schema.Schema, v schema.Schema, primaryFields []string) {
t.Run("CheckSchema/"+s.Name, func(t *testing.T) {
equalFieldNames := []string{"Name", "Table"}
for _, name := range equalFieldNames {
got := reflect.ValueOf(s).Elem().FieldByName(name).Interface()
expects := reflect.ValueOf(v).FieldByName(name).Interface()
if !reflect.DeepEqual(got, expects) {
t.Errorf("schema %v %v is not equal, expects: %v, got %v", s, name, expects, got)
}
}
tests.AssertEqual(t, s, v, "Name", "Table")
for idx, field := range primaryFields {
var found bool
@ -59,15 +51,7 @@ func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(*
if parsedField, ok := s.FieldsByName[f.Name]; !ok {
t.Errorf("schema %v failed to look up field with name %v", s, f.Name)
} else {
equalFieldNames := []string{"Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings"}
for _, name := range equalFieldNames {
got := reflect.ValueOf(parsedField).Elem().FieldByName(name).Interface()
expects := reflect.ValueOf(f).Elem().FieldByName(name).Interface()
if !reflect.DeepEqual(got, expects) {
t.Errorf("%v is not equal, expects: %v, got %v", name, expects, got)
}
}
tests.AssertEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings")
if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field {
t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName)

View File

@ -10,6 +10,7 @@ import (
"sync"
"github.com/jinzhu/gorm/clause"
"github.com/jinzhu/gorm/schema"
)
// Instance db instance
@ -37,6 +38,7 @@ type Statement struct {
Clauses map[string]clause.Clause
Settings sync.Map
DB *DB
Schema *schema.Schema
// SQL Builder
SQL strings.Builder
@ -69,9 +71,32 @@ func (stmt Statement) WriteQuoted(field interface{}) (err error) {
}
// Quote returns quoted value
func (stmt Statement) Quote(field interface{}) (str string) {
// FIXME
return fmt.Sprint(field)
func (stmt Statement) Quote(field interface{}) string {
var str strings.Builder
switch v := field.(type) {
case clause.Table:
str.WriteString(v.Table)
if v.Alias != "" {
str.WriteString(" AS ")
str.WriteString(v.Alias)
}
case clause.Column:
if v.Table != "" {
str.WriteString(v.Table)
str.WriteByte('.')
}
str.WriteString(v.Name)
if v.Alias != "" {
str.WriteString(" AS ")
str.WriteString(v.Alias)
}
default:
fmt.Sprint(field)
}
return str.String()
}
// Write write string

View File

@ -1 +0,0 @@
package tests

42
tests/tests.go Normal file
View File

@ -0,0 +1,42 @@
package tests
import (
"testing"
"time"
"github.com/jinzhu/gorm"
)
func Now() *time.Time {
now := time.Now()
return &now
}
func RunTestsSuit(t *testing.T, db *gorm.DB) {
TestCreate(t, db)
}
func TestCreate(t *testing.T, db *gorm.DB) {
t.Run("Create", func(t *testing.T) {
var user = User{
Name: "create",
Age: 18,
Birthday: Now(),
}
if err := db.Create(&user).Error; err != nil {
t.Errorf("errors happened when create: %v", err)
}
if user.ID == 0 {
t.Errorf("user's primary key should has value after create, got : %v", user.ID)
}
var newUser User
if err := db.Where("id = ?", user.ID).First(&newUser).Error; err != nil {
t.Errorf("errors happened when query: %v", err)
} else {
AssertEqual(t, newUser, user, "Name", "Age", "Birthday")
}
})
}

19
tests/utils.go Normal file
View File

@ -0,0 +1,19 @@
package tests
import (
"reflect"
"testing"
)
func AssertEqual(t *testing.T, r, e interface{}, names ...string) {
for _, name := range names {
got := reflect.Indirect(reflect.ValueOf(r)).FieldByName(name).Interface()
expects := reflect.Indirect(reflect.ValueOf(e)).FieldByName(name).Interface()
if !reflect.DeepEqual(got, expects) {
t.Run(name, func(t *testing.T) {
t.Errorf("expects: %v, got %v", expects, got)
})
}
}
}