forked from mirror/gorm
Work on clauses
This commit is contained in:
parent
8cb15cadde
commit
d833efe8b9
13
callbacks.go
13
callbacks.go
|
@ -1,9 +1,11 @@
|
||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/jinzhu/gorm/logger"
|
"github.com/jinzhu/gorm/logger"
|
||||||
|
"github.com/jinzhu/gorm/schema"
|
||||||
"github.com/jinzhu/gorm/utils"
|
"github.com/jinzhu/gorm/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -67,6 +69,17 @@ func (cs *callbacks) Raw() *processor {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *processor) Execute(db *DB) {
|
func (p *processor) Execute(db *DB) {
|
||||||
|
if stmt := db.Statement; stmt != nil && stmt.Dest != nil {
|
||||||
|
var err error
|
||||||
|
stmt.Schema, err = schema.Parse(stmt.Dest, db.cacheStore, db.NamingStrategy)
|
||||||
|
|
||||||
|
if err != nil && !errors.Is(err, schema.ErrUnsupportedDataType) {
|
||||||
|
db.AddError(err)
|
||||||
|
} else if stmt.Table == "" && stmt.Schema != nil {
|
||||||
|
stmt.Table = stmt.Schema.Table
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for _, f := range p.fns {
|
for _, f := range p.fns {
|
||||||
f(db)
|
f(db)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,10 @@
|
||||||
package callbacks
|
package callbacks
|
||||||
|
|
||||||
import "github.com/jinzhu/gorm"
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
func BeforeCreate(db *gorm.DB) {
|
func BeforeCreate(db *gorm.DB) {
|
||||||
// before save
|
// before save
|
||||||
|
@ -13,6 +17,9 @@ func SaveBeforeAssociations(db *gorm.DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Create(db *gorm.DB) {
|
func Create(db *gorm.DB) {
|
||||||
|
db.Statement.Build("WITH", "INSERT", "VALUES", "ON_CONFLICT", "RETURNING")
|
||||||
|
|
||||||
|
fmt.Println(db.Statement.SQL.String(), db.Statement.Vars)
|
||||||
}
|
}
|
||||||
|
|
||||||
func SaveAfterAssociations(db *gorm.DB) {
|
func SaveAfterAssociations(db *gorm.DB) {
|
||||||
|
@ -22,3 +29,17 @@ func AfterCreate(db *gorm.DB) {
|
||||||
// after save
|
// after save
|
||||||
// after create
|
// after create
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func objectToFieldsMap(stmt *gorm.Statement) {
|
||||||
|
if stmt.Schema != nil {
|
||||||
|
if s, ok := stmt.Clauses["SELECT"]; ok {
|
||||||
|
s.Attrs
|
||||||
|
}
|
||||||
|
|
||||||
|
if s, ok := stmt.Clauses["OMIT"]; ok {
|
||||||
|
s.Attrs
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt.Schema.LookUpField(s.S)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,13 @@
|
||||||
|
package callbacks
|
||||||
|
|
||||||
|
import "github.com/jinzhu/gorm"
|
||||||
|
|
||||||
|
func Query(db *gorm.DB) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func Preload(db *gorm.DB) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func AfterQuery(db *gorm.DB) {
|
||||||
|
// after find
|
||||||
|
}
|
127
clause/clause.go
127
clause/clause.go
|
@ -51,124 +51,21 @@ type OverrideNameInterface interface {
|
||||||
OverrideName() string
|
OverrideName() string
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
// Column quote with name
|
||||||
// Predefined Clauses
|
type Column struct {
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
// Where where clause
|
|
||||||
type Where struct {
|
|
||||||
AndConditions AddConditions
|
|
||||||
ORConditions []ORConditions
|
|
||||||
builders []Expression
|
|
||||||
}
|
|
||||||
|
|
||||||
func (where Where) Name() string {
|
|
||||||
return "WHERE"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (where Where) Build(builder Builder) {
|
|
||||||
var withConditions bool
|
|
||||||
|
|
||||||
if len(where.AndConditions) > 0 {
|
|
||||||
withConditions = true
|
|
||||||
where.AndConditions.Build(builder)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(where.builders) > 0 {
|
|
||||||
for _, b := range where.builders {
|
|
||||||
if withConditions {
|
|
||||||
builder.Write(" AND ")
|
|
||||||
}
|
|
||||||
withConditions = true
|
|
||||||
b.Build(builder)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var singleOrConditions []ORConditions
|
|
||||||
for _, or := range where.ORConditions {
|
|
||||||
if len(or) == 1 {
|
|
||||||
if withConditions {
|
|
||||||
builder.Write(" OR ")
|
|
||||||
or.Build(builder)
|
|
||||||
} else {
|
|
||||||
singleOrConditions = append(singleOrConditions, or)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
withConditions = true
|
|
||||||
builder.Write(" AND (")
|
|
||||||
or.Build(builder)
|
|
||||||
builder.WriteByte(')')
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, or := range singleOrConditions {
|
|
||||||
if withConditions {
|
|
||||||
builder.Write(" AND ")
|
|
||||||
or.Build(builder)
|
|
||||||
} else {
|
|
||||||
withConditions = true
|
|
||||||
or.Build(builder)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !withConditions {
|
|
||||||
builder.Write(" FALSE")
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (where Where) MergeExpression(expr Expression) {
|
|
||||||
if w, ok := expr.(Where); ok {
|
|
||||||
where.AndConditions = append(where.AndConditions, w.AndConditions...)
|
|
||||||
where.ORConditions = append(where.ORConditions, w.ORConditions...)
|
|
||||||
where.builders = append(where.builders, w.builders...)
|
|
||||||
} else {
|
|
||||||
where.builders = append(where.builders, expr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Select select attrs when querying, updating, creating
|
|
||||||
type Select struct {
|
|
||||||
Omit bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// Join join clause
|
|
||||||
type Join struct {
|
|
||||||
Table string
|
Table string
|
||||||
Type string // left join books on
|
Name string
|
||||||
ON []Expression
|
Alias string
|
||||||
builders []Expression
|
Raw bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (join Join) Build(builder Builder) {
|
func ToColumns(value ...interface{}) []Column {
|
||||||
// TODO
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (join Join) MergeExpression(expr Expression) {
|
// Table quote with name
|
||||||
if j, ok := expr.(Join); ok {
|
type Table struct {
|
||||||
join.builders = append(join.builders, j.builders...)
|
Table string
|
||||||
} else {
|
Alias string
|
||||||
join.builders = append(join.builders, expr)
|
Raw bool
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// GroupBy group by clause
|
|
||||||
type GroupBy struct {
|
|
||||||
}
|
|
||||||
|
|
||||||
// Having having clause
|
|
||||||
type Having struct {
|
|
||||||
}
|
|
||||||
|
|
||||||
// Order order clause
|
|
||||||
type Order struct {
|
|
||||||
}
|
|
||||||
|
|
||||||
// Limit limit clause
|
|
||||||
type Limit struct {
|
|
||||||
}
|
|
||||||
|
|
||||||
// Offset offset clause
|
|
||||||
type Offset struct {
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,22 @@
|
||||||
|
package clause
|
||||||
|
|
||||||
|
// From from clause
|
||||||
|
type From struct {
|
||||||
|
Tables []Table
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name from clause name
|
||||||
|
func (From) Name() string {
|
||||||
|
return "FROM"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build build from clause
|
||||||
|
func (from From) Build(builder Builder) {
|
||||||
|
for idx, table := range from.Tables {
|
||||||
|
if idx > 0 {
|
||||||
|
builder.WriteByte(',')
|
||||||
|
}
|
||||||
|
|
||||||
|
builder.WriteQuoted(table)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,6 @@
|
||||||
|
package clause
|
||||||
|
|
||||||
|
// GroupBy group by clause
|
||||||
|
type GroupBy struct {
|
||||||
|
Having Where
|
||||||
|
}
|
|
@ -0,0 +1,23 @@
|
||||||
|
package clause
|
||||||
|
|
||||||
|
// Join join clause
|
||||||
|
type Join struct {
|
||||||
|
Table From // From
|
||||||
|
Type string // INNER, LEFT, RIGHT, FULL, CROSS JOIN
|
||||||
|
Using []Column
|
||||||
|
ON Where
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO multiple joins
|
||||||
|
|
||||||
|
func (join Join) Build(builder Builder) {
|
||||||
|
// TODO
|
||||||
|
}
|
||||||
|
|
||||||
|
func (join Join) MergeExpression(expr Expression) {
|
||||||
|
// if j, ok := expr.(Join); ok {
|
||||||
|
// join.builders = append(join.builders, j.builders...)
|
||||||
|
// } else {
|
||||||
|
// join.builders = append(join.builders, expr)
|
||||||
|
// }
|
||||||
|
}
|
|
@ -0,0 +1,6 @@
|
||||||
|
package clause
|
||||||
|
|
||||||
|
// Limit limit clause
|
||||||
|
type Limit struct {
|
||||||
|
Offset uint
|
||||||
|
}
|
|
@ -0,0 +1,4 @@
|
||||||
|
package clause
|
||||||
|
|
||||||
|
type OrderBy struct {
|
||||||
|
}
|
|
@ -2,12 +2,6 @@ package clause
|
||||||
|
|
||||||
import "strings"
|
import "strings"
|
||||||
|
|
||||||
// Column quote with name
|
|
||||||
type Column struct {
|
|
||||||
Table string
|
|
||||||
Name string
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
// Query Expressions
|
// Query Expressions
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
|
@ -0,0 +1,45 @@
|
||||||
|
package clause
|
||||||
|
|
||||||
|
// Select select attrs when querying, updating, creating
|
||||||
|
type Select struct {
|
||||||
|
SelectColumns []Column
|
||||||
|
OmitColumns []Column
|
||||||
|
}
|
||||||
|
|
||||||
|
// SelectInterface select clause interface
|
||||||
|
type SelectInterface interface {
|
||||||
|
Selects() []Column
|
||||||
|
Omits() []Column
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s Select) Selects() []Column {
|
||||||
|
return s.SelectColumns
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s Select) Omits() []Column {
|
||||||
|
return s.OmitColumns
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s Select) Build(builder Builder) {
|
||||||
|
if len(s.SelectColumns) > 0 {
|
||||||
|
for idx, column := range s.SelectColumns {
|
||||||
|
if idx > 0 {
|
||||||
|
builder.WriteByte(',')
|
||||||
|
}
|
||||||
|
builder.WriteQuoted(column)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
builder.WriteByte('*')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s Select) MergeExpression(expr Expression) {
|
||||||
|
if v, ok := expr.(SelectInterface); ok {
|
||||||
|
if len(s.SelectColumns) == 0 {
|
||||||
|
s.SelectColumns = v.Selects()
|
||||||
|
}
|
||||||
|
if len(s.OmitColumns) == 0 {
|
||||||
|
s.OmitColumns = v.Omits()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,77 @@
|
||||||
|
package clause
|
||||||
|
|
||||||
|
// Where where clause
|
||||||
|
type Where struct {
|
||||||
|
AndConditions AddConditions
|
||||||
|
ORConditions []ORConditions
|
||||||
|
builders []Expression
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name where clause name
|
||||||
|
func (where Where) Name() string {
|
||||||
|
return "WHERE"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build build where clause
|
||||||
|
func (where Where) Build(builder Builder) {
|
||||||
|
var withConditions bool
|
||||||
|
|
||||||
|
if len(where.AndConditions) > 0 {
|
||||||
|
withConditions = true
|
||||||
|
where.AndConditions.Build(builder)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(where.builders) > 0 {
|
||||||
|
for _, b := range where.builders {
|
||||||
|
if withConditions {
|
||||||
|
builder.Write(" AND ")
|
||||||
|
}
|
||||||
|
withConditions = true
|
||||||
|
b.Build(builder)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var singleOrConditions []ORConditions
|
||||||
|
for _, or := range where.ORConditions {
|
||||||
|
if len(or) == 1 {
|
||||||
|
if withConditions {
|
||||||
|
builder.Write(" OR ")
|
||||||
|
or.Build(builder)
|
||||||
|
} else {
|
||||||
|
singleOrConditions = append(singleOrConditions, or)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
withConditions = true
|
||||||
|
builder.Write(" AND (")
|
||||||
|
or.Build(builder)
|
||||||
|
builder.WriteByte(')')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, or := range singleOrConditions {
|
||||||
|
if withConditions {
|
||||||
|
builder.Write(" AND ")
|
||||||
|
or.Build(builder)
|
||||||
|
} else {
|
||||||
|
withConditions = true
|
||||||
|
or.Build(builder)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !withConditions {
|
||||||
|
builder.Write(" FALSE")
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// MergeExpression merge where clauses
|
||||||
|
func (where Where) MergeExpression(expr Expression) {
|
||||||
|
if w, ok := expr.(Where); ok {
|
||||||
|
where.AndConditions = append(where.AndConditions, w.AndConditions...)
|
||||||
|
where.ORConditions = append(where.ORConditions, w.ORConditions...)
|
||||||
|
where.builders = append(where.builders, w.builders...)
|
||||||
|
} else {
|
||||||
|
where.builders = append(where.builders, expr)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,4 @@
|
||||||
|
package clause
|
||||||
|
|
||||||
|
type With struct {
|
||||||
|
}
|
|
@ -1,7 +1,5 @@
|
||||||
module github.com/jinzhu/gorm/dialects/mysql
|
module github.com/jinzhu/gorm/dialects/sqlite
|
||||||
|
|
||||||
go 1.13
|
go 1.13
|
||||||
|
|
||||||
require (
|
require github.com/mattn/go-sqlite3 v2.0.3+incompatible
|
||||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible
|
|
||||||
)
|
|
||||||
|
|
|
@ -0,0 +1,2 @@
|
||||||
|
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
|
||||||
|
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
|
|
@ -1,6 +1,7 @@
|
||||||
package sqlite
|
package sqlite
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
"github.com/jinzhu/gorm/callbacks"
|
"github.com/jinzhu/gorm/callbacks"
|
||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,15 +1,27 @@
|
||||||
package sqlite_test
|
package sqlite_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
|
"github.com/jinzhu/gorm/dialects/sqlite"
|
||||||
|
"github.com/jinzhu/gorm/tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
var DB *gorm.DB
|
var (
|
||||||
|
DB *gorm.DB
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
func TestOpen(t *testing.T) {
|
func init() {
|
||||||
db, err = gorm.Open("sqlite3", filepath.Join(os.TempDir(), "gorm.db"))
|
if DB, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{}); err != nil {
|
||||||
|
panic(fmt.Sprintf("failed to initialize database, got error %v", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlite(t *testing.T) {
|
||||||
|
tests.RunTestsSuit(t, DB)
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,7 +12,9 @@ func (db *DB) Count(sql string, values ...interface{}) (tx *DB) {
|
||||||
// First find first record that match given conditions, order by primary key
|
// First find first record that match given conditions, order by primary key
|
||||||
func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) {
|
func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
tx.callbacks.Create().Execute(tx.Limit(1).Order("id"))
|
tx.Statement.Dest = out
|
||||||
|
tx.Limit(1)
|
||||||
|
tx.callbacks.Query().Execute(tx)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -35,12 +37,10 @@ func (db *DB) Find(out interface{}, where ...interface{}) (tx *DB) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) Row() *sql.Row {
|
func (db *DB) Row() *sql.Row {
|
||||||
// TODO
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) Rows() (*sql.Rows, error) {
|
func (db *DB) Rows() (*sql.Rows, error) {
|
||||||
// TODO
|
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
5
go.mod
5
go.mod
|
@ -2,4 +2,7 @@ module github.com/jinzhu/gorm
|
||||||
|
|
||||||
go 1.13
|
go 1.13
|
||||||
|
|
||||||
require github.com/jinzhu/inflection v1.0.0
|
require (
|
||||||
|
github.com/jinzhu/inflection v1.0.0
|
||||||
|
gopkg.in/errgo.v2 v2.1.0
|
||||||
|
)
|
||||||
|
|
31
gorm.go
31
gorm.go
|
@ -2,6 +2,7 @@ package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/jinzhu/gorm/clause"
|
"github.com/jinzhu/gorm/clause"
|
||||||
|
@ -12,36 +13,28 @@ import (
|
||||||
// Config GORM config
|
// Config GORM config
|
||||||
type Config struct {
|
type Config struct {
|
||||||
// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
|
// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
|
||||||
// You can cancel it by setting `SkipDefaultTransaction` to true
|
// You can disable it by setting `SkipDefaultTransaction` to true
|
||||||
SkipDefaultTransaction bool // TODO
|
SkipDefaultTransaction bool
|
||||||
|
|
||||||
// NamingStrategy tables, columns naming strategy
|
// NamingStrategy tables, columns naming strategy
|
||||||
NamingStrategy schema.Namer
|
NamingStrategy schema.Namer
|
||||||
|
|
||||||
// Logger
|
// Logger
|
||||||
Logger logger.Interface
|
Logger logger.Interface
|
||||||
|
|
||||||
// NowFunc the function to be used when creating a new timestamp
|
// NowFunc the function to be used when creating a new timestamp
|
||||||
NowFunc func() time.Time
|
NowFunc func() time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dialector GORM database dialector
|
|
||||||
type Dialector interface {
|
|
||||||
Initialize(*DB) error
|
|
||||||
Migrator() Migrator
|
|
||||||
BindVar(stmt Statement, v interface{}) string
|
|
||||||
}
|
|
||||||
|
|
||||||
// DB GORM DB definition
|
// DB GORM DB definition
|
||||||
type DB struct {
|
type DB struct {
|
||||||
*Config
|
*Config
|
||||||
Dialector
|
Dialector
|
||||||
Instance
|
Instance
|
||||||
|
DB CommonDB
|
||||||
clone bool
|
clone bool
|
||||||
callbacks *callbacks
|
callbacks *callbacks
|
||||||
|
cacheStore *sync.Map
|
||||||
}
|
}
|
||||||
|
|
||||||
// Session session config when create new session
|
// Session session config when create session with Session() method
|
||||||
type Session struct {
|
type Session struct {
|
||||||
Context context.Context
|
Context context.Context
|
||||||
Logger logger.Interface
|
Logger logger.Interface
|
||||||
|
@ -71,6 +64,7 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) {
|
||||||
Dialector: dialector,
|
Dialector: dialector,
|
||||||
clone: true,
|
clone: true,
|
||||||
callbacks: InitializeCallbacks(),
|
callbacks: InitializeCallbacks(),
|
||||||
|
cacheStore: &sync.Map{},
|
||||||
}
|
}
|
||||||
|
|
||||||
if dialector != nil {
|
if dialector != nil {
|
||||||
|
@ -113,10 +107,6 @@ func (db *DB) Debug() (tx *DB) {
|
||||||
return db.Session(&Session{Logger: db.Logger.LogMode(logger.Info)})
|
return db.Session(&Session{Logger: db.Logger.LogMode(logger.Info)})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) Close() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set store value with key into current db instance's context
|
// Set store value with key into current db instance's context
|
||||||
func (db *DB) Set(key string, value interface{}) *DB {
|
func (db *DB) Set(key string, value interface{}) *DB {
|
||||||
tx := db.getInstance()
|
tx := db.getInstance()
|
||||||
|
@ -145,12 +135,15 @@ func (db *DB) getInstance() *DB {
|
||||||
}
|
}
|
||||||
|
|
||||||
return &DB{
|
return &DB{
|
||||||
Config: db.Config,
|
|
||||||
Dialector: db.Dialector,
|
|
||||||
Instance: Instance{
|
Instance: Instance{
|
||||||
Context: ctx,
|
Context: ctx,
|
||||||
Statement: &Statement{DB: db, Clauses: map[string]clause.Clause{}},
|
Statement: &Statement{DB: db, Clauses: map[string]clause.Clause{}},
|
||||||
},
|
},
|
||||||
|
Config: db.Config,
|
||||||
|
Dialector: db.Dialector,
|
||||||
|
DB: db.DB,
|
||||||
|
callbacks: db.callbacks,
|
||||||
|
cacheStore: db.cacheStore,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
package gorm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Dialector GORM database dialector
|
||||||
|
type Dialector interface {
|
||||||
|
Initialize(*DB) error
|
||||||
|
Migrator() Migrator
|
||||||
|
BindVar(stmt Statement, v interface{}) string
|
||||||
|
}
|
||||||
|
|
||||||
|
// CommonDB common db interface
|
||||||
|
type CommonDB interface {
|
||||||
|
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
||||||
|
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
|
||||||
|
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||||
|
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
|
||||||
|
}
|
|
@ -1,6 +1,7 @@
|
||||||
package schema
|
package schema
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"go/ast"
|
"go/ast"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
@ -9,6 +10,9 @@ import (
|
||||||
"github.com/jinzhu/gorm/logger"
|
"github.com/jinzhu/gorm/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ErrUnsupportedDataType unsupported data type
|
||||||
|
var ErrUnsupportedDataType = errors.New("unsupported data type")
|
||||||
|
|
||||||
type Schema struct {
|
type Schema struct {
|
||||||
Name string
|
Name string
|
||||||
ModelType reflect.Type
|
ModelType reflect.Type
|
||||||
|
@ -50,9 +54,9 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
|
||||||
|
|
||||||
if modelType.Kind() != reflect.Struct {
|
if modelType.Kind() != reflect.Struct {
|
||||||
if modelType.PkgPath() == "" {
|
if modelType.PkgPath() == "" {
|
||||||
return nil, fmt.Errorf("unsupported data %+v when parsing model", dest)
|
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("unsupported data type %v when parsing model", modelType.PkgPath())
|
return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
|
||||||
}
|
}
|
||||||
|
|
||||||
if v, ok := cacheStore.Load(modelType); ok {
|
if v, ok := cacheStore.Load(modelType); ok {
|
||||||
|
@ -88,7 +92,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, field := range schema.Fields {
|
for _, field := range schema.Fields {
|
||||||
if field.DBName == "" {
|
if field.DBName == "" && field.DataType != "" {
|
||||||
field.DBName = namer.ColumnName(schema.Table, field.Name)
|
field.DBName = namer.ColumnName(schema.Table, field.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,24 +2,16 @@ package schema_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/jinzhu/gorm/schema"
|
"github.com/jinzhu/gorm/schema"
|
||||||
|
"github.com/jinzhu/gorm/tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
func checkSchema(t *testing.T, s *schema.Schema, v schema.Schema, primaryFields []string) {
|
func checkSchema(t *testing.T, s *schema.Schema, v schema.Schema, primaryFields []string) {
|
||||||
t.Run("CheckSchema/"+s.Name, func(t *testing.T) {
|
t.Run("CheckSchema/"+s.Name, func(t *testing.T) {
|
||||||
equalFieldNames := []string{"Name", "Table"}
|
tests.AssertEqual(t, s, v, "Name", "Table")
|
||||||
|
|
||||||
for _, name := range equalFieldNames {
|
|
||||||
got := reflect.ValueOf(s).Elem().FieldByName(name).Interface()
|
|
||||||
expects := reflect.ValueOf(v).FieldByName(name).Interface()
|
|
||||||
if !reflect.DeepEqual(got, expects) {
|
|
||||||
t.Errorf("schema %v %v is not equal, expects: %v, got %v", s, name, expects, got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for idx, field := range primaryFields {
|
for idx, field := range primaryFields {
|
||||||
var found bool
|
var found bool
|
||||||
|
@ -59,15 +51,7 @@ func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(*
|
||||||
if parsedField, ok := s.FieldsByName[f.Name]; !ok {
|
if parsedField, ok := s.FieldsByName[f.Name]; !ok {
|
||||||
t.Errorf("schema %v failed to look up field with name %v", s, f.Name)
|
t.Errorf("schema %v failed to look up field with name %v", s, f.Name)
|
||||||
} else {
|
} else {
|
||||||
equalFieldNames := []string{"Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings"}
|
tests.AssertEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings")
|
||||||
|
|
||||||
for _, name := range equalFieldNames {
|
|
||||||
got := reflect.ValueOf(parsedField).Elem().FieldByName(name).Interface()
|
|
||||||
expects := reflect.ValueOf(f).Elem().FieldByName(name).Interface()
|
|
||||||
if !reflect.DeepEqual(got, expects) {
|
|
||||||
t.Errorf("%v is not equal, expects: %v, got %v", name, expects, got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field {
|
if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field {
|
||||||
t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName)
|
t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName)
|
||||||
|
|
31
statement.go
31
statement.go
|
@ -10,6 +10,7 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/jinzhu/gorm/clause"
|
"github.com/jinzhu/gorm/clause"
|
||||||
|
"github.com/jinzhu/gorm/schema"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Instance db instance
|
// Instance db instance
|
||||||
|
@ -37,6 +38,7 @@ type Statement struct {
|
||||||
Clauses map[string]clause.Clause
|
Clauses map[string]clause.Clause
|
||||||
Settings sync.Map
|
Settings sync.Map
|
||||||
DB *DB
|
DB *DB
|
||||||
|
Schema *schema.Schema
|
||||||
|
|
||||||
// SQL Builder
|
// SQL Builder
|
||||||
SQL strings.Builder
|
SQL strings.Builder
|
||||||
|
@ -69,9 +71,32 @@ func (stmt Statement) WriteQuoted(field interface{}) (err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Quote returns quoted value
|
// Quote returns quoted value
|
||||||
func (stmt Statement) Quote(field interface{}) (str string) {
|
func (stmt Statement) Quote(field interface{}) string {
|
||||||
// FIXME
|
var str strings.Builder
|
||||||
return fmt.Sprint(field)
|
|
||||||
|
switch v := field.(type) {
|
||||||
|
case clause.Table:
|
||||||
|
str.WriteString(v.Table)
|
||||||
|
if v.Alias != "" {
|
||||||
|
str.WriteString(" AS ")
|
||||||
|
str.WriteString(v.Alias)
|
||||||
|
}
|
||||||
|
case clause.Column:
|
||||||
|
if v.Table != "" {
|
||||||
|
str.WriteString(v.Table)
|
||||||
|
str.WriteByte('.')
|
||||||
|
}
|
||||||
|
|
||||||
|
str.WriteString(v.Name)
|
||||||
|
if v.Alias != "" {
|
||||||
|
str.WriteString(" AS ")
|
||||||
|
str.WriteString(v.Alias)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
fmt.Sprint(field)
|
||||||
|
}
|
||||||
|
|
||||||
|
return str.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write write string
|
// Write write string
|
||||||
|
|
|
@ -1 +0,0 @@
|
||||||
package tests
|
|
|
@ -0,0 +1,42 @@
|
||||||
|
package tests
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Now() *time.Time {
|
||||||
|
now := time.Now()
|
||||||
|
return &now
|
||||||
|
}
|
||||||
|
|
||||||
|
func RunTestsSuit(t *testing.T, db *gorm.DB) {
|
||||||
|
TestCreate(t, db)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreate(t *testing.T, db *gorm.DB) {
|
||||||
|
t.Run("Create", func(t *testing.T) {
|
||||||
|
var user = User{
|
||||||
|
Name: "create",
|
||||||
|
Age: 18,
|
||||||
|
Birthday: Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.Create(&user).Error; err != nil {
|
||||||
|
t.Errorf("errors happened when create: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.ID == 0 {
|
||||||
|
t.Errorf("user's primary key should has value after create, got : %v", user.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
var newUser User
|
||||||
|
if err := db.Where("id = ?", user.ID).First(&newUser).Error; err != nil {
|
||||||
|
t.Errorf("errors happened when query: %v", err)
|
||||||
|
} else {
|
||||||
|
AssertEqual(t, newUser, user, "Name", "Age", "Birthday")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
|
@ -0,0 +1,19 @@
|
||||||
|
package tests
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func AssertEqual(t *testing.T, r, e interface{}, names ...string) {
|
||||||
|
for _, name := range names {
|
||||||
|
got := reflect.Indirect(reflect.ValueOf(r)).FieldByName(name).Interface()
|
||||||
|
expects := reflect.Indirect(reflect.ValueOf(e)).FieldByName(name).Interface()
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(got, expects) {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Errorf("expects: %v, got %v", expects, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue