mirror of https://github.com/go-gorm/gorm.git
Work on clauses
This commit is contained in:
parent
8cb15cadde
commit
d833efe8b9
13
callbacks.go
13
callbacks.go
|
@ -1,9 +1,11 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/jinzhu/gorm/logger"
|
||||
"github.com/jinzhu/gorm/schema"
|
||||
"github.com/jinzhu/gorm/utils"
|
||||
)
|
||||
|
||||
|
@ -67,6 +69,17 @@ func (cs *callbacks) Raw() *processor {
|
|||
}
|
||||
|
||||
func (p *processor) Execute(db *DB) {
|
||||
if stmt := db.Statement; stmt != nil && stmt.Dest != nil {
|
||||
var err error
|
||||
stmt.Schema, err = schema.Parse(stmt.Dest, db.cacheStore, db.NamingStrategy)
|
||||
|
||||
if err != nil && !errors.Is(err, schema.ErrUnsupportedDataType) {
|
||||
db.AddError(err)
|
||||
} else if stmt.Table == "" && stmt.Schema != nil {
|
||||
stmt.Table = stmt.Schema.Table
|
||||
}
|
||||
}
|
||||
|
||||
for _, f := range p.fns {
|
||||
f(db)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
package callbacks
|
||||
|
||||
import "github.com/jinzhu/gorm"
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
func BeforeCreate(db *gorm.DB) {
|
||||
// before save
|
||||
|
@ -13,6 +17,9 @@ func SaveBeforeAssociations(db *gorm.DB) {
|
|||
}
|
||||
|
||||
func Create(db *gorm.DB) {
|
||||
db.Statement.Build("WITH", "INSERT", "VALUES", "ON_CONFLICT", "RETURNING")
|
||||
|
||||
fmt.Println(db.Statement.SQL.String(), db.Statement.Vars)
|
||||
}
|
||||
|
||||
func SaveAfterAssociations(db *gorm.DB) {
|
||||
|
@ -22,3 +29,17 @@ func AfterCreate(db *gorm.DB) {
|
|||
// after save
|
||||
// after create
|
||||
}
|
||||
|
||||
func objectToFieldsMap(stmt *gorm.Statement) {
|
||||
if stmt.Schema != nil {
|
||||
if s, ok := stmt.Clauses["SELECT"]; ok {
|
||||
s.Attrs
|
||||
}
|
||||
|
||||
if s, ok := stmt.Clauses["OMIT"]; ok {
|
||||
s.Attrs
|
||||
}
|
||||
|
||||
stmt.Schema.LookUpField(s.S)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
package callbacks
|
||||
|
||||
import "github.com/jinzhu/gorm"
|
||||
|
||||
func Query(db *gorm.DB) {
|
||||
}
|
||||
|
||||
func Preload(db *gorm.DB) {
|
||||
}
|
||||
|
||||
func AfterQuery(db *gorm.DB) {
|
||||
// after find
|
||||
}
|
127
clause/clause.go
127
clause/clause.go
|
@ -51,124 +51,21 @@ type OverrideNameInterface interface {
|
|||
OverrideName() string
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Predefined Clauses
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Where where clause
|
||||
type Where struct {
|
||||
AndConditions AddConditions
|
||||
ORConditions []ORConditions
|
||||
builders []Expression
|
||||
}
|
||||
|
||||
func (where Where) Name() string {
|
||||
return "WHERE"
|
||||
}
|
||||
|
||||
func (where Where) Build(builder Builder) {
|
||||
var withConditions bool
|
||||
|
||||
if len(where.AndConditions) > 0 {
|
||||
withConditions = true
|
||||
where.AndConditions.Build(builder)
|
||||
}
|
||||
|
||||
if len(where.builders) > 0 {
|
||||
for _, b := range where.builders {
|
||||
if withConditions {
|
||||
builder.Write(" AND ")
|
||||
}
|
||||
withConditions = true
|
||||
b.Build(builder)
|
||||
}
|
||||
}
|
||||
|
||||
var singleOrConditions []ORConditions
|
||||
for _, or := range where.ORConditions {
|
||||
if len(or) == 1 {
|
||||
if withConditions {
|
||||
builder.Write(" OR ")
|
||||
or.Build(builder)
|
||||
} else {
|
||||
singleOrConditions = append(singleOrConditions, or)
|
||||
}
|
||||
} else {
|
||||
withConditions = true
|
||||
builder.Write(" AND (")
|
||||
or.Build(builder)
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
}
|
||||
|
||||
for _, or := range singleOrConditions {
|
||||
if withConditions {
|
||||
builder.Write(" AND ")
|
||||
or.Build(builder)
|
||||
} else {
|
||||
withConditions = true
|
||||
or.Build(builder)
|
||||
}
|
||||
}
|
||||
|
||||
if !withConditions {
|
||||
builder.Write(" FALSE")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (where Where) MergeExpression(expr Expression) {
|
||||
if w, ok := expr.(Where); ok {
|
||||
where.AndConditions = append(where.AndConditions, w.AndConditions...)
|
||||
where.ORConditions = append(where.ORConditions, w.ORConditions...)
|
||||
where.builders = append(where.builders, w.builders...)
|
||||
} else {
|
||||
where.builders = append(where.builders, expr)
|
||||
}
|
||||
}
|
||||
|
||||
// Select select attrs when querying, updating, creating
|
||||
type Select struct {
|
||||
Omit bool
|
||||
}
|
||||
|
||||
// Join join clause
|
||||
type Join struct {
|
||||
// Column quote with name
|
||||
type Column struct {
|
||||
Table string
|
||||
Type string // left join books on
|
||||
ON []Expression
|
||||
builders []Expression
|
||||
Name string
|
||||
Alias string
|
||||
Raw bool
|
||||
}
|
||||
|
||||
func (join Join) Build(builder Builder) {
|
||||
// TODO
|
||||
func ToColumns(value ...interface{}) []Column {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (join Join) MergeExpression(expr Expression) {
|
||||
if j, ok := expr.(Join); ok {
|
||||
join.builders = append(join.builders, j.builders...)
|
||||
} else {
|
||||
join.builders = append(join.builders, expr)
|
||||
}
|
||||
}
|
||||
|
||||
// GroupBy group by clause
|
||||
type GroupBy struct {
|
||||
}
|
||||
|
||||
// Having having clause
|
||||
type Having struct {
|
||||
}
|
||||
|
||||
// Order order clause
|
||||
type Order struct {
|
||||
}
|
||||
|
||||
// Limit limit clause
|
||||
type Limit struct {
|
||||
}
|
||||
|
||||
// Offset offset clause
|
||||
type Offset struct {
|
||||
// Table quote with name
|
||||
type Table struct {
|
||||
Table string
|
||||
Alias string
|
||||
Raw bool
|
||||
}
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
package clause
|
||||
|
||||
// From from clause
|
||||
type From struct {
|
||||
Tables []Table
|
||||
}
|
||||
|
||||
// Name from clause name
|
||||
func (From) Name() string {
|
||||
return "FROM"
|
||||
}
|
||||
|
||||
// Build build from clause
|
||||
func (from From) Build(builder Builder) {
|
||||
for idx, table := range from.Tables {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
|
||||
builder.WriteQuoted(table)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
package clause
|
||||
|
||||
// GroupBy group by clause
|
||||
type GroupBy struct {
|
||||
Having Where
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
package clause
|
||||
|
||||
// Join join clause
|
||||
type Join struct {
|
||||
Table From // From
|
||||
Type string // INNER, LEFT, RIGHT, FULL, CROSS JOIN
|
||||
Using []Column
|
||||
ON Where
|
||||
}
|
||||
|
||||
// TODO multiple joins
|
||||
|
||||
func (join Join) Build(builder Builder) {
|
||||
// TODO
|
||||
}
|
||||
|
||||
func (join Join) MergeExpression(expr Expression) {
|
||||
// if j, ok := expr.(Join); ok {
|
||||
// join.builders = append(join.builders, j.builders...)
|
||||
// } else {
|
||||
// join.builders = append(join.builders, expr)
|
||||
// }
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
package clause
|
||||
|
||||
// Limit limit clause
|
||||
type Limit struct {
|
||||
Offset uint
|
||||
}
|
|
@ -0,0 +1,4 @@
|
|||
package clause
|
||||
|
||||
type OrderBy struct {
|
||||
}
|
|
@ -2,12 +2,6 @@ package clause
|
|||
|
||||
import "strings"
|
||||
|
||||
// Column quote with name
|
||||
type Column struct {
|
||||
Table string
|
||||
Name string
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Query Expressions
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
package clause
|
||||
|
||||
// Select select attrs when querying, updating, creating
|
||||
type Select struct {
|
||||
SelectColumns []Column
|
||||
OmitColumns []Column
|
||||
}
|
||||
|
||||
// SelectInterface select clause interface
|
||||
type SelectInterface interface {
|
||||
Selects() []Column
|
||||
Omits() []Column
|
||||
}
|
||||
|
||||
func (s Select) Selects() []Column {
|
||||
return s.SelectColumns
|
||||
}
|
||||
|
||||
func (s Select) Omits() []Column {
|
||||
return s.OmitColumns
|
||||
}
|
||||
|
||||
func (s Select) Build(builder Builder) {
|
||||
if len(s.SelectColumns) > 0 {
|
||||
for idx, column := range s.SelectColumns {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
builder.WriteQuoted(column)
|
||||
}
|
||||
} else {
|
||||
builder.WriteByte('*')
|
||||
}
|
||||
}
|
||||
|
||||
func (s Select) MergeExpression(expr Expression) {
|
||||
if v, ok := expr.(SelectInterface); ok {
|
||||
if len(s.SelectColumns) == 0 {
|
||||
s.SelectColumns = v.Selects()
|
||||
}
|
||||
if len(s.OmitColumns) == 0 {
|
||||
s.OmitColumns = v.Omits()
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,77 @@
|
|||
package clause
|
||||
|
||||
// Where where clause
|
||||
type Where struct {
|
||||
AndConditions AddConditions
|
||||
ORConditions []ORConditions
|
||||
builders []Expression
|
||||
}
|
||||
|
||||
// Name where clause name
|
||||
func (where Where) Name() string {
|
||||
return "WHERE"
|
||||
}
|
||||
|
||||
// Build build where clause
|
||||
func (where Where) Build(builder Builder) {
|
||||
var withConditions bool
|
||||
|
||||
if len(where.AndConditions) > 0 {
|
||||
withConditions = true
|
||||
where.AndConditions.Build(builder)
|
||||
}
|
||||
|
||||
if len(where.builders) > 0 {
|
||||
for _, b := range where.builders {
|
||||
if withConditions {
|
||||
builder.Write(" AND ")
|
||||
}
|
||||
withConditions = true
|
||||
b.Build(builder)
|
||||
}
|
||||
}
|
||||
|
||||
var singleOrConditions []ORConditions
|
||||
for _, or := range where.ORConditions {
|
||||
if len(or) == 1 {
|
||||
if withConditions {
|
||||
builder.Write(" OR ")
|
||||
or.Build(builder)
|
||||
} else {
|
||||
singleOrConditions = append(singleOrConditions, or)
|
||||
}
|
||||
} else {
|
||||
withConditions = true
|
||||
builder.Write(" AND (")
|
||||
or.Build(builder)
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
}
|
||||
|
||||
for _, or := range singleOrConditions {
|
||||
if withConditions {
|
||||
builder.Write(" AND ")
|
||||
or.Build(builder)
|
||||
} else {
|
||||
withConditions = true
|
||||
or.Build(builder)
|
||||
}
|
||||
}
|
||||
|
||||
if !withConditions {
|
||||
builder.Write(" FALSE")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// MergeExpression merge where clauses
|
||||
func (where Where) MergeExpression(expr Expression) {
|
||||
if w, ok := expr.(Where); ok {
|
||||
where.AndConditions = append(where.AndConditions, w.AndConditions...)
|
||||
where.ORConditions = append(where.ORConditions, w.ORConditions...)
|
||||
where.builders = append(where.builders, w.builders...)
|
||||
} else {
|
||||
where.builders = append(where.builders, expr)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,4 @@
|
|||
package clause
|
||||
|
||||
type With struct {
|
||||
}
|
|
@ -1,7 +1,5 @@
|
|||
module github.com/jinzhu/gorm/dialects/mysql
|
||||
module github.com/jinzhu/gorm/dialects/sqlite
|
||||
|
||||
go 1.13
|
||||
|
||||
require (
|
||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible
|
||||
)
|
||||
require github.com/mattn/go-sqlite3 v2.0.3+incompatible
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
|
||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
|
|
@ -1,6 +1,7 @@
|
|||
package sqlite
|
||||
|
||||
import (
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/jinzhu/gorm/callbacks"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
|
|
@ -1,15 +1,27 @@
|
|||
package sqlite_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/jinzhu/gorm/dialects/sqlite"
|
||||
"github.com/jinzhu/gorm/tests"
|
||||
)
|
||||
|
||||
var DB *gorm.DB
|
||||
var (
|
||||
DB *gorm.DB
|
||||
err error
|
||||
)
|
||||
|
||||
func TestOpen(t *testing.T) {
|
||||
db, err = gorm.Open("sqlite3", filepath.Join(os.TempDir(), "gorm.db"))
|
||||
func init() {
|
||||
if DB, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{}); err != nil {
|
||||
panic(fmt.Sprintf("failed to initialize database, got error %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlite(t *testing.T) {
|
||||
tests.RunTestsSuit(t, DB)
|
||||
}
|
||||
|
|
|
@ -12,7 +12,9 @@ func (db *DB) Count(sql string, values ...interface{}) (tx *DB) {
|
|||
// First find first record that match given conditions, order by primary key
|
||||
func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.callbacks.Create().Execute(tx.Limit(1).Order("id"))
|
||||
tx.Statement.Dest = out
|
||||
tx.Limit(1)
|
||||
tx.callbacks.Query().Execute(tx)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -35,12 +37,10 @@ func (db *DB) Find(out interface{}, where ...interface{}) (tx *DB) {
|
|||
}
|
||||
|
||||
func (db *DB) Row() *sql.Row {
|
||||
// TODO
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) Rows() (*sql.Rows, error) {
|
||||
// TODO
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
|
5
go.mod
5
go.mod
|
@ -2,4 +2,7 @@ module github.com/jinzhu/gorm
|
|||
|
||||
go 1.13
|
||||
|
||||
require github.com/jinzhu/inflection v1.0.0
|
||||
require (
|
||||
github.com/jinzhu/inflection v1.0.0
|
||||
gopkg.in/errgo.v2 v2.1.0
|
||||
)
|
||||
|
|
31
gorm.go
31
gorm.go
|
@ -2,6 +2,7 @@ package gorm
|
|||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/jinzhu/gorm/clause"
|
||||
|
@ -12,36 +13,28 @@ import (
|
|||
// Config GORM config
|
||||
type Config struct {
|
||||
// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
|
||||
// You can cancel it by setting `SkipDefaultTransaction` to true
|
||||
SkipDefaultTransaction bool // TODO
|
||||
|
||||
// You can disable it by setting `SkipDefaultTransaction` to true
|
||||
SkipDefaultTransaction bool
|
||||
// NamingStrategy tables, columns naming strategy
|
||||
NamingStrategy schema.Namer
|
||||
|
||||
// Logger
|
||||
Logger logger.Interface
|
||||
|
||||
// NowFunc the function to be used when creating a new timestamp
|
||||
NowFunc func() time.Time
|
||||
}
|
||||
|
||||
// Dialector GORM database dialector
|
||||
type Dialector interface {
|
||||
Initialize(*DB) error
|
||||
Migrator() Migrator
|
||||
BindVar(stmt Statement, v interface{}) string
|
||||
}
|
||||
|
||||
// DB GORM DB definition
|
||||
type DB struct {
|
||||
*Config
|
||||
Dialector
|
||||
Instance
|
||||
DB CommonDB
|
||||
clone bool
|
||||
callbacks *callbacks
|
||||
cacheStore *sync.Map
|
||||
}
|
||||
|
||||
// Session session config when create new session
|
||||
// Session session config when create session with Session() method
|
||||
type Session struct {
|
||||
Context context.Context
|
||||
Logger logger.Interface
|
||||
|
@ -71,6 +64,7 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) {
|
|||
Dialector: dialector,
|
||||
clone: true,
|
||||
callbacks: InitializeCallbacks(),
|
||||
cacheStore: &sync.Map{},
|
||||
}
|
||||
|
||||
if dialector != nil {
|
||||
|
@ -113,10 +107,6 @@ func (db *DB) Debug() (tx *DB) {
|
|||
return db.Session(&Session{Logger: db.Logger.LogMode(logger.Info)})
|
||||
}
|
||||
|
||||
func (db *DB) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Set store value with key into current db instance's context
|
||||
func (db *DB) Set(key string, value interface{}) *DB {
|
||||
tx := db.getInstance()
|
||||
|
@ -145,12 +135,15 @@ func (db *DB) getInstance() *DB {
|
|||
}
|
||||
|
||||
return &DB{
|
||||
Config: db.Config,
|
||||
Dialector: db.Dialector,
|
||||
Instance: Instance{
|
||||
Context: ctx,
|
||||
Statement: &Statement{DB: db, Clauses: map[string]clause.Clause{}},
|
||||
},
|
||||
Config: db.Config,
|
||||
Dialector: db.Dialector,
|
||||
DB: db.DB,
|
||||
callbacks: db.callbacks,
|
||||
cacheStore: db.cacheStore,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
// Dialector GORM database dialector
|
||||
type Dialector interface {
|
||||
Initialize(*DB) error
|
||||
Migrator() Migrator
|
||||
BindVar(stmt Statement, v interface{}) string
|
||||
}
|
||||
|
||||
// CommonDB common db interface
|
||||
type CommonDB interface {
|
||||
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
||||
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
|
||||
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
package schema
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"reflect"
|
||||
|
@ -9,6 +10,9 @@ import (
|
|||
"github.com/jinzhu/gorm/logger"
|
||||
)
|
||||
|
||||
// ErrUnsupportedDataType unsupported data type
|
||||
var ErrUnsupportedDataType = errors.New("unsupported data type")
|
||||
|
||||
type Schema struct {
|
||||
Name string
|
||||
ModelType reflect.Type
|
||||
|
@ -50,9 +54,9 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
|
|||
|
||||
if modelType.Kind() != reflect.Struct {
|
||||
if modelType.PkgPath() == "" {
|
||||
return nil, fmt.Errorf("unsupported data %+v when parsing model", dest)
|
||||
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
|
||||
}
|
||||
return nil, fmt.Errorf("unsupported data type %v when parsing model", modelType.PkgPath())
|
||||
return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
|
||||
}
|
||||
|
||||
if v, ok := cacheStore.Load(modelType); ok {
|
||||
|
@ -88,7 +92,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
|
|||
}
|
||||
|
||||
for _, field := range schema.Fields {
|
||||
if field.DBName == "" {
|
||||
if field.DBName == "" && field.DataType != "" {
|
||||
field.DBName = namer.ColumnName(schema.Table, field.Name)
|
||||
}
|
||||
|
||||
|
|
|
@ -2,24 +2,16 @@ package schema_test
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/jinzhu/gorm/schema"
|
||||
"github.com/jinzhu/gorm/tests"
|
||||
)
|
||||
|
||||
func checkSchema(t *testing.T, s *schema.Schema, v schema.Schema, primaryFields []string) {
|
||||
t.Run("CheckSchema/"+s.Name, func(t *testing.T) {
|
||||
equalFieldNames := []string{"Name", "Table"}
|
||||
|
||||
for _, name := range equalFieldNames {
|
||||
got := reflect.ValueOf(s).Elem().FieldByName(name).Interface()
|
||||
expects := reflect.ValueOf(v).FieldByName(name).Interface()
|
||||
if !reflect.DeepEqual(got, expects) {
|
||||
t.Errorf("schema %v %v is not equal, expects: %v, got %v", s, name, expects, got)
|
||||
}
|
||||
}
|
||||
tests.AssertEqual(t, s, v, "Name", "Table")
|
||||
|
||||
for idx, field := range primaryFields {
|
||||
var found bool
|
||||
|
@ -59,15 +51,7 @@ func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(*
|
|||
if parsedField, ok := s.FieldsByName[f.Name]; !ok {
|
||||
t.Errorf("schema %v failed to look up field with name %v", s, f.Name)
|
||||
} else {
|
||||
equalFieldNames := []string{"Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings"}
|
||||
|
||||
for _, name := range equalFieldNames {
|
||||
got := reflect.ValueOf(parsedField).Elem().FieldByName(name).Interface()
|
||||
expects := reflect.ValueOf(f).Elem().FieldByName(name).Interface()
|
||||
if !reflect.DeepEqual(got, expects) {
|
||||
t.Errorf("%v is not equal, expects: %v, got %v", name, expects, got)
|
||||
}
|
||||
}
|
||||
tests.AssertEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings")
|
||||
|
||||
if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field {
|
||||
t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName)
|
||||
|
|
31
statement.go
31
statement.go
|
@ -10,6 +10,7 @@ import (
|
|||
"sync"
|
||||
|
||||
"github.com/jinzhu/gorm/clause"
|
||||
"github.com/jinzhu/gorm/schema"
|
||||
)
|
||||
|
||||
// Instance db instance
|
||||
|
@ -37,6 +38,7 @@ type Statement struct {
|
|||
Clauses map[string]clause.Clause
|
||||
Settings sync.Map
|
||||
DB *DB
|
||||
Schema *schema.Schema
|
||||
|
||||
// SQL Builder
|
||||
SQL strings.Builder
|
||||
|
@ -69,9 +71,32 @@ func (stmt Statement) WriteQuoted(field interface{}) (err error) {
|
|||
}
|
||||
|
||||
// Quote returns quoted value
|
||||
func (stmt Statement) Quote(field interface{}) (str string) {
|
||||
// FIXME
|
||||
return fmt.Sprint(field)
|
||||
func (stmt Statement) Quote(field interface{}) string {
|
||||
var str strings.Builder
|
||||
|
||||
switch v := field.(type) {
|
||||
case clause.Table:
|
||||
str.WriteString(v.Table)
|
||||
if v.Alias != "" {
|
||||
str.WriteString(" AS ")
|
||||
str.WriteString(v.Alias)
|
||||
}
|
||||
case clause.Column:
|
||||
if v.Table != "" {
|
||||
str.WriteString(v.Table)
|
||||
str.WriteByte('.')
|
||||
}
|
||||
|
||||
str.WriteString(v.Name)
|
||||
if v.Alias != "" {
|
||||
str.WriteString(" AS ")
|
||||
str.WriteString(v.Alias)
|
||||
}
|
||||
default:
|
||||
fmt.Sprint(field)
|
||||
}
|
||||
|
||||
return str.String()
|
||||
}
|
||||
|
||||
// Write write string
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
package tests
|
|
@ -0,0 +1,42 @@
|
|||
package tests
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
func Now() *time.Time {
|
||||
now := time.Now()
|
||||
return &now
|
||||
}
|
||||
|
||||
func RunTestsSuit(t *testing.T, db *gorm.DB) {
|
||||
TestCreate(t, db)
|
||||
}
|
||||
|
||||
func TestCreate(t *testing.T, db *gorm.DB) {
|
||||
t.Run("Create", func(t *testing.T) {
|
||||
var user = User{
|
||||
Name: "create",
|
||||
Age: 18,
|
||||
Birthday: Now(),
|
||||
}
|
||||
|
||||
if err := db.Create(&user).Error; err != nil {
|
||||
t.Errorf("errors happened when create: %v", err)
|
||||
}
|
||||
|
||||
if user.ID == 0 {
|
||||
t.Errorf("user's primary key should has value after create, got : %v", user.ID)
|
||||
}
|
||||
|
||||
var newUser User
|
||||
if err := db.Where("id = ?", user.ID).First(&newUser).Error; err != nil {
|
||||
t.Errorf("errors happened when query: %v", err)
|
||||
} else {
|
||||
AssertEqual(t, newUser, user, "Name", "Age", "Birthday")
|
||||
}
|
||||
})
|
||||
}
|
|
@ -0,0 +1,19 @@
|
|||
package tests
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func AssertEqual(t *testing.T, r, e interface{}, names ...string) {
|
||||
for _, name := range names {
|
||||
got := reflect.Indirect(reflect.ValueOf(r)).FieldByName(name).Interface()
|
||||
expects := reflect.Indirect(reflect.ValueOf(e)).FieldByName(name).Interface()
|
||||
|
||||
if !reflect.DeepEqual(got, expects) {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Errorf("expects: %v, got %v", expects, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue