diff --git a/callbacks.go b/callbacks.go index a7f30612..22d2eda3 100644 --- a/callbacks.go +++ b/callbacks.go @@ -1,9 +1,11 @@ package gorm import ( + "errors" "fmt" "github.com/jinzhu/gorm/logger" + "github.com/jinzhu/gorm/schema" "github.com/jinzhu/gorm/utils" ) @@ -67,6 +69,17 @@ func (cs *callbacks) Raw() *processor { } func (p *processor) Execute(db *DB) { + if stmt := db.Statement; stmt != nil && stmt.Dest != nil { + var err error + stmt.Schema, err = schema.Parse(stmt.Dest, db.cacheStore, db.NamingStrategy) + + if err != nil && !errors.Is(err, schema.ErrUnsupportedDataType) { + db.AddError(err) + } else if stmt.Table == "" && stmt.Schema != nil { + stmt.Table = stmt.Schema.Table + } + } + for _, f := range p.fns { f(db) } diff --git a/callbacks/create.go b/callbacks/create.go index 2fe27140..5a3aaa24 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -1,6 +1,10 @@ package callbacks -import "github.com/jinzhu/gorm" +import ( + "fmt" + + "github.com/jinzhu/gorm" +) func BeforeCreate(db *gorm.DB) { // before save @@ -13,6 +17,9 @@ func SaveBeforeAssociations(db *gorm.DB) { } func Create(db *gorm.DB) { + db.Statement.Build("WITH", "INSERT", "VALUES", "ON_CONFLICT", "RETURNING") + + fmt.Println(db.Statement.SQL.String(), db.Statement.Vars) } func SaveAfterAssociations(db *gorm.DB) { @@ -22,3 +29,17 @@ func AfterCreate(db *gorm.DB) { // after save // after create } + +func objectToFieldsMap(stmt *gorm.Statement) { + if stmt.Schema != nil { + if s, ok := stmt.Clauses["SELECT"]; ok { + s.Attrs + } + + if s, ok := stmt.Clauses["OMIT"]; ok { + s.Attrs + } + + stmt.Schema.LookUpField(s.S) + } +} diff --git a/callbacks/query.go b/callbacks/query.go new file mode 100644 index 00000000..5d27ea17 --- /dev/null +++ b/callbacks/query.go @@ -0,0 +1,13 @@ +package callbacks + +import "github.com/jinzhu/gorm" + +func Query(db *gorm.DB) { +} + +func Preload(db *gorm.DB) { +} + +func AfterQuery(db *gorm.DB) { + // after find +} diff --git a/clause/clause.go b/clause/clause.go index 1b4a7e85..c0ebe7e2 100644 --- a/clause/clause.go +++ b/clause/clause.go @@ -51,124 +51,21 @@ type OverrideNameInterface interface { OverrideName() string } -//////////////////////////////////////////////////////////////////////////////// -// Predefined Clauses -//////////////////////////////////////////////////////////////////////////////// - -// Where where clause -type Where struct { - AndConditions AddConditions - ORConditions []ORConditions - builders []Expression +// Column quote with name +type Column struct { + Table string + Name string + Alias string + Raw bool } -func (where Where) Name() string { - return "WHERE" +func ToColumns(value ...interface{}) []Column { + return nil } -func (where Where) Build(builder Builder) { - var withConditions bool - - if len(where.AndConditions) > 0 { - withConditions = true - where.AndConditions.Build(builder) - } - - if len(where.builders) > 0 { - for _, b := range where.builders { - if withConditions { - builder.Write(" AND ") - } - withConditions = true - b.Build(builder) - } - } - - var singleOrConditions []ORConditions - for _, or := range where.ORConditions { - if len(or) == 1 { - if withConditions { - builder.Write(" OR ") - or.Build(builder) - } else { - singleOrConditions = append(singleOrConditions, or) - } - } else { - withConditions = true - builder.Write(" AND (") - or.Build(builder) - builder.WriteByte(')') - } - } - - for _, or := range singleOrConditions { - if withConditions { - builder.Write(" AND ") - or.Build(builder) - } else { - withConditions = true - or.Build(builder) - } - } - - if !withConditions { - builder.Write(" FALSE") - } - - return -} - -func (where Where) MergeExpression(expr Expression) { - if w, ok := expr.(Where); ok { - where.AndConditions = append(where.AndConditions, w.AndConditions...) - where.ORConditions = append(where.ORConditions, w.ORConditions...) - where.builders = append(where.builders, w.builders...) - } else { - where.builders = append(where.builders, expr) - } -} - -// Select select attrs when querying, updating, creating -type Select struct { - Omit bool -} - -// Join join clause -type Join struct { - Table string - Type string // left join books on - ON []Expression - builders []Expression -} - -func (join Join) Build(builder Builder) { - // TODO -} - -func (join Join) MergeExpression(expr Expression) { - if j, ok := expr.(Join); ok { - join.builders = append(join.builders, j.builders...) - } else { - join.builders = append(join.builders, expr) - } -} - -// GroupBy group by clause -type GroupBy struct { -} - -// Having having clause -type Having struct { -} - -// Order order clause -type Order struct { -} - -// Limit limit clause -type Limit struct { -} - -// Offset offset clause -type Offset struct { +// Table quote with name +type Table struct { + Table string + Alias string + Raw bool } diff --git a/clause/from.go b/clause/from.go new file mode 100644 index 00000000..610d69a4 --- /dev/null +++ b/clause/from.go @@ -0,0 +1,22 @@ +package clause + +// From from clause +type From struct { + Tables []Table +} + +// Name from clause name +func (From) Name() string { + return "FROM" +} + +// Build build from clause +func (from From) Build(builder Builder) { + for idx, table := range from.Tables { + if idx > 0 { + builder.WriteByte(',') + } + + builder.WriteQuoted(table) + } +} diff --git a/clause/group_by.go b/clause/group_by.go new file mode 100644 index 00000000..bce94109 --- /dev/null +++ b/clause/group_by.go @@ -0,0 +1,6 @@ +package clause + +// GroupBy group by clause +type GroupBy struct { + Having Where +} diff --git a/clause/join.go b/clause/join.go new file mode 100644 index 00000000..6b0e8f97 --- /dev/null +++ b/clause/join.go @@ -0,0 +1,23 @@ +package clause + +// Join join clause +type Join struct { + Table From // From + Type string // INNER, LEFT, RIGHT, FULL, CROSS JOIN + Using []Column + ON Where +} + +// TODO multiple joins + +func (join Join) Build(builder Builder) { + // TODO +} + +func (join Join) MergeExpression(expr Expression) { + // if j, ok := expr.(Join); ok { + // join.builders = append(join.builders, j.builders...) + // } else { + // join.builders = append(join.builders, expr) + // } +} diff --git a/clause/limit.go b/clause/limit.go new file mode 100644 index 00000000..8fbc0055 --- /dev/null +++ b/clause/limit.go @@ -0,0 +1,6 @@ +package clause + +// Limit limit clause +type Limit struct { + Offset uint +} diff --git a/clause/order_by.go b/clause/order_by.go new file mode 100644 index 00000000..a11a3c48 --- /dev/null +++ b/clause/order_by.go @@ -0,0 +1,4 @@ +package clause + +type OrderBy struct { +} diff --git a/clause/query.go b/clause/query.go index 7b5491e5..949678d9 100644 --- a/clause/query.go +++ b/clause/query.go @@ -2,12 +2,6 @@ package clause import "strings" -// Column quote with name -type Column struct { - Table string - Name string -} - //////////////////////////////////////////////////////////////////////////////// // Query Expressions //////////////////////////////////////////////////////////////////////////////// diff --git a/clause/select.go b/clause/select.go new file mode 100644 index 00000000..1342c411 --- /dev/null +++ b/clause/select.go @@ -0,0 +1,45 @@ +package clause + +// Select select attrs when querying, updating, creating +type Select struct { + SelectColumns []Column + OmitColumns []Column +} + +// SelectInterface select clause interface +type SelectInterface interface { + Selects() []Column + Omits() []Column +} + +func (s Select) Selects() []Column { + return s.SelectColumns +} + +func (s Select) Omits() []Column { + return s.OmitColumns +} + +func (s Select) Build(builder Builder) { + if len(s.SelectColumns) > 0 { + for idx, column := range s.SelectColumns { + if idx > 0 { + builder.WriteByte(',') + } + builder.WriteQuoted(column) + } + } else { + builder.WriteByte('*') + } +} + +func (s Select) MergeExpression(expr Expression) { + if v, ok := expr.(SelectInterface); ok { + if len(s.SelectColumns) == 0 { + s.SelectColumns = v.Selects() + } + if len(s.OmitColumns) == 0 { + s.OmitColumns = v.Omits() + } + } +} diff --git a/clause/where.go b/clause/where.go new file mode 100644 index 00000000..888b9d07 --- /dev/null +++ b/clause/where.go @@ -0,0 +1,77 @@ +package clause + +// Where where clause +type Where struct { + AndConditions AddConditions + ORConditions []ORConditions + builders []Expression +} + +// Name where clause name +func (where Where) Name() string { + return "WHERE" +} + +// Build build where clause +func (where Where) Build(builder Builder) { + var withConditions bool + + if len(where.AndConditions) > 0 { + withConditions = true + where.AndConditions.Build(builder) + } + + if len(where.builders) > 0 { + for _, b := range where.builders { + if withConditions { + builder.Write(" AND ") + } + withConditions = true + b.Build(builder) + } + } + + var singleOrConditions []ORConditions + for _, or := range where.ORConditions { + if len(or) == 1 { + if withConditions { + builder.Write(" OR ") + or.Build(builder) + } else { + singleOrConditions = append(singleOrConditions, or) + } + } else { + withConditions = true + builder.Write(" AND (") + or.Build(builder) + builder.WriteByte(')') + } + } + + for _, or := range singleOrConditions { + if withConditions { + builder.Write(" AND ") + or.Build(builder) + } else { + withConditions = true + or.Build(builder) + } + } + + if !withConditions { + builder.Write(" FALSE") + } + + return +} + +// MergeExpression merge where clauses +func (where Where) MergeExpression(expr Expression) { + if w, ok := expr.(Where); ok { + where.AndConditions = append(where.AndConditions, w.AndConditions...) + where.ORConditions = append(where.ORConditions, w.ORConditions...) + where.builders = append(where.builders, w.builders...) + } else { + where.builders = append(where.builders, expr) + } +} diff --git a/clause/with.go b/clause/with.go new file mode 100644 index 00000000..7e9eaef1 --- /dev/null +++ b/clause/with.go @@ -0,0 +1,4 @@ +package clause + +type With struct { +} diff --git a/dialects/sqlite/go.mod b/dialects/sqlite/go.mod index db3370e9..79d48da8 100644 --- a/dialects/sqlite/go.mod +++ b/dialects/sqlite/go.mod @@ -1,7 +1,5 @@ -module github.com/jinzhu/gorm/dialects/mysql +module github.com/jinzhu/gorm/dialects/sqlite go 1.13 -require ( - github.com/mattn/go-sqlite3 v2.0.3+incompatible -) +require github.com/mattn/go-sqlite3 v2.0.3+incompatible diff --git a/dialects/sqlite/go.sum b/dialects/sqlite/go.sum new file mode 100644 index 00000000..d6744290 --- /dev/null +++ b/dialects/sqlite/go.sum @@ -0,0 +1,2 @@ +github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U= +github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go index f3c3f0c7..bcd6bd5c 100644 --- a/dialects/sqlite/sqlite.go +++ b/dialects/sqlite/sqlite.go @@ -1,6 +1,7 @@ package sqlite import ( + "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/callbacks" _ "github.com/mattn/go-sqlite3" ) diff --git a/dialects/sqlite/sqlite_test.go b/dialects/sqlite/sqlite_test.go index f0429a12..51c1def0 100644 --- a/dialects/sqlite/sqlite_test.go +++ b/dialects/sqlite/sqlite_test.go @@ -1,15 +1,27 @@ package sqlite_test import ( + "fmt" "os" "path/filepath" "testing" "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/dialects/sqlite" + "github.com/jinzhu/gorm/tests" ) -var DB *gorm.DB +var ( + DB *gorm.DB + err error +) -func TestOpen(t *testing.T) { - db, err = gorm.Open("sqlite3", filepath.Join(os.TempDir(), "gorm.db")) +func init() { + if DB, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{}); err != nil { + panic(fmt.Sprintf("failed to initialize database, got error %v", err)) + } +} + +func TestSqlite(t *testing.T) { + tests.RunTestsSuit(t, DB) } diff --git a/finisher_api.go b/finisher_api.go index b155e90d..c79915d2 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -12,7 +12,9 @@ func (db *DB) Count(sql string, values ...interface{}) (tx *DB) { // First find first record that match given conditions, order by primary key func (db *DB) First(out interface{}, where ...interface{}) (tx *DB) { tx = db.getInstance() - tx.callbacks.Create().Execute(tx.Limit(1).Order("id")) + tx.Statement.Dest = out + tx.Limit(1) + tx.callbacks.Query().Execute(tx) return } @@ -35,12 +37,10 @@ func (db *DB) Find(out interface{}, where ...interface{}) (tx *DB) { } func (db *DB) Row() *sql.Row { - // TODO return nil } func (db *DB) Rows() (*sql.Rows, error) { - // TODO return nil, nil } diff --git a/go.mod b/go.mod index 516a9759..820046ba 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,7 @@ module github.com/jinzhu/gorm go 1.13 -require github.com/jinzhu/inflection v1.0.0 +require ( + github.com/jinzhu/inflection v1.0.0 + gopkg.in/errgo.v2 v2.1.0 +) diff --git a/gorm.go b/gorm.go index 896d07f9..2264b9ae 100644 --- a/gorm.go +++ b/gorm.go @@ -2,6 +2,7 @@ package gorm import ( "context" + "sync" "time" "github.com/jinzhu/gorm/clause" @@ -12,36 +13,28 @@ import ( // Config GORM config type Config struct { // GORM perform single create, update, delete operations in transactions by default to ensure database data integrity - // You can cancel it by setting `SkipDefaultTransaction` to true - SkipDefaultTransaction bool // TODO - + // You can disable it by setting `SkipDefaultTransaction` to true + SkipDefaultTransaction bool // NamingStrategy tables, columns naming strategy NamingStrategy schema.Namer - // Logger Logger logger.Interface - // NowFunc the function to be used when creating a new timestamp NowFunc func() time.Time } -// Dialector GORM database dialector -type Dialector interface { - Initialize(*DB) error - Migrator() Migrator - BindVar(stmt Statement, v interface{}) string -} - // DB GORM DB definition type DB struct { *Config Dialector Instance - clone bool - callbacks *callbacks + DB CommonDB + clone bool + callbacks *callbacks + cacheStore *sync.Map } -// Session session config when create new session +// Session session config when create session with Session() method type Session struct { Context context.Context Logger logger.Interface @@ -67,10 +60,11 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { } db = &DB{ - Config: config, - Dialector: dialector, - clone: true, - callbacks: InitializeCallbacks(), + Config: config, + Dialector: dialector, + clone: true, + callbacks: InitializeCallbacks(), + cacheStore: &sync.Map{}, } if dialector != nil { @@ -113,10 +107,6 @@ func (db *DB) Debug() (tx *DB) { return db.Session(&Session{Logger: db.Logger.LogMode(logger.Info)}) } -func (db *DB) Close() error { - return nil -} - // Set store value with key into current db instance's context func (db *DB) Set(key string, value interface{}) *DB { tx := db.getInstance() @@ -145,12 +135,15 @@ func (db *DB) getInstance() *DB { } return &DB{ - Config: db.Config, - Dialector: db.Dialector, Instance: Instance{ Context: ctx, Statement: &Statement{DB: db, Clauses: map[string]clause.Clause{}}, }, + Config: db.Config, + Dialector: db.Dialector, + DB: db.DB, + callbacks: db.callbacks, + cacheStore: db.cacheStore, } } diff --git a/interfaces.go b/interfaces.go new file mode 100644 index 00000000..98d04592 --- /dev/null +++ b/interfaces.go @@ -0,0 +1,21 @@ +package gorm + +import ( + "context" + "database/sql" +) + +// Dialector GORM database dialector +type Dialector interface { + Initialize(*DB) error + Migrator() Migrator + BindVar(stmt Statement, v interface{}) string +} + +// CommonDB common db interface +type CommonDB interface { + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row +} diff --git a/schema/schema.go b/schema/schema.go index 5cd6146b..53170e18 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -1,6 +1,7 @@ package schema import ( + "errors" "fmt" "go/ast" "reflect" @@ -9,6 +10,9 @@ import ( "github.com/jinzhu/gorm/logger" ) +// ErrUnsupportedDataType unsupported data type +var ErrUnsupportedDataType = errors.New("unsupported data type") + type Schema struct { Name string ModelType reflect.Type @@ -50,9 +54,9 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) if modelType.Kind() != reflect.Struct { if modelType.PkgPath() == "" { - return nil, fmt.Errorf("unsupported data %+v when parsing model", dest) + return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) } - return nil, fmt.Errorf("unsupported data type %v when parsing model", modelType.PkgPath()) + return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } if v, ok := cacheStore.Load(modelType); ok { @@ -88,7 +92,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) } for _, field := range schema.Fields { - if field.DBName == "" { + if field.DBName == "" && field.DataType != "" { field.DBName = namer.ColumnName(schema.Table, field.Name) } diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go index 05f41131..db38355d 100644 --- a/schema/schema_helper_test.go +++ b/schema/schema_helper_test.go @@ -2,24 +2,16 @@ package schema_test import ( "fmt" - "reflect" "strings" "testing" "github.com/jinzhu/gorm/schema" + "github.com/jinzhu/gorm/tests" ) func checkSchema(t *testing.T, s *schema.Schema, v schema.Schema, primaryFields []string) { t.Run("CheckSchema/"+s.Name, func(t *testing.T) { - equalFieldNames := []string{"Name", "Table"} - - for _, name := range equalFieldNames { - got := reflect.ValueOf(s).Elem().FieldByName(name).Interface() - expects := reflect.ValueOf(v).FieldByName(name).Interface() - if !reflect.DeepEqual(got, expects) { - t.Errorf("schema %v %v is not equal, expects: %v, got %v", s, name, expects, got) - } - } + tests.AssertEqual(t, s, v, "Name", "Table") for idx, field := range primaryFields { var found bool @@ -59,15 +51,7 @@ func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(* if parsedField, ok := s.FieldsByName[f.Name]; !ok { t.Errorf("schema %v failed to look up field with name %v", s, f.Name) } else { - equalFieldNames := []string{"Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings"} - - for _, name := range equalFieldNames { - got := reflect.ValueOf(parsedField).Elem().FieldByName(name).Interface() - expects := reflect.ValueOf(f).Elem().FieldByName(name).Interface() - if !reflect.DeepEqual(got, expects) { - t.Errorf("%v is not equal, expects: %v, got %v", name, expects, got) - } - } + tests.AssertEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "DBDataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings") if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field { t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) diff --git a/statement.go b/statement.go index 30d45b98..86359177 100644 --- a/statement.go +++ b/statement.go @@ -10,6 +10,7 @@ import ( "sync" "github.com/jinzhu/gorm/clause" + "github.com/jinzhu/gorm/schema" ) // Instance db instance @@ -37,6 +38,7 @@ type Statement struct { Clauses map[string]clause.Clause Settings sync.Map DB *DB + Schema *schema.Schema // SQL Builder SQL strings.Builder @@ -69,9 +71,32 @@ func (stmt Statement) WriteQuoted(field interface{}) (err error) { } // Quote returns quoted value -func (stmt Statement) Quote(field interface{}) (str string) { - // FIXME - return fmt.Sprint(field) +func (stmt Statement) Quote(field interface{}) string { + var str strings.Builder + + switch v := field.(type) { + case clause.Table: + str.WriteString(v.Table) + if v.Alias != "" { + str.WriteString(" AS ") + str.WriteString(v.Alias) + } + case clause.Column: + if v.Table != "" { + str.WriteString(v.Table) + str.WriteByte('.') + } + + str.WriteString(v.Name) + if v.Alias != "" { + str.WriteString(" AS ") + str.WriteString(v.Alias) + } + default: + fmt.Sprint(field) + } + + return str.String() } // Write write string diff --git a/tests/create_test.go b/tests/create_test.go deleted file mode 100644 index ca8701d2..00000000 --- a/tests/create_test.go +++ /dev/null @@ -1 +0,0 @@ -package tests diff --git a/tests/tests.go b/tests/tests.go new file mode 100644 index 00000000..b3246a79 --- /dev/null +++ b/tests/tests.go @@ -0,0 +1,42 @@ +package tests + +import ( + "testing" + "time" + + "github.com/jinzhu/gorm" +) + +func Now() *time.Time { + now := time.Now() + return &now +} + +func RunTestsSuit(t *testing.T, db *gorm.DB) { + TestCreate(t, db) +} + +func TestCreate(t *testing.T, db *gorm.DB) { + t.Run("Create", func(t *testing.T) { + var user = User{ + Name: "create", + Age: 18, + Birthday: Now(), + } + + if err := db.Create(&user).Error; err != nil { + t.Errorf("errors happened when create: %v", err) + } + + if user.ID == 0 { + t.Errorf("user's primary key should has value after create, got : %v", user.ID) + } + + var newUser User + if err := db.Where("id = ?", user.ID).First(&newUser).Error; err != nil { + t.Errorf("errors happened when query: %v", err) + } else { + AssertEqual(t, newUser, user, "Name", "Age", "Birthday") + } + }) +} diff --git a/tests/utils.go b/tests/utils.go new file mode 100644 index 00000000..d12df2dc --- /dev/null +++ b/tests/utils.go @@ -0,0 +1,19 @@ +package tests + +import ( + "reflect" + "testing" +) + +func AssertEqual(t *testing.T, r, e interface{}, names ...string) { + for _, name := range names { + got := reflect.Indirect(reflect.ValueOf(r)).FieldByName(name).Interface() + expects := reflect.Indirect(reflect.ValueOf(e)).FieldByName(name).Interface() + + if !reflect.DeepEqual(got, expects) { + t.Run(name, func(t *testing.T) { + t.Errorf("expects: %v, got %v", expects, got) + }) + } + } +}