diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..01dc5ce0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +documents +_book diff --git a/README.md b/README.md index f1e6e4b5..c3f209c9 100644 --- a/README.md +++ b/README.md @@ -1,1339 +1,33 @@ # GORM -[![Join the chat at https://gitter.im/jinzhu/gorm](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) - The fantastic ORM library for Golang, aims to be developer friendly. +[![Join the chat at https://gitter.im/jinzhu/gorm](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) [![wercker status](https://app.wercker.com/status/0cb7bb1039e21b74f8274941428e0921/s/master "wercker status")](https://app.wercker.com/project/bykey/0cb7bb1039e21b74f8274941428e0921) +[![GoDoc](https://godoc.org/github.com/jinzhu/gorm?status.svg)](https://godoc.org/github.com/jinzhu/gorm) ## Overview * Full-Featured ORM (almost) -* Chainable API -* Auto Migrations -* Relations (Has One, Has Many, Belongs To, Many To Many, [Polymorphism](#polymorphism)) +* Associations (Has One, Has Many, Belongs To, Many To Many, Polymorphism) * Callbacks (Before/After Create/Save/Update/Delete/Find) * Preloading (eager loading) * Transactions -* Embed Anonymous Struct -* Soft Deletes -* Customizable Logger -* Iteration Support via [Rows](#row--rows) +* Composite Primary Key +* SQL Builder +* Auto Migrations +* Logger +* Extendable, write Plugins based on GORM callbacks * Every feature comes with tests * Developer Friendly -# Getting Started +## Getting Started -## Install +* GORM Guides [jinzhu.github.com/gorm](https://jinzhu.github.io/gorm) -``` -go get -u github.com/jinzhu/gorm -``` +## Upgrading To V1.0 -## Table of Contents - -- [Define Models (Structs)](#define-models-structs) -- [Conventions](#conventions) -- [Initialize Database](#initialize-database) -- [Migration](#migration) -- [Basic CRUD](#basic-crud) - - [Create](#create-record) - - [Query](#query) - - [Query With Where (Plain SQL)](#query-with-where-plain-sql) - - [Query With Where (Struct & Map)](#query-with-where-struct--map) - - [Query With Not](#query-with-not) - - [Query With Inline Condition](#query-with-inline-condition) - - [Query With Or](#query-with-or) - - [Query Chains](#query-chains) - - [Preloading (Eager loading)](#preloading-eager-loading) - - [Update](#update) - - [Update Without Callbacks](#update-without-callbacks) - - [Batch Updates](#batch-updates) - - [Update with SQL Expression](#update-with-sql-expression) - - [Delete](#delete) - - [Batch Delete](#batch-delete) - - [Soft Delete](#soft-delete) -- [Associations](#associations) - - [Has One](#has-one) - - [Belongs To](#belongs-to) - - [Has Many](#has-many) - - [Many To Many](#many-to-many) - - [Polymorphism](#polymorphism) - - [Association Mode](#association-mode) -- [Advanced Usage](#advanced-usage) - - [FirstOrInit](#firstorinit) - - [FirstOrCreate](#firstorcreate) - - [Select](#select) - - [Order](#order) - - [Limit](#limit) - - [Offset](#offset) - - [Count](#count) - - [Pluck](#pluck) - - [Raw SQL](#raw-sql) - - [Row & Rows](#row--rows) - - [Scan](#scan) - - [Group & Having](#group--having) - - [Joins](#joins) - - [Transactions](#transactions) - - [Scopes](#scopes) - - [Callbacks](#callbacks) - - [Specifying The Table Name](#specifying-the-table-name) - - [Error Handling](#error-handling) - - [Logger](#logger) - - [Existing Schema](#existing-schema) - - [Composite Primary Key](#composite-primary-key) - - [Database Indexes & Foreign Key](#database-indexes--foreign-key) - - [Default values](#default-values) - - [More examples with query chain](#more-examples-with-query-chain) - -## Define Models (Structs) - -```go -type User struct { - ID int - Birthday time.Time - Age int - Name string `sql:"size:255"` // Default size for string is 255, you could reset it with this tag - Num int `sql:"AUTO_INCREMENT"` - CreatedAt time.Time - UpdatedAt time.Time - DeletedAt *time.Time - - CreditCard CreditCard // One-To-One relationship (has one - use CreditCard's UserID as foreign key) - Emails []Email // One-To-Many relationship (has many - use Email's UserID as foreign key) - - BillingAddress Address // One-To-One relationship (belongs to - use BillingAddressID as foreign key) - BillingAddressID sql.NullInt64 - - ShippingAddress Address // One-To-One relationship (belongs to - use ShippingAddressID as foreign key) - ShippingAddressID int - - IgnoreMe int `sql:"-"` // Ignore this field - Languages []Language `gorm:"many2many:user_languages;"` // Many-To-Many relationship, 'user_languages' is join table -} - -type Email struct { - ID int - UserID int `sql:"index"` // Foreign key (belongs to), tag `index` will create index for this field when using AutoMigrate - Email string `sql:"type:varchar(100);unique_index"` // Set field's sql type, tag `unique_index` will create unique index - Subscribed bool -} - -type Address struct { - ID int - Address1 string `sql:"not null;unique"` // Set field as not nullable and unique - Address2 string `sql:"type:varchar(100);unique"` - Post sql.NullString `sql:"not null"` -} - -type Language struct { - ID int - Name string `sql:"index:idx_name_code"` // Create index with name, and will create combined index if find other fields defined same name - Code string `sql:"index:idx_name_code"` // `unique_index` also works -} - -type CreditCard struct { - gorm.Model - UserID uint - Number string -} -``` - -## Conventions - -* Table name is the plural of struct name's snake case, you can disable pluralization with `db.SingularTable(true)`, or [Specifying The Table Name For A Struct Permanently With TableName](#specifying-the-table-name-for-a-struct-permanently-with-tablename) - -```go -type User struct{} // struct User's database table name is "users" by default, will be "user" if you disabled pluralisation -``` - -* Column name is the snake case of field's name -* Use `ID` field as primary key -* Use `CreatedAt` to store record's created time if field exists -* Use `UpdatedAt` to store record's updated time if field exists -* Use `DeletedAt` to store record's deleted time if field exists [Soft Delete](#soft-delete) -* Gorm provide a default model struct, you could embed it in your struct - -```go -type Model struct { - ID uint `gorm:"primary_key"` - CreatedAt time.Time - UpdatedAt time.Time - DeletedAt *time.Time -} - -type User struct { - gorm.Model - Name string -} -``` - -## Initialize Database - -```go -import ( - "github.com/jinzhu/gorm" - _ "github.com/lib/pq" - _ "github.com/go-sql-driver/mysql" - _ "github.com/mattn/go-sqlite3" -) - -db, err := gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable") -// db, err := gorm.Open("foundation", "dbname=gorm") // FoundationDB. -// db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local") -// db, err := gorm.Open("sqlite3", "/tmp/gorm.db") - -// You can also use an existing database connection handle -// dbSql, _ := sql.Open("postgres", "user=gorm dbname=gorm sslmode=disable") -// db, _ := gorm.Open("postgres", dbSql) - -// Get database connection handle [*sql.DB](http://golang.org/pkg/database/sql/#DB) -db.DB() - -// Then you could invoke `*sql.DB`'s functions with it -db.DB().Ping() -db.DB().SetMaxIdleConns(10) -db.DB().SetMaxOpenConns(100) - -// Disable table name's pluralization -db.SingularTable(true) -``` - -## Migration - -```go -// Create table -db.CreateTable(&User{}) -db.Set("gorm:table_options", "ENGINE=InnoDB").CreateTable(&User{}) - -// Drop table -db.DropTable(&User{}) - -// ModifyColumn -db.Model(&User{}).ModifyColumn("description", "text") - -// DropColumn -db.Model(&User{}).DropColumn("description") - -// Automating Migration -db.AutoMigrate(&User{}) -db.Set("gorm:table_options", "ENGINE=InnoDB").AutoMigrate(&User{}) -db.AutoMigrate(&User{}, &Product{}, &Order{}) -// Feel free to change your struct, AutoMigrate will keep your database up-to-date. -// AutoMigrate will ONLY add *new columns* and *new indexes*, -// WON'T update current column's type or delete unused columns, to protect your data. -// If the table is not existing, AutoMigrate will create the table automatically. -``` - -# Basic CRUD - -## Create Record - -```go -user := User{Name: "Jinzhu", Age: 18, Birthday: time.Now()} - -db.NewRecord(user) // => returns `true` if primary key is blank - -db.Create(&user) - -db.NewRecord(user) // => return `false` after `user` created - -// Associations will be inserted automatically when save the record -user := User{ - Name: "jinzhu", - BillingAddress: Address{Address1: "Billing Address - Address 1"}, - ShippingAddress: Address{Address1: "Shipping Address - Address 1"}, - Emails: []Email{{Email: "jinzhu@example.com"}, {Email: "jinzhu-2@example@example.com"}}, - Languages: []Language{{Name: "ZH"}, {Name: "EN"}}, -} - -db.Create(&user) -//// BEGIN TRANSACTION; -//// INSERT INTO "addresses" (address1) VALUES ("Billing Address - Address 1"); -//// INSERT INTO "addresses" (address1) VALUES ("Shipping Address - Address 1"); -//// INSERT INTO "users" (name,billing_address_id,shipping_address_id) VALUES ("jinzhu", 1, 2); -//// INSERT INTO "emails" (user_id,email) VALUES (111, "jinzhu@example.com"); -//// INSERT INTO "emails" (user_id,email) VALUES (111, "jinzhu-2@example.com"); -//// INSERT INTO "languages" ("name") VALUES ('ZH'); -//// INSERT INTO user_languages ("user_id","language_id") VALUES (111, 1); -//// INSERT INTO "languages" ("name") VALUES ('EN'); -//// INSERT INTO user_languages ("user_id","language_id") VALUES (111, 2); -//// COMMIT; -``` - -Refer [Associations](#associations) for more details - -## Query - -```go -// Get the first record -db.First(&user) -//// SELECT * FROM users ORDER BY id LIMIT 1; - -// Get the last record -db.Last(&user) -//// SELECT * FROM users ORDER BY id DESC LIMIT 1; - -// Get all records -db.Find(&users) -//// SELECT * FROM users; - -// Get record with primary key -db.First(&user, 10) -//// SELECT * FROM users WHERE id = 10; -``` - -### Query With Where (Plain SQL) - -```go -// Get the first matched record -db.Where("name = ?", "jinzhu").First(&user) -//// SELECT * FROM users WHERE name = 'jinzhu' limit 1; - -// Get all matched records -db.Where("name = ?", "jinzhu").Find(&users) -//// SELECT * FROM users WHERE name = 'jinzhu'; - -db.Where("name <> ?", "jinzhu").Find(&users) - -// IN -db.Where("name in (?)", []string{"jinzhu", "jinzhu 2"}).Find(&users) - -// LIKE -db.Where("name LIKE ?", "%jin%").Find(&users) - -// AND -db.Where("name = ? and age >= ?", "jinzhu", "22").Find(&users) - -// Time -db.Where("updated_at > ?", lastWeek).Find(&users) - -db.Where("created_at BETWEEN ? AND ?", lastWeek, today).Find(&users) -``` - -### Query With Where (Struct & Map) - -```go -// Struct -db.Where(&User{Name: "jinzhu", Age: 20}).First(&user) -//// SELECT * FROM users WHERE name = "jinzhu" AND age = 20 LIMIT 1; - -// Map -db.Where(map[string]interface{}{"name": "jinzhu", "age": 20}).Find(&users) -//// SELECT * FROM users WHERE name = "jinzhu" AND age = 20; - -// Slice of primary keys -db.Where([]int64{20, 21, 22}).Find(&users) -//// SELECT * FROM users WHERE id IN (20, 21, 22); -``` - -### Query With Not - -```go -db.Not("name", "jinzhu").First(&user) -//// SELECT * FROM users WHERE name <> "jinzhu" LIMIT 1; - -// Not In -db.Not("name", []string{"jinzhu", "jinzhu 2"}).Find(&users) -//// SELECT * FROM users WHERE name NOT IN ("jinzhu", "jinzhu 2"); - -// Not In slice of primary keys -db.Not([]int64{1,2,3}).First(&user) -//// SELECT * FROM users WHERE id NOT IN (1,2,3); - -db.Not([]int64{}).First(&user) -//// SELECT * FROM users; - -// Plain SQL -db.Not("name = ?", "jinzhu").First(&user) -//// SELECT * FROM users WHERE NOT(name = "jinzhu"); - -// Struct -db.Not(User{Name: "jinzhu"}).First(&user) -//// SELECT * FROM users WHERE name <> "jinzhu"; -``` - -### Query With Inline Condition - -```go -// Get by primary key -db.First(&user, 23) -//// SELECT * FROM users WHERE id = 23 LIMIT 1; - -// Plain SQL -db.Find(&user, "name = ?", "jinzhu") -//// SELECT * FROM users WHERE name = "jinzhu"; - -db.Find(&users, "name <> ? AND age > ?", "jinzhu", 20) -//// SELECT * FROM users WHERE name <> "jinzhu" AND age > 20; - -// Struct -db.Find(&users, User{Age: 20}) -//// SELECT * FROM users WHERE age = 20; - -// Map -db.Find(&users, map[string]interface{}{"age": 20}) -//// SELECT * FROM users WHERE age = 20; -``` - -### Query With Or - -```go -db.Where("role = ?", "admin").Or("role = ?", "super_admin").Find(&users) -//// SELECT * FROM users WHERE role = 'admin' OR role = 'super_admin'; - -// Struct -db.Where("name = 'jinzhu'").Or(User{Name: "jinzhu 2"}).Find(&users) -//// SELECT * FROM users WHERE name = 'jinzhu' OR name = 'jinzhu 2'; - -// Map -db.Where("name = 'jinzhu'").Or(map[string]interface{}{"name": "jinzhu 2"}).Find(&users) -``` - -### Query Chains - -Gorm has a chainable API, you could use it like this - -```go -db.Where("name <> ?","jinzhu").Where("age >= ? and role <> ?",20,"admin").Find(&users) -//// SELECT * FROM users WHERE name <> 'jinzhu' AND age >= 20 AND role <> 'admin'; - -db.Where("role = ?", "admin").Or("role = ?", "super_admin").Not("name = ?", "jinzhu").Find(&users) -``` - -### Preloading (Eager loading) - -```go -db.Preload("Orders").Find(&users) -//// SELECT * FROM users; -//// SELECT * FROM orders WHERE user_id IN (1,2,3,4); - -db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) -//// SELECT * FROM users; -//// SELECT * FROM orders WHERE user_id IN (1,2,3,4) AND state NOT IN ('cancelled'); - -db.Where("state = ?", "active").Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) -//// SELECT * FROM users WHERE state = 'active'; -//// SELECT * FROM orders WHERE user_id IN (1,2) AND state NOT IN ('cancelled'); - -db.Preload("Orders").Preload("Profile").Preload("Role").Find(&users) -//// SELECT * FROM users; -//// SELECT * FROM orders WHERE user_id IN (1,2,3,4); // has many -//// SELECT * FROM profiles WHERE user_id IN (1,2,3,4); // has one -//// SELECT * FROM roles WHERE id IN (4,5,6); // belongs to -``` - -#### Nested Preloading - -```go -db.Preload("Orders.OrderItems").Find(&users) -db.Preload("Orders", "state = ?", "paid").Preload("Orders.OrderItems").Find(&users) -``` - -## Update - -```go -// Update an existing struct -db.First(&user) -user.Name = "jinzhu 2" -user.Age = 100 -db.Save(&user) -//// UPDATE users SET name='jinzhu 2', age=100, updated_at = '2013-11-17 21:34:10' WHERE id=111; - -db.Where("active = ?", true).Save(&user) -//// UPDATE users SET name='jinzhu 2', age=100, updated_at = '2013-11-17 21:34:10' WHERE id=111 AND active = true; - -// Update an attribute if it is changed -db.Model(&user).Update("name", "hello") -//// UPDATE users SET name='hello', updated_at = '2013-11-17 21:34:10' WHERE id=111; - -db.Model(&user).Where("active = ?", true).Update("name", "hello") -//// UPDATE users SET name='hello', updated_at = '2013-11-17 21:34:10' WHERE id=111 AND active = true; - -db.First(&user, 111).Update("name", "hello") -//// SELECT * FROM users LIMIT 1; -//// UPDATE users SET name='hello', updated_at = '2013-11-17 21:34:10' WHERE id=111; - -// Update multiple attributes if they are changed -db.Model(&user).Updates(map[string]interface{}{"name": "hello", "age": 18, "actived": false}) - -// Update multiple attributes if they are changed (update with struct only works with none zero values) -db.Model(&user).Updates(User{Name: "hello", Age: 18}) -//// UPDATE users SET name='hello', age=18, updated_at = '2013-11-17 21:34:10' WHERE id = 111; -``` - -### Update Without Callbacks - -By default, update will call BeforeUpdate, AfterUpdate callbacks, if you want to update w/o callbacks and w/o saving associations: - -```go -db.Model(&user).UpdateColumn("name", "hello") -//// UPDATE users SET name='hello' WHERE id = 111; - -// Update with struct only works with none zero values, or use map[string]interface{} -db.Model(&user).UpdateColumns(User{Name: "hello", Age: 18}) -//// UPDATE users SET name='hello', age=18 WHERE id = 111; -``` - -### Batch Updates - -```go -db.Table("users").Where("id = ?", 10).Updates(map[string]interface{}{"name": "hello", "age": 18}) -//// UPDATE users SET name='hello', age=18 WHERE id = 10; - -// Update with struct only works with none zero values, or use map[string]interface{} -db.Model(User{}).Updates(User{Name: "hello", Age: 18}) -//// UPDATE users SET name='hello', age=18; - -// Callbacks won't run when do batch updates - -// Use `RowsAffected` to get the count of affected records -db.Model(User{}).Updates(User{Name: "hello", Age: 18}).RowsAffected -``` - -### Update with SQL Expression - -```go -DB.Model(&product).Update("price", gorm.Expr("price * ? + ?", 2, 100)) -//// UPDATE "products" SET "code" = 'L1212', "price" = price * '2' + '100', "updated_at" = '2013-11-17 21:34:10' WHERE "id" = '2'; - -DB.Model(&product).Updates(map[string]interface{}{"price": gorm.Expr("price * ? + ?", 2, 100)}) -//// UPDATE "products" SET "code" = 'L1212', "price" = price * '2' + '100', "updated_at" = '2013-11-17 21:34:10' WHERE "id" = '2'; - -DB.Model(&product).UpdateColumn("quantity", gorm.Expr("quantity - ?", 1)) -//// UPDATE "products" SET "quantity" = quantity - 1 WHERE "id" = '2'; - -DB.Model(&product).Where("quantity > 1").UpdateColumn("quantity", gorm.Expr("quantity - ?", 1)) -//// UPDATE "products" SET "quantity" = quantity - 1 WHERE "id" = '2' AND quantity > 1; -``` - -## Delete - -```go -// Delete an existing record -db.Delete(&email) -//// DELETE from emails where id=10; -``` - -### Batch Delete - -```go -db.Where("email LIKE ?", "%jinzhu%").Delete(Email{}) -//// DELETE from emails where email LIKE "%jinhu%"; -``` - -### Soft Delete - -If struct has `DeletedAt` field, it will get soft delete ability automatically! -Then it won't be deleted from database permanently when call `Delete`. - -```go -db.Delete(&user) -//// UPDATE users SET deleted_at="2013-10-29 10:23" WHERE id = 111; - -// Batch Delete -db.Where("age = ?", 20).Delete(&User{}) -//// UPDATE users SET deleted_at="2013-10-29 10:23" WHERE age = 20; - -// Soft deleted records will be ignored when query them -db.Where("age = 20").Find(&user) -//// SELECT * FROM users WHERE age = 20 AND (deleted_at IS NULL OR deleted_at <= '0001-01-02'); - -// Find soft deleted records with Unscoped -db.Unscoped().Where("age = 20").Find(&users) -//// SELECT * FROM users WHERE age = 20; - -// Delete record permanently with Unscoped -db.Unscoped().Delete(&order) -//// DELETE FROM orders WHERE id=10; -``` - -## Associations - -### Has One - -```go -// User has one CreditCard, UserID is the foreign key -type User struct { - gorm.Model - CreditCard CreditCard -} - -type CreditCard struct { - gorm.Model - UserID uint - Number string -} - -var card CreditCard -db.Model(&user).Related(&card, "CreditCard") -//// SELECT * FROM credit_cards WHERE user_id = 123; // 123 is user's primary key -// CreditCard is user's field name, it means get user's CreditCard relations and fill it into variable card -// If the field name is same as the variable's type name, like above example, it could be omitted, like: -db.Model(&user).Related(&card) -``` - -### Belongs To - -```go -// User belongs to a profile, ProfileID is the foreign key -type User struct { - gorm.Model - Profile Profile - ProfileID int -} - -type Profile struct { - gorm.Model - Name string -} - -db.Model(&user).Related(&profile) -//// SELECT * FROM profiles WHERE id = 111; // 111 is user's foreign key ProfileID -``` - -### Has Many - -```go -// User has many emails, UserID is the foreign key -type User struct { - gorm.Model - Emails []Email -} - -type Email struct { - gorm.Model - Email string - UserID uint -} - -db.Model(&user).Related(&emails) -//// SELECT * FROM emails WHERE user_id = 111; // 111 is user's primary key -``` - -### Many To Many - -```go -// User has and belongs to many languages, use `user_languages` as join table -type User struct { - gorm.Model - Languages []Language `gorm:"many2many:user_languages;"` -} - -type Language struct { - gorm.Model - Name string -} - -db.Model(&user).Related(&languages) -//// SELECT * FROM "languages" INNER JOIN "user_languages" ON "user_languages"."language_id" = "languages"."id" WHERE "user_languages"."user_id" = 111 -``` - -### Polymorphism - -Supports polymorphic has-many and has-one associations. - -```go - type Cat struct { - Id int - Name string - Toy Toy `gorm:"polymorphic:Owner;"` - } - - type Dog struct { - Id int - Name string - Toy Toy `gorm:"polymorphic:Owner;"` - } - - type Toy struct { - Id int - Name string - OwnerId int - OwnerType string - } -``` -Note: polymorphic belongs-to and many-to-many are explicitly NOT supported, and will throw errors. - -## Association Mode - -Association Mode contains some helper methods to handle relationship things easily. - -```go -// Start Association Mode -var user User -db.Model(&user).Association("Languages") -// `user` is the source, it need to be a valid record (contains primary key) -// `Languages` is source's field name for a relationship. -// If those conditions not matched, will return an error, check it with: -// db.Model(&user).Association("Languages").Error - - -// Query - Find out all related associations -db.Model(&user).Association("Languages").Find(&languages) - - -// Append - Append new associations for many2many, has_many, will replace current association for has_one, belongs_to -db.Model(&user).Association("Languages").Append([]Language{languageZH, languageEN}) -db.Model(&user).Association("Languages").Append(Language{Name: "DE"}) - - -// Delete - Remove relationship between source & passed arguments, won't delete those arguments -db.Model(&user).Association("Languages").Delete([]Language{languageZH, languageEN}) -db.Model(&user).Association("Languages").Delete(languageZH, languageEN) - - -// Replace - Replace current associations with new one -db.Model(&user).Association("Languages").Replace([]Language{languageZH, languageEN}) -db.Model(&user).Association("Languages").Replace(Language{Name: "DE"}, languageEN) - - -// Count - Return the count of current associations -db.Model(&user).Association("Languages").Count() - - -// Clear - Remove relationship between source & current associations, won't delete those associations -db.Model(&user).Association("Languages").Clear() -``` - -## Advanced Usage - -## FirstOrInit - -Get the first matched record, or initialize a record with search conditions. - -```go -// Unfound -db.FirstOrInit(&user, User{Name: "non_existing"}) -//// user -> User{Name: "non_existing"} - -// Found -db.Where(User{Name: "Jinzhu"}).FirstOrInit(&user) -//// user -> User{Id: 111, Name: "Jinzhu", Age: 20} -db.FirstOrInit(&user, map[string]interface{}{"name": "jinzhu"}) -//// user -> User{Id: 111, Name: "Jinzhu", Age: 20} -``` - -### Attrs - -Ignore some values when searching, but use them to initialize the struct if record is not found. - -```go -// Unfound -db.Where(User{Name: "non_existing"}).Attrs(User{Age: 20}).FirstOrInit(&user) -//// SELECT * FROM USERS WHERE name = 'non_existing'; -//// user -> User{Name: "non_existing", Age: 20} - -db.Where(User{Name: "noexisting_user"}).Attrs("age", 20).FirstOrInit(&user) -//// SELECT * FROM USERS WHERE name = 'non_existing'; -//// user -> User{Name: "non_existing", Age: 20} - -// Found -db.Where(User{Name: "Jinzhu"}).Attrs(User{Age: 30}).FirstOrInit(&user) -//// SELECT * FROM USERS WHERE name = jinzhu'; -//// user -> User{Id: 111, Name: "Jinzhu", Age: 20} -``` - -### Assign - -Ignore some values when searching, but assign it to the result regardless it is found or not. - -```go -// Unfound -db.Where(User{Name: "non_existing"}).Assign(User{Age: 20}).FirstOrInit(&user) -//// user -> User{Name: "non_existing", Age: 20} - -// Found -db.Where(User{Name: "Jinzhu"}).Assign(User{Age: 30}).FirstOrInit(&user) -//// SELECT * FROM USERS WHERE name = jinzhu'; -//// user -> User{Id: 111, Name: "Jinzhu", Age: 30} -``` - -## FirstOrCreate - -Get the first matched record, or create with search conditions. - -```go -// Unfound -db.FirstOrCreate(&user, User{Name: "non_existing"}) -//// INSERT INTO "users" (name) VALUES ("non_existing"); -//// user -> User{Id: 112, Name: "non_existing"} - -// Found -db.Where(User{Name: "Jinzhu"}).FirstOrCreate(&user) -//// user -> User{Id: 111, Name: "Jinzhu"} -``` - -### Attrs - -Ignore some values when searching, but use them to create the struct if record is not found. like `FirstOrInit` - -```go -// Unfound -db.Where(User{Name: "non_existing"}).Attrs(User{Age: 20}).FirstOrCreate(&user) -//// SELECT * FROM users WHERE name = 'non_existing'; -//// INSERT INTO "users" (name, age) VALUES ("non_existing", 20); -//// user -> User{Id: 112, Name: "non_existing", Age: 20} - -// Found -db.Where(User{Name: "jinzhu"}).Attrs(User{Age: 30}).FirstOrCreate(&user) -//// SELECT * FROM users WHERE name = 'jinzhu'; -//// user -> User{Id: 111, Name: "jinzhu", Age: 20} -``` - -### Assign - -Ignore some values when searching, but assign it to the record regardless it is found or not, then save back to database. like `FirstOrInit` - -```go -// Unfound -db.Where(User{Name: "non_existing"}).Assign(User{Age: 20}).FirstOrCreate(&user) -//// SELECT * FROM users WHERE name = 'non_existing'; -//// INSERT INTO "users" (name, age) VALUES ("non_existing", 20); -//// user -> User{Id: 112, Name: "non_existing", Age: 20} - -// Found -db.Where(User{Name: "jinzhu"}).Assign(User{Age: 30}).FirstOrCreate(&user) -//// SELECT * FROM users WHERE name = 'jinzhu'; -//// UPDATE users SET age=30 WHERE id = 111; -//// user -> User{Id: 111, Name: "jinzhu", Age: 30} -``` - -## Select - -```go -db.Select("name, age").Find(&users) -//// SELECT name, age FROM users; - -db.Select([]string{"name", "age"}).Find(&users) -//// SELECT name, age FROM users; - -db.Table("users").Select("COALESCE(age,?)", 42).Rows() -//// SELECT COALESCE(age,'42') FROM users; -``` - -## Order - -```go -db.Order("age desc, name").Find(&users) -//// SELECT * FROM users ORDER BY age desc, name; - -// Multiple orders -db.Order("age desc").Order("name").Find(&users) -//// SELECT * FROM users ORDER BY age desc, name; - -// ReOrder -db.Order("age desc").Find(&users1).Order("age", true).Find(&users2) -//// SELECT * FROM users ORDER BY age desc; (users1) -//// SELECT * FROM users ORDER BY age; (users2) -``` - -## Limit - -```go -db.Limit(3).Find(&users) -//// SELECT * FROM users LIMIT 3; - -// Cancel limit condition with -1 -db.Limit(10).Find(&users1).Limit(-1).Find(&users2) -//// SELECT * FROM users LIMIT 10; (users1) -//// SELECT * FROM users; (users2) -``` - -## Offset - -```go -db.Offset(3).Find(&users) -//// SELECT * FROM users OFFSET 3; - -// Cancel offset condition with -1 -db.Offset(10).Find(&users1).Offset(-1).Find(&users2) -//// SELECT * FROM users OFFSET 10; (users1) -//// SELECT * FROM users; (users2) -``` - -## Count - -```go -db.Where("name = ?", "jinzhu").Or("name = ?", "jinzhu 2").Find(&users).Count(&count) -//// SELECT * from USERS WHERE name = 'jinzhu' OR name = 'jinzhu 2'; (users) -//// SELECT count(*) FROM users WHERE name = 'jinzhu' OR name = 'jinzhu 2'; (count) - -db.Model(&User{}).Where("name = ?", "jinzhu").Count(&count) -//// SELECT count(*) FROM users WHERE name = 'jinzhu'; (count) - -db.Table("deleted_users").Count(&count) -//// SELECT count(*) FROM deleted_users; -``` - -## Pluck - -Get selected attributes as map - -```go -var ages []int64 -db.Find(&users).Pluck("age", &ages) - -var names []string -db.Model(&User{}).Pluck("name", &names) - -db.Table("deleted_users").Pluck("name", &names) - -// Requesting more than one column? Do it like this: -db.Select("name, age").Find(&users) -``` - -## Raw SQL - -```go -db.Exec("DROP TABLE users;") -db.Exec("UPDATE orders SET shipped_at=? WHERE id IN (?)", time.Now, []int64{11,22,33}) -``` - -## Row & Rows - -It is even possible to get query result as `*sql.Row` or `*sql.Rows` - -```go -row := db.Table("users").Where("name = ?", "jinzhu").Select("name, age").Row() // (*sql.Row) -row.Scan(&name, &age) - -rows, err := db.Model(User{}).Where("name = ?", "jinzhu").Select("name, age, email").Rows() // (*sql.Rows, error) -defer rows.Close() -for rows.Next() { - ... - rows.Scan(&name, &age, &email) - ... -} - -// Raw SQL -rows, err := db.Raw("select name, age, email from users where name = ?", "jinzhu").Rows() // (*sql.Rows, error) -defer rows.Close() -for rows.Next() { - ... - rows.Scan(&name, &age, &email) - ... -} -``` - -## Scan - -Scan results into another struct. - -```go -type Result struct { - Name string - Age int -} - -var result Result -db.Table("users").Select("name, age").Where("name = ?", 3).Scan(&result) - -// Raw SQL -db.Raw("SELECT name, age FROM users WHERE name = ?", 3).Scan(&result) -``` - -## Group & Having - -```go -rows, err := db.Table("orders").Select("date(created_at) as date, sum(amount) as total").Group("date(created_at)").Rows() -for rows.Next() { - ... -} - -rows, err := db.Table("orders").Select("date(created_at) as date, sum(amount) as total").Group("date(created_at)").Having("sum(amount) > ?", 100).Rows() -for rows.Next() { - ... -} - -type Result struct { - Date time.Time - Total int64 -} -db.Table("orders").Select("date(created_at) as date, sum(amount) as total").Group("date(created_at)").Having("sum(amount) > ?", 100).Scan(&results) -``` - -## Joins - -```go -rows, err := db.Table("users").Select("users.name, emails.email").Joins("left join emails on emails.user_id = users.id").Rows() -for rows.Next() { - ... -} - -db.Table("users").Select("users.name, emails.email").Joins("left join emails on emails.user_id = users.id").Scan(&results) - -// find a user by email address -db.Joins("inner join emails on emails.user_id = users.id").Where("emails.email = ?", "x@example.org").Find(&user) - -// find all email addresses for a user -db.Joins("left join users on users.id = emails.user_id").Where("users.name = ?", "jinzhu").Find(&emails) -``` - -## Transactions - -To perform a set of operations within a transaction, the general flow is as below. -The database handle returned from ``` db.Begin() ``` should be used for all operations within the transaction. -(Note that all individual save and delete operations are run in a transaction by default.) - -```go -// begin -tx := db.Begin() - -// do some database operations (use 'tx' from this point, not 'db') -tx.Create(...) -... - -// rollback in case of error -tx.Rollback() - -// Or commit if all is ok -tx.Commit() -``` - -### A Specific Example -``` -func CreateAnimals(db *gorm.DB) err { - tx := db.Begin() - // Note the use of tx as the database handle once you are within a transaction - - if err := tx.Create(&Animal{Name: "Giraffe"}).Error; err != nil { - tx.Rollback() - return err - } - - if err := tx.Create(&Animal{Name: "Lion"}).Error; err != nil { - tx.Rollback() - return err - } - - tx.Commit() - return nil -} -``` - -## Scopes - -```go -func AmountGreaterThan1000(db *gorm.DB) *gorm.DB { - return db.Where("amount > ?", 1000) -} - -func PaidWithCreditCard(db *gorm.DB) *gorm.DB { - return db.Where("pay_mode_sign = ?", "C") -} - -func PaidWithCod(db *gorm.DB) *gorm.DB { - return db.Where("pay_mode_sign = ?", "C") -} - -func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB { - return func (db *gorm.DB) *gorm.DB { - return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status) - } -} - -db.Scopes(AmountGreaterThan1000, PaidWithCreditCard).Find(&orders) -// Find all credit card orders and amount greater than 1000 - -db.Scopes(AmountGreaterThan1000, PaidWithCod).Find(&orders) -// Find all COD orders and amount greater than 1000 - -db.Scopes(OrderStatus([]string{"paid", "shipped"})).Find(&orders) -// Find all paid, shipped orders -``` - -## Callbacks - -Callbacks are methods defined on the pointer of struct. -If any callback returns an error, gorm will stop future operations and rollback all changes. - -Here is the list of all available callbacks: -(listed in the same order in which they will get called during the respective operations) - -### Creating An Object - -```go -BeforeSave -BeforeCreate -// save before associations -// save self -// save after associations -AfterCreate -AfterSave -``` -### Updating An Object - -```go -BeforeSave -BeforeUpdate -// save before associations -// save self -// save after associations -AfterUpdate -AfterSave -``` - -### Destroying An Object - -```go -BeforeDelete -// delete self -AfterDelete -``` - -### After Find - -```go -// load data from database -AfterFind -``` - -### Example - -```go -func (u *User) BeforeUpdate() (err error) { - if u.readonly() { - err = errors.New("read only user") - } - return -} - -// Rollback the insertion if user's id greater than 1000 -func (u *User) AfterCreate() (err error) { - if (u.Id > 1000) { - err = errors.New("user id is already greater than 1000") - } - return -} -``` - -Save/delete operations in gorm are running in a transaction. -Changes made in that transaction are not visible unless it is commited. -So if you want to use those changes in your callbacks, you need to run your SQL in the same transaction. -For this Gorm supports passing transactions to callbacks like this: - -```go -func (u *User) AfterCreate(tx *gorm.DB) (err error) { - tx.Model(u).Update("role", "admin") - return -} -``` - -## Specifying The Table Name - -```go -// Create `deleted_users` table with struct User's definition -db.Table("deleted_users").CreateTable(&User{}) - -var deleted_users []User -db.Table("deleted_users").Find(&deleted_users) -//// SELECT * FROM deleted_users; - -db.Table("deleted_users").Where("name = ?", "jinzhu").Delete() -//// DELETE FROM deleted_users WHERE name = 'jinzhu'; -``` - -### Specifying The Table Name For A Struct Permanently with TableName - -```go -type Cart struct { -} - -func (c Cart) TableName() string { - return "shopping_cart" -} - -func (u User) TableName() string { - if u.Role == "admin" { - return "admin_users" - } else { - return "users" - } -} -``` - -## Error Handling - -```go -query := db.Where("name = ?", "jinzhu").First(&user) -query := db.First(&user).Limit(10).Find(&users) -// query.Error will return the last happened error - -// So you could do error handing in your application like this: -if err := db.Where("name = ?", "jinzhu").First(&user).Error; err != nil { - // error handling... -} - -// RecordNotFound -// If no record found when you query data, gorm will return RecordNotFound error, you could check it like this: -db.Where("name = ?", "hello world").First(&User{}).Error == gorm.RecordNotFound -// Or use the shortcut method -db.Where("name = ?", "hello world").First(&user).RecordNotFound() - -if db.Model(&user).Related(&credit_card).RecordNotFound() { - // no credit card found error handling -} -``` - -## Logger - -Gorm has built-in logger support - -```go -// Enable Logger -db.LogMode(true) - -// Diable Logger -db.LogMode(false) - -// Debug a single operation -db.Debug().Where("name = ?", "jinzhu").First(&User{}) -``` - -![logger](https://raw.github.com/jinzhu/gorm/master/images/logger.png) - -### Customize Logger - -```go -// Refer gorm's default logger for how to: https://github.com/jinzhu/gorm/blob/master/logger.go#files -db.SetLogger(gorm.Logger{revel.TRACE}) -db.SetLogger(log.New(os.Stdout, "\r\n", 0)) -``` - -## Existing Schema - -If you have an existing database schema, and the primary key field is different from `id`, you can add a tag to the field structure to specify that this field is a primary key. - -```go -type Animal struct { - AnimalId int64 `gorm:"primary_key"` - Birthday time.Time `sql:"DEFAULT:current_timestamp"` - Name string `sql:"default:'galeone'"` - Age int64 -} -``` - -If your column names differ from the struct fields, you can specify them like this: - -```go -type Animal struct { - AnimalId int64 `gorm:"column:beast_id;primary_key"` - Birthday time.Time `gorm:"column:day_of_the_beast"` - Age int64 `gorm:"column:age_of_the_beast"` -} -``` - -## Composite Primary Key - -```go -type Product struct { - ID string `gorm:"primary_key"` - LanguageCode string `gorm:"primary_key"` -} -``` - -## Database Indexes & Foreign Key - -```go -// Add foreign key -// 1st param : foreignkey field -// 2nd param : destination table(id) -// 3rd param : ONDELETE -// 4th param : ONUPDATE -db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT") - -// Add index -db.Model(&User{}).AddIndex("idx_user_name", "name") - -// Multiple column index -db.Model(&User{}).AddIndex("idx_user_name_age", "name", "age") - -// Add unique index -db.Model(&User{}).AddUniqueIndex("idx_user_name", "name") - -// Multiple column unique index -db.Model(&User{}).AddUniqueIndex("idx_user_name_age", "name", "age") - -// Remove index -db.Model(&User{}).RemoveIndex("idx_user_name") -``` - -## Default values - -```go -type Animal struct { - ID int64 - Name string `sql:"default:'galeone'"` - Age int64 -} -``` - -If you have defined a default value in the `sql` tag, the generated create SQL will ignore these fields if it is blank. - -Eg. - -```go -db.Create(&Animal{Age: 99, Name: ""}) -``` - -The generated SQL will be: - -```sql -INSERT INTO animals("age") values('99'); -``` - -The same thing occurs in update statements. - -## More examples with query chain - -```go -db.First(&first_article).Count(&total_count).Limit(10).Find(&first_page_articles).Offset(10).Find(&second_page_articles) -//// SELECT * FROM articles LIMIT 1; (first_article) -//// SELECT count(*) FROM articles; (total_count) -//// SELECT * FROM articles LIMIT 10; (first_page_articles) -//// SELECT * FROM articles LIMIT 10 OFFSET 10; (second_page_articles) - - -db.Where("created_at > ?", "2013-10-10").Find(&cancelled_orders, "state = ?", "cancelled").Find(&shipped_orders, "state = ?", "shipped") -//// SELECT * FROM orders WHERE created_at > '2013/10/10' AND state = 'cancelled'; (cancelled_orders) -//// SELECT * FROM orders WHERE created_at > '2013/10/10' AND state = 'shipped'; (shipped_orders) - - -// Use variables to keep query chain -todays_orders := db.Where("created_at > ?", "2013-10-29") -cancelled_orders := todays_orders.Where("state = ?", "cancelled") -shipped_orders := todays_orders.Where("state = ?", "shipped") - - -// Search with shared conditions for different tables -db.Where("product_name = ?", "fancy_product").Find(&orders).Find(&shopping_carts) -//// SELECT * FROM orders WHERE product_name = 'fancy_product'; (orders) -//// SELECT * FROM carts WHERE product_name = 'fancy_product'; (shopping_carts) - - -// Search with shared conditions from different tables with specified table -db.Where("mail_type = ?", "TEXT").Find(&users1).Table("deleted_users").Find(&users2) -//// SELECT * FROM users WHERE mail_type = 'TEXT'; (users1) -//// SELECT * FROM deleted_users WHERE mail_type = 'TEXT'; (users2) - - -// FirstOrCreate example -db.Where("email = ?", "x@example.org").Attrs(User{RegisteredIp: "111.111.111.111"}).FirstOrCreate(&user) -//// SELECT * FROM users WHERE email = 'x@example.org'; -//// INSERT INTO "users" (email,registered_ip) VALUES ("x@example.org", "111.111.111.111") // if record not found -``` - -## Documentation - -[![GoDoc](https://godoc.org/github.com/jinzhu/gorm?status.svg)](https://godoc.org/github.com/jinzhu/gorm) - -`go doc` format documentation for this project can be viewed online without -installing the package by using the GoDoc page at: -http://godoc.org/github.com/jinzhu/gorm - -## TODO -* Github Pages +* [CHANGELOG](https://jinzhu.github.io/gorm/changelog.html) # Author diff --git a/association.go b/association.go index 30ea36b2..cd8fd912 100644 --- a/association.go +++ b/association.go @@ -4,32 +4,289 @@ import ( "errors" "fmt" "reflect" - "strings" ) +// Association Mode contains some helper methods to handle relationship things easily. type Association struct { - Scope *Scope - Column string Error error - Field *Field + scope *Scope + column string + field *Field } -func (association *Association) setErr(err error) *Association { - if err != nil { - association.Error = err +// Find find out all related associations +func (association *Association) Find(value interface{}) *Association { + association.scope.related(value, association.column) + return association.setErr(association.scope.db.Error) +} + +// Append append new associations for many2many, has_many, replace current association for has_one, belongs_to +func (association *Association) Append(values ...interface{}) *Association { + if relationship := association.field.Relationship; relationship.Kind == "has_one" { + return association.Replace(values...) + } + return association.saveAssociations(values...) +} + +// Replace replace current associations with new one +func (association *Association) Replace(values ...interface{}) *Association { + var ( + relationship = association.field.Relationship + scope = association.scope + field = association.field.Field + newDB = scope.NewDB() + ) + + // Append new values + association.field.Set(reflect.Zero(association.field.Field.Type())) + association.saveAssociations(values...) + + // Belongs To + if relationship.Kind == "belongs_to" { + // Set foreign key to be null when clearing value (length equals 0) + if len(values) == 0 { + // Set foreign key to be nil + var foreignKeyMap = map[string]interface{}{} + for _, foreignKey := range relationship.ForeignDBNames { + foreignKeyMap[foreignKey] = nil + } + association.setErr(newDB.Model(scope.Value).UpdateColumn(foreignKeyMap).Error) + } + } else { + // Polymorphic Relations + if relationship.PolymorphicDBName != "" { + newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName()) + } + + // Delete Relations except new created + if len(values) > 0 { + var associationForeignFieldNames []string + if relationship.Kind == "many_to_many" { + // if many to many relations, get association fields name from association foreign keys + associationScope := scope.New(reflect.New(field.Type()).Interface()) + for _, dbName := range relationship.AssociationForeignFieldNames { + if field, ok := associationScope.FieldByName(dbName); ok { + associationForeignFieldNames = append(associationForeignFieldNames, field.Name) + } + } + } else { + // If other relations, use primary keys + for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() { + associationForeignFieldNames = append(associationForeignFieldNames, field.Name) + } + } + + newPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, field.Interface()) + + if len(newPrimaryKeys) > 0 { + sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(newPrimaryKeys)) + newDB = newDB.Where(sql, toQueryValues(newPrimaryKeys)...) + } + } + + if relationship.Kind == "many_to_many" { + // if many to many relations, delete related relations from join table + var sourceForeignFieldNames []string + + for _, dbName := range relationship.ForeignFieldNames { + if field, ok := scope.FieldByName(dbName); ok { + sourceForeignFieldNames = append(sourceForeignFieldNames, field.Name) + } + } + + if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 { + newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...) + + association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship)) + } + } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { + // has_one or has_many relations, set foreign key to be nil (TODO or delete them?) + var foreignKeyMap = map[string]interface{}{} + for idx, foreignKey := range relationship.ForeignDBNames { + foreignKeyMap[foreignKey] = nil + if field, ok := scope.FieldByName(relationship.AssociationForeignFieldNames[idx]); ok { + newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) + } + } + + fieldValue := reflect.New(association.field.Field.Type()).Interface() + association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error) + } } return association } -func (association *Association) Find(value interface{}) *Association { - association.Scope.related(value, association.Column) - return association.setErr(association.Scope.db.Error) +// Delete remove relationship between source & passed arguments, but won't delete those arguments +func (association *Association) Delete(values ...interface{}) *Association { + var ( + relationship = association.field.Relationship + scope = association.scope + field = association.field.Field + newDB = scope.NewDB() + ) + + if len(values) == 0 { + return association + } + + var deletingResourcePrimaryFieldNames, deletingResourcePrimaryDBNames []string + for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() { + deletingResourcePrimaryFieldNames = append(deletingResourcePrimaryFieldNames, field.Name) + deletingResourcePrimaryDBNames = append(deletingResourcePrimaryDBNames, field.DBName) + } + + deletingPrimaryKeys := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, values...) + + if relationship.Kind == "many_to_many" { + // source value's foreign keys + for idx, foreignKey := range relationship.ForeignDBNames { + if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok { + newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) + } + } + + // get association's foreign fields name + var associationScope = scope.New(reflect.New(field.Type()).Interface()) + var associationForeignFieldNames []string + for _, associationDBName := range relationship.AssociationForeignFieldNames { + if field, ok := associationScope.FieldByName(associationDBName); ok { + associationForeignFieldNames = append(associationForeignFieldNames, field.Name) + } + } + + // association value's foreign keys + deletingPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, values...) + sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys)) + newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...) + + association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship)) + } else { + var foreignKeyMap = map[string]interface{}{} + for _, foreignKey := range relationship.ForeignDBNames { + foreignKeyMap[foreignKey] = nil + } + + if relationship.Kind == "belongs_to" { + // find with deleting relation's foreign keys + primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, values...) + newDB = newDB.Where( + fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), + toQueryValues(primaryKeys)..., + ) + + // set foreign key to be null if there are some records affected + modelValue := reflect.New(scope.GetModelStruct().ModelType).Interface() + if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.Error == nil { + if results.RowsAffected > 0 { + scope.updatedAttrsWithValues(foreignKeyMap) + } + } else { + association.setErr(results.Error) + } + } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { + // find all relations + primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value) + newDB = newDB.Where( + fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), + toQueryValues(primaryKeys)..., + ) + + // only include those deleting relations + newDB = newDB.Where( + fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, deletingResourcePrimaryDBNames), toQueryMarks(deletingPrimaryKeys)), + toQueryValues(deletingPrimaryKeys)..., + ) + + // set matched relation's foreign key to be null + fieldValue := reflect.New(association.field.Field.Type()).Interface() + association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error) + } + } + + // Remove deleted records from source's field + if association.Error == nil { + if field.Kind() == reflect.Slice { + leftValues := reflect.Zero(field.Type()) + + for i := 0; i < field.Len(); i++ { + reflectValue := field.Index(i) + primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0] + var isDeleted = false + for _, pk := range deletingPrimaryKeys { + if equalAsString(primaryKey, pk) { + isDeleted = true + break + } + } + if !isDeleted { + leftValues = reflect.Append(leftValues, reflectValue) + } + } + + association.field.Set(leftValues) + } else if field.Kind() == reflect.Struct { + primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, field.Interface())[0] + for _, pk := range deletingPrimaryKeys { + if equalAsString(primaryKey, pk) { + association.field.Set(reflect.Zero(field.Type())) + break + } + } + } + } + + return association } +// Clear remove relationship between source & current associations, won't delete those associations +func (association *Association) Clear() *Association { + return association.Replace() +} + +// Count return the count of current associations +func (association *Association) Count() int { + var ( + count = 0 + relationship = association.field.Relationship + scope = association.scope + fieldValue = association.field.Field.Interface() + query = scope.DB() + ) + + if relationship.Kind == "many_to_many" { + query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value) + } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { + primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value) + query = query.Where( + fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), + toQueryValues(primaryKeys)..., + ) + } else if relationship.Kind == "belongs_to" { + primaryKeys := scope.getColumnAsArray(relationship.ForeignFieldNames, scope.Value) + query = query.Where( + fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)), + toQueryValues(primaryKeys)..., + ) + } + + if relationship.PolymorphicType != "" { + query = query.Where( + fmt.Sprintf("%v.%v = ?", scope.New(fieldValue).QuotedTableName(), scope.Quote(relationship.PolymorphicDBName)), + scope.TableName(), + ) + } + + query.Model(fieldValue).Count(&count) + return count +} + +// saveAssociations save passed values as associations func (association *Association) saveAssociations(values ...interface{}) *Association { - scope := association.Scope - field := association.Field - relationship := association.Field.Relationship + var ( + scope = association.scope + field = association.field + relationship = field.Relationship + ) saveAssociation := func(reflectValue reflect.Value) { // value has to been pointer @@ -94,318 +351,9 @@ func (association *Association) saveAssociations(values ...interface{}) *Associa return association } -func (association *Association) Append(values ...interface{}) *Association { - if relationship := association.Field.Relationship; relationship.Kind == "has_one" { - return association.Replace(values...) - } - return association.saveAssociations(values...) -} - -func (association *Association) Replace(values ...interface{}) *Association { - var ( - relationship = association.Field.Relationship - scope = association.Scope - field = association.Field.Field - newDB = scope.NewDB() - ) - - // Append new values - association.Field.Set(reflect.Zero(association.Field.Field.Type())) - association.saveAssociations(values...) - - // Belongs To - if relationship.Kind == "belongs_to" { - // Set foreign key to be null only when clearing value - if len(values) == 0 { - // Set foreign key to be nil - var foreignKeyMap = map[string]interface{}{} - for _, foreignKey := range relationship.ForeignDBNames { - foreignKeyMap[foreignKey] = nil - } - association.setErr(newDB.Model(scope.Value).UpdateColumn(foreignKeyMap).Error) - } - } else { - // Relations - if relationship.PolymorphicDBName != "" { - newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName()) - } - - // Relations except new created - if len(values) > 0 { - var newPrimaryKeys [][]interface{} - var associationForeignFieldNames []string - - if relationship.Kind == "many_to_many" { - // If many to many relations, get it from foreign key - associationForeignFieldNames = relationship.AssociationForeignFieldNames - } else { - // If other relations, get real primary keys - for _, field := range scope.New(reflect.New(field.Type()).Interface()).Fields() { - if field.IsPrimaryKey { - associationForeignFieldNames = append(associationForeignFieldNames, field.Name) - } - } - } - - newPrimaryKeys = association.getPrimaryKeys(associationForeignFieldNames, field.Interface()) - - if len(newPrimaryKeys) > 0 { - sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(newPrimaryKeys)) - newDB = newDB.Where(sql, toQueryValues(newPrimaryKeys)...) - } - } - - if relationship.Kind == "many_to_many" { - for idx, foreignKey := range relationship.ForeignDBNames { - if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok { - newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) - } - } - - association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship)) - } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { - var foreignKeyMap = map[string]interface{}{} - for idx, foreignKey := range relationship.ForeignDBNames { - foreignKeyMap[foreignKey] = nil - if field, ok := scope.FieldByName(relationship.AssociationForeignFieldNames[idx]); ok { - newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) - } - } - - fieldValue := reflect.New(association.Field.Field.Type()).Interface() - association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error) - } +func (association *Association) setErr(err error) *Association { + if err != nil { + association.Error = err } return association } - -func (association *Association) Delete(values ...interface{}) *Association { - var ( - relationship = association.Field.Relationship - scope = association.Scope - field = association.Field.Field - newDB = scope.NewDB() - ) - - if len(values) == 0 { - return association - } - - var deletingResourcePrimaryFieldNames, deletingResourcePrimaryDBNames []string - for _, field := range scope.New(reflect.New(field.Type()).Interface()).Fields() { - if field.IsPrimaryKey { - deletingResourcePrimaryFieldNames = append(deletingResourcePrimaryFieldNames, field.Name) - deletingResourcePrimaryDBNames = append(deletingResourcePrimaryDBNames, field.DBName) - } - } - - deletingPrimaryKeys := association.getPrimaryKeys(deletingResourcePrimaryFieldNames, values...) - - if relationship.Kind == "many_to_many" { - // source value's foreign keys - for idx, foreignKey := range relationship.ForeignDBNames { - if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok { - newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) - } - } - - // association value's foreign keys - deletingPrimaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, values...) - sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys)) - newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...) - - association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship)) - } else { - var foreignKeyMap = map[string]interface{}{} - for _, foreignKey := range relationship.ForeignDBNames { - foreignKeyMap[foreignKey] = nil - } - - if relationship.Kind == "belongs_to" { - // find with deleting relation's foreign keys - primaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, values...) - newDB = newDB.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), - toQueryValues(primaryKeys)..., - ) - - // set foreign key to be null - modelValue := reflect.New(scope.GetModelStruct().ModelType).Interface() - if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.Error == nil { - if results.RowsAffected > 0 { - scope.updatedAttrsWithValues(foreignKeyMap, false) - } - } else { - association.setErr(results.Error) - } - } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { - // find all relations - primaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, scope.Value) - newDB = newDB.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), - toQueryValues(primaryKeys)..., - ) - - // only include those deleting relations - newDB = newDB.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, deletingResourcePrimaryDBNames), toQueryMarks(deletingPrimaryKeys)), - toQueryValues(deletingPrimaryKeys)..., - ) - - // set matched relation's foreign key to be null - fieldValue := reflect.New(association.Field.Field.Type()).Interface() - association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error) - } - } - - // Remove deleted records from field - if association.Error == nil { - if association.Field.Field.Kind() == reflect.Slice { - leftValues := reflect.Zero(association.Field.Field.Type()) - - for i := 0; i < association.Field.Field.Len(); i++ { - reflectValue := association.Field.Field.Index(i) - primaryKey := association.getPrimaryKeys(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0] - var included = false - for _, pk := range deletingPrimaryKeys { - if equalAsString(primaryKey, pk) { - included = true - } - } - if !included { - leftValues = reflect.Append(leftValues, reflectValue) - } - } - - association.Field.Set(leftValues) - } else if association.Field.Field.Kind() == reflect.Struct { - primaryKey := association.getPrimaryKeys(deletingResourcePrimaryFieldNames, association.Field.Field.Interface())[0] - for _, pk := range deletingPrimaryKeys { - if equalAsString(primaryKey, pk) { - association.Field.Set(reflect.Zero(association.Field.Field.Type())) - break - } - } - } - } - - return association -} - -func (association *Association) Clear() *Association { - return association.Replace() -} - -func (association *Association) Count() int { - var ( - count = 0 - relationship = association.Field.Relationship - scope = association.Scope - fieldValue = association.Field.Field.Interface() - newScope = scope.New(fieldValue) - ) - - if relationship.Kind == "many_to_many" { - relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.DB(), association.Scope.Value).Model(fieldValue).Count(&count) - } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { - query := scope.DB() - for idx, foreignKey := range relationship.ForeignDBNames { - if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { - query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(foreignKey)), - field.Field.Interface()) - } - } - - if relationship.PolymorphicType != "" { - query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.PolymorphicDBName)), scope.TableName()) - } - query.Model(fieldValue).Count(&count) - } else if relationship.Kind == "belongs_to" { - query := scope.DB() - for idx, primaryKey := range relationship.AssociationForeignDBNames { - if field, ok := scope.FieldByName(relationship.ForeignDBNames[idx]); ok { - query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(primaryKey)), - field.Field.Interface()) - } - } - query.Model(fieldValue).Count(&count) - } - - return count -} - -func (association *Association) getPrimaryKeys(columns []string, values ...interface{}) (results [][]interface{}) { - scope := association.Scope - - for _, value := range values { - reflectValue := reflect.Indirect(reflect.ValueOf(value)) - if reflectValue.Kind() == reflect.Slice { - for i := 0; i < reflectValue.Len(); i++ { - primaryKeys := []interface{}{} - newScope := scope.New(reflectValue.Index(i).Interface()) - for _, column := range columns { - if field, ok := newScope.FieldByName(column); ok { - primaryKeys = append(primaryKeys, field.Field.Interface()) - } else { - primaryKeys = append(primaryKeys, "") - } - } - results = append(results, primaryKeys) - } - } else if reflectValue.Kind() == reflect.Struct { - newScope := scope.New(value) - var primaryKeys []interface{} - for _, column := range columns { - if field, ok := newScope.FieldByName(column); ok { - primaryKeys = append(primaryKeys, field.Field.Interface()) - } else { - primaryKeys = append(primaryKeys, "") - } - } - - results = append(results, primaryKeys) - } - } - - return -} - -func toQueryMarks(primaryValues [][]interface{}) string { - var results []string - - for _, primaryValue := range primaryValues { - var marks []string - for _, _ = range primaryValue { - marks = append(marks, "?") - } - - if len(marks) > 1 { - results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ","))) - } else { - results = append(results, strings.Join(marks, "")) - } - } - return strings.Join(results, ",") -} - -func toQueryCondition(scope *Scope, columns []string) string { - var newColumns []string - for _, column := range columns { - newColumns = append(newColumns, scope.Quote(column)) - } - - if len(columns) > 1 { - return fmt.Sprintf("(%v)", strings.Join(newColumns, ",")) - } else { - return strings.Join(newColumns, ",") - } -} - -func toQueryValues(primaryValues [][]interface{}) (values []interface{}) { - for _, primaryValue := range primaryValues { - for _, value := range primaryValue { - values = append(values, value) - } - } - return values -} diff --git a/association_test.go b/association_test.go index f02d4620..52d2303f 100644 --- a/association_test.go +++ b/association_test.go @@ -5,6 +5,8 @@ import ( "reflect" "sort" "testing" + + "github.com/jinzhu/gorm" ) func TestBelongsTo(t *testing.T) { @@ -16,7 +18,7 @@ func TestBelongsTo(t *testing.T) { } if err := DB.Save(&post).Error; err != nil { - t.Errorf("Got errors when save post", err.Error()) + t.Error("Got errors when save post", err) } if post.Category.ID == 0 || post.MainCategory.ID == 0 { @@ -177,6 +179,49 @@ func TestBelongsTo(t *testing.T) { } } +func TestBelongsToOverrideForeignKey1(t *testing.T) { + type Profile struct { + gorm.Model + Name string + } + + type User struct { + gorm.Model + Profile Profile `gorm:"ForeignKey:ProfileRefer"` + ProfileRefer int + } + + if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { + if relation.Relationship.Kind != "belongs_to" || + !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"ProfileRefer"}) || + !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) { + t.Errorf("Override belongs to foreign key with tag") + } + } +} + +func TestBelongsToOverrideForeignKey2(t *testing.T) { + type Profile struct { + gorm.Model + Refer string + Name string + } + + type User struct { + gorm.Model + Profile Profile `gorm:"ForeignKey:ProfileID;AssociationForeignKey:Refer"` + ProfileID int + } + + if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { + if relation.Relationship.Kind != "belongs_to" || + !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"ProfileID"}) || + !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) { + t.Errorf("Override belongs to foreign key with tag") + } + } +} + func TestHasOne(t *testing.T) { user := User{ Name: "has one", @@ -184,7 +229,7 @@ func TestHasOne(t *testing.T) { } if err := DB.Save(&user).Error; err != nil { - t.Errorf("Got errors when save user", err.Error()) + t.Error("Got errors when save user", err.Error()) } if user.CreditCard.UserId.Int64 == 0 { @@ -323,6 +368,49 @@ func TestHasOne(t *testing.T) { } } +func TestHasOneOverrideForeignKey1(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserRefer uint + } + + type User struct { + gorm.Model + Profile Profile `gorm:"ForeignKey:UserRefer"` + } + + if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { + if relation.Relationship.Kind != "has_one" || + !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserRefer"}) || + !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) { + t.Errorf("Override belongs to foreign key with tag") + } + } +} + +func TestHasOneOverrideForeignKey2(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserID uint + } + + type User struct { + gorm.Model + Refer string + Profile Profile `gorm:"ForeignKey:UserID;AssociationForeignKey:Refer"` + } + + if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { + if relation.Relationship.Kind != "has_one" || + !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserID"}) || + !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) { + t.Errorf("Override belongs to foreign key with tag") + } + } +} + func TestHasMany(t *testing.T) { post := Post{ Title: "post has many", @@ -331,7 +419,7 @@ func TestHasMany(t *testing.T) { } if err := DB.Save(&post).Error; err != nil { - t.Errorf("Got errors when save post", err.Error()) + t.Error("Got errors when save post", err) } for _, comment := range post.Comments { @@ -462,6 +550,49 @@ func TestHasMany(t *testing.T) { } } +func TestHasManyOverrideForeignKey1(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserRefer uint + } + + type User struct { + gorm.Model + Profile []Profile `gorm:"ForeignKey:UserRefer"` + } + + if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { + if relation.Relationship.Kind != "has_many" || + !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserRefer"}) || + !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) { + t.Errorf("Override belongs to foreign key with tag") + } + } +} + +func TestHasManyOverrideForeignKey2(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserID uint + } + + type User struct { + gorm.Model + Refer string + Profile []Profile `gorm:"ForeignKey:UserID;AssociationForeignKey:Refer"` + } + + if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { + if relation.Relationship.Kind != "has_many" || + !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserID"}) || + !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) { + t.Errorf("Override belongs to foreign key with tag") + } + } +} + func TestManyToMany(t *testing.T) { DB.Raw("delete from languages") var languages = []Language{{Name: "ZH"}, {Name: "EN"}} diff --git a/callback.go b/callback.go index 603e5111..84526236 100644 --- a/callback.go +++ b/callback.go @@ -4,34 +4,39 @@ import ( "fmt" ) -type callback struct { +// DefaultCallback default callbacks defined by gorm +var DefaultCallback = &Callback{} + +// Callback is a struct that contains all CURD callbacks +// Field `creates` contains callbacks will be call when creating object +// Field `updates` contains callbacks will be call when updating object +// Field `deletes` contains callbacks will be call when deleting object +// Field `queries` contains callbacks will be call when querying object with query methods like Find, First, Related, Association... +// Field `rowQueries` contains callbacks will be call when querying object with Row, Rows... +// Field `processors` contains all callback processors, will be used to generate above callbacks in order +type Callback struct { creates []*func(scope *Scope) updates []*func(scope *Scope) deletes []*func(scope *Scope) queries []*func(scope *Scope) rowQueries []*func(scope *Scope) - processors []*callbackProcessor + processors []*CallbackProcessor } -type callbackProcessor struct { - name string - before string - after string - replace bool - remove bool - typ string - processor *func(scope *Scope) - callback *callback +// CallbackProcessor contains callback informations +type CallbackProcessor struct { + name string // current callback's name + before string // register current callback before a callback + after string // register current callback after a callback + replace bool // replace callbacks with same name + remove bool // delete callbacks with same name + kind string // callback type: create, update, delete, query, row_query + processor *func(scope *Scope) // callback handler + parent *Callback } -func (c *callback) addProcessor(typ string) *callbackProcessor { - cp := &callbackProcessor{typ: typ, callback: c} - c.processors = append(c.processors, cp) - return cp -} - -func (c *callback) clone() *callback { - return &callback{ +func (c *Callback) clone() *Callback { + return &Callback{ creates: c.creates, updates: c.updates, deletes: c.deletes, @@ -40,57 +45,95 @@ func (c *callback) clone() *callback { } } -func (c *callback) Create() *callbackProcessor { - return c.addProcessor("create") +// Create could be used to register callbacks for creating object +// db.Callback().Create().After("gorm:create").Register("plugin:run_after_create", func(*Scope) { +// // business logic +// ... +// +// // set error if some thing wrong happened, will rollback the creating +// scope.Err(errors.New("error")) +// }) +func (c *Callback) Create() *CallbackProcessor { + return &CallbackProcessor{kind: "create", parent: c} } -func (c *callback) Update() *callbackProcessor { - return c.addProcessor("update") +// Update could be used to register callbacks for updating object, refer `Create` for usage +func (c *Callback) Update() *CallbackProcessor { + return &CallbackProcessor{kind: "update", parent: c} } -func (c *callback) Delete() *callbackProcessor { - return c.addProcessor("delete") +// Delete could be used to register callbacks for deleting object, refer `Create` for usage +func (c *Callback) Delete() *CallbackProcessor { + return &CallbackProcessor{kind: "delete", parent: c} } -func (c *callback) Query() *callbackProcessor { - return c.addProcessor("query") +// Query could be used to register callbacks for querying objects with query methods like `Find`, `First`, `Related`, `Association`... +// Refer `Create` for usage +func (c *Callback) Query() *CallbackProcessor { + return &CallbackProcessor{kind: "query", parent: c} } -func (c *callback) RowQuery() *callbackProcessor { - return c.addProcessor("row_query") +// RowQuery could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage +func (c *Callback) RowQuery() *CallbackProcessor { + return &CallbackProcessor{kind: "row_query", parent: c} } -func (cp *callbackProcessor) Before(name string) *callbackProcessor { - cp.before = name +// After insert a new callback after callback `callbackName`, refer `Callbacks.Create` +func (cp *CallbackProcessor) After(callbackName string) *CallbackProcessor { + cp.after = callbackName return cp } -func (cp *callbackProcessor) After(name string) *callbackProcessor { - cp.after = name +// Before insert a new callback before callback `callbackName`, refer `Callbacks.Create` +func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor { + cp.before = callbackName return cp } -func (cp *callbackProcessor) Register(name string, fc func(scope *Scope)) { - cp.name = name - cp.processor = &fc - cp.callback.sort() +// Register a new callback, refer `Callbacks.Create` +func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) { + cp.name = callbackName + cp.processor = &callback + cp.parent.processors = append(cp.parent.processors, cp) + cp.parent.reorder() } -func (cp *callbackProcessor) Remove(name string) { - fmt.Printf("[info] removing callback `%v` from %v\n", name, fileWithLineNum()) - cp.name = name +// Remove a registered callback +// db.Callback().Create().Remove("gorm:update_time_stamp_when_create") +func (cp *CallbackProcessor) Remove(callbackName string) { + fmt.Printf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum()) + cp.name = callbackName cp.remove = true - cp.callback.sort() + cp.parent.processors = append(cp.parent.processors, cp) + cp.parent.reorder() } -func (cp *callbackProcessor) Replace(name string, fc func(scope *Scope)) { - fmt.Printf("[info] replacing callback `%v` from %v\n", name, fileWithLineNum()) - cp.name = name - cp.processor = &fc +// Replace a registered callback with new callback +// db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) { +// scope.SetColumn("Created", now) +// scope.SetColumn("Updated", now) +// }) +func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) { + fmt.Printf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum()) + cp.name = callbackName + cp.processor = &callback cp.replace = true - cp.callback.sort() + cp.parent.processors = append(cp.parent.processors, cp) + cp.parent.reorder() } +// Get registered callback +// db.Callback().Create().Get("gorm:create") +func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) { + for _, p := range cp.parent.processors { + if p.name == callbackName && p.kind == cp.kind && !cp.remove { + return *p.processor + } + } + return nil +} + +// getRIndex get right index from string slice func getRIndex(strs []string, str string) int { for i := len(strs) - 1; i >= 0; i-- { if strs[i] == str { @@ -100,93 +143,88 @@ func getRIndex(strs []string, str string) int { return -1 } -func sortProcessors(cps []*callbackProcessor) []*func(scope *Scope) { - var sortCallbackProcessor func(c *callbackProcessor) - var names, sortedNames = []string{}, []string{} +// sortProcessors sort callback processors based on its before, after, remove, replace +func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) { + var ( + allNames, sortedNames []string + sortCallbackProcessor func(c *CallbackProcessor) + ) for _, cp := range cps { - if index := getRIndex(names, cp.name); index > -1 { - if !cp.replace && !cp.remove { - fmt.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum()) - } + // show warning message the callback name already exists + if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove { + fmt.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum()) } - names = append(names, cp.name) + allNames = append(allNames, cp.name) } - sortCallbackProcessor = func(c *callbackProcessor) { - if getRIndex(sortedNames, c.name) > -1 { - return - } - - if len(c.before) > 0 { - if index := getRIndex(sortedNames, c.before); index > -1 { - sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...) - } else if index := getRIndex(names, c.before); index > -1 { - sortedNames = append(sortedNames, c.name) - sortCallbackProcessor(cps[index]) - } else { - sortedNames = append(sortedNames, c.name) - } - } - - if len(c.after) > 0 { - if index := getRIndex(sortedNames, c.after); index > -1 { - sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...) - } else if index := getRIndex(names, c.after); index > -1 { - cp := cps[index] - if len(cp.before) == 0 { - cp.before = c.name + sortCallbackProcessor = func(c *CallbackProcessor) { + if getRIndex(sortedNames, c.name) == -1 { // if not sorted + if c.before != "" { // if defined before callback + if index := getRIndex(sortedNames, c.before); index != -1 { + // if before callback already sorted, append current callback just after it + sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...) + } else if index := getRIndex(allNames, c.before); index != -1 { + // if before callback exists but haven't sorted, append current callback to last + sortedNames = append(sortedNames, c.name) + sortCallbackProcessor(cps[index]) } - sortCallbackProcessor(cp) - } else { + } + + if c.after != "" { // if defined after callback + if index := getRIndex(sortedNames, c.after); index != -1 { + // if after callback already sorted, append current callback just before it + sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...) + } else if index := getRIndex(allNames, c.after); index != -1 { + // if after callback exists but haven't sorted + cp := cps[index] + // set after callback's before callback to current callback + if cp.before == "" { + cp.before = c.name + } + sortCallbackProcessor(cp) + } + } + + // if current callback haven't been sorted, append it to last + if getRIndex(sortedNames, c.name) == -1 { sortedNames = append(sortedNames, c.name) } } - - if getRIndex(sortedNames, c.name) == -1 { - sortedNames = append(sortedNames, c.name) - } } for _, cp := range cps { sortCallbackProcessor(cp) } - var funcs = []*func(scope *Scope){} - var sortedFuncs = []*func(scope *Scope){} + var sortedFuncs []*func(scope *Scope) for _, name := range sortedNames { - index := getRIndex(names, name) - if !cps[index].remove { + if index := getRIndex(allNames, name); !cps[index].remove { sortedFuncs = append(sortedFuncs, cps[index].processor) } } - for _, cp := range cps { - if sindex := getRIndex(sortedNames, cp.name); sindex == -1 { - if !cp.remove { - funcs = append(funcs, cp.processor) - } - } - } - - return append(sortedFuncs, funcs...) + return sortedFuncs } -func (c *callback) sort() { - var creates, updates, deletes, queries, rowQueries []*callbackProcessor +// reorder all registered processors, and reset CURD callbacks +func (c *Callback) reorder() { + var creates, updates, deletes, queries, rowQueries []*CallbackProcessor for _, processor := range c.processors { - switch processor.typ { - case "create": - creates = append(creates, processor) - case "update": - updates = append(updates, processor) - case "delete": - deletes = append(deletes, processor) - case "query": - queries = append(queries, processor) - case "row_query": - rowQueries = append(rowQueries, processor) + if processor.name != "" { + switch processor.kind { + case "create": + creates = append(creates, processor) + case "update": + updates = append(updates, processor) + case "delete": + deletes = append(deletes, processor) + case "query": + queries = append(queries, processor) + case "row_query": + rowQueries = append(rowQueries, processor) + } } } @@ -196,5 +234,3 @@ func (c *callback) sort() { c.queries = sortProcessors(queries) c.rowQueries = sortProcessors(rowQueries) } - -var DefaultCallback = &callback{processors: []*callbackProcessor{}} diff --git a/callback_create.go b/callback_create.go index d13a71be..e3cd2f0b 100644 --- a/callback_create.go +++ b/callback_create.go @@ -5,12 +5,31 @@ import ( "strings" ) -func BeforeCreate(scope *Scope) { - scope.CallMethodWithErrorCheck("BeforeSave") - scope.CallMethodWithErrorCheck("BeforeCreate") +// Define callbacks for creating +func init() { + DefaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback) + DefaultCallback.Create().Register("gorm:before_create", beforeCreateCallback) + DefaultCallback.Create().Register("gorm:save_before_associations", saveBeforeAssociationsCallback) + DefaultCallback.Create().Register("gorm:update_time_stamp", updateTimeStampForCreateCallback) + DefaultCallback.Create().Register("gorm:create", createCallback) + DefaultCallback.Create().Register("gorm:force_reload_after_create", forceReloadAfterCreateCallback) + DefaultCallback.Create().Register("gorm:save_after_associations", saveAfterAssociationsCallback) + DefaultCallback.Create().Register("gorm:after_create", afterCreateCallback) + DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) } -func UpdateTimeStampWhenCreate(scope *Scope) { +// beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating +func beforeCreateCallback(scope *Scope) { + if !scope.HasError() { + scope.CallMethod("BeforeSave") + } + if !scope.HasError() { + scope.CallMethod("BeforeCreate") + } +} + +// updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating +func updateTimeStampForCreateCallback(scope *Scope) { if !scope.HasError() { now := NowFunc() scope.SetColumn("CreatedAt", now) @@ -18,109 +37,108 @@ func UpdateTimeStampWhenCreate(scope *Scope) { } } -func Create(scope *Scope) { - defer scope.Trace(NowFunc()) - +// createCallback the callback used to insert data into database +func createCallback(scope *Scope) { if !scope.HasError() { - // set create sql - var sqls, columns []string - fields := scope.Fields() - for _, field := range fields { + defer scope.trace(NowFunc()) + + var ( + columns, placeholders []string + blankColumnsWithDefaultValue []string + ) + + for _, field := range scope.Fields() { if scope.changeableField(field) { if field.IsNormal { - if !field.IsPrimaryKey || (field.IsPrimaryKey && !field.IsBlank) { - if !field.IsBlank || !field.HasDefaultValue { + if !field.IsPrimaryKey || !field.IsBlank { + if field.IsBlank && field.HasDefaultValue { + blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, field.DBName) + scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue) + } else { columns = append(columns, scope.Quote(field.DBName)) - sqls = append(sqls, scope.AddToVars(field.Field.Interface())) - } else if field.HasDefaultValue { - var hasDefaultValueColumns []string - if oldHasDefaultValueColumns, ok := scope.InstanceGet("gorm:force_reload_after_create_attrs"); ok { - hasDefaultValueColumns = oldHasDefaultValueColumns.([]string) - } - hasDefaultValueColumns = append(hasDefaultValueColumns, field.DBName) - scope.InstanceSet("gorm:force_reload_after_create_attrs", hasDefaultValueColumns) + placeholders = append(placeholders, scope.AddToVars(field.Field.Interface())) } } - } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { - for _, dbName := range relationship.ForeignDBNames { - if relationField := fields[dbName]; !scope.changeableField(relationField) { - columns = append(columns, scope.Quote(relationField.DBName)) - sqls = append(sqls, scope.AddToVars(relationField.Field.Interface())) + } else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" { + for _, foreignKey := range field.Relationship.ForeignDBNames { + if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) { + columns = append(columns, scope.Quote(foreignField.DBName)) + placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface())) } } } } } - returningKey := "*" - primaryField := scope.PrimaryField() - if primaryField != nil { - returningKey = scope.Quote(primaryField.DBName) + var ( + returningColumn = "*" + quotedTableName = scope.QuotedTableName() + primaryField = scope.PrimaryField() + extraOption string + ) + + if str, ok := scope.Get("gorm:insert_option"); ok { + extraOption = fmt.Sprint(str) } + if primaryField != nil { + returningColumn = scope.Quote(primaryField.DBName) + } + + lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn) + if len(columns) == 0 { - scope.Raw(fmt.Sprintf("INSERT INTO %v DEFAULT VALUES %v", - scope.QuotedTableName(), - scope.Dialect().ReturningStr(scope.QuotedTableName(), returningKey), + scope.Raw(fmt.Sprintf( + "INSERT INTO %v DEFAULT VALUES%v%v", + quotedTableName, + addExtraSpaceIfExist(extraOption), + addExtraSpaceIfExist(lastInsertIDReturningSuffix), )) } else { scope.Raw(fmt.Sprintf( - "INSERT INTO %v (%v) VALUES (%v) %v", + "INSERT INTO %v (%v) VALUES (%v)%v%v", scope.QuotedTableName(), strings.Join(columns, ","), - strings.Join(sqls, ","), - scope.Dialect().ReturningStr(scope.QuotedTableName(), returningKey), + strings.Join(placeholders, ","), + addExtraSpaceIfExist(extraOption), + addExtraSpaceIfExist(lastInsertIDReturningSuffix), )) } // execute create sql - if scope.Dialect().SupportLastInsertId() { - if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil { - id, err := result.LastInsertId() - if scope.Err(err) == nil { - scope.db.RowsAffected, _ = result.RowsAffected() - if primaryField != nil && primaryField.IsBlank { - scope.Err(scope.SetColumn(primaryField, id)) + if lastInsertIDReturningSuffix == "" || primaryField == nil { + if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { + // set rows affected count + scope.db.RowsAffected, _ = result.RowsAffected() + + // set primary value to primary field + if primaryField != nil && primaryField.IsBlank { + if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil { + scope.Err(primaryField.Set(primaryValue)) } } } } else { - if primaryField == nil { - if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == nil { - scope.db.RowsAffected, _ = results.RowsAffected() - } else { - scope.Err(err) - } - } else { - if err := scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())); err == nil { - scope.db.RowsAffected = 1 - } else { - scope.Err(err) - } + if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil { + scope.db.RowsAffected = 1 } } } } -func ForceReloadAfterCreate(scope *Scope) { - if columns, ok := scope.InstanceGet("gorm:force_reload_after_create_attrs"); ok { - scope.DB().New().Select(columns.([]string)).First(scope.Value) +// forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object +func forceReloadAfterCreateCallback(scope *Scope) { + if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok { + scope.DB().New().Select(blankColumnsWithDefaultValue.([]string)).First(scope.Value) } } -func AfterCreate(scope *Scope) { - scope.CallMethodWithErrorCheck("AfterCreate") - scope.CallMethodWithErrorCheck("AfterSave") -} - -func init() { - DefaultCallback.Create().Register("gorm:begin_transaction", BeginTransaction) - DefaultCallback.Create().Register("gorm:before_create", BeforeCreate) - DefaultCallback.Create().Register("gorm:save_before_associations", SaveBeforeAssociations) - DefaultCallback.Create().Register("gorm:update_time_stamp_when_create", UpdateTimeStampWhenCreate) - DefaultCallback.Create().Register("gorm:create", Create) - DefaultCallback.Create().Register("gorm:force_reload_after_create", ForceReloadAfterCreate) - DefaultCallback.Create().Register("gorm:save_after_associations", SaveAfterAssociations) - DefaultCallback.Create().Register("gorm:after_create", AfterCreate) - DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) +// afterCreateCallback will invoke `AfterCreate`, `AfterSave` method after creating +func afterCreateCallback(scope *Scope) { + if !scope.HasError() { + scope.CallMethod("AfterCreate") + } + if !scope.HasError() { + scope.CallMethod("AfterSave") + } } diff --git a/callback_delete.go b/callback_delete.go index 72236659..c8ffcc82 100644 --- a/callback_delete.go +++ b/callback_delete.go @@ -2,35 +2,52 @@ package gorm import "fmt" -func BeforeDelete(scope *Scope) { - scope.CallMethodWithErrorCheck("BeforeDelete") +// Define callbacks for deleting +func init() { + DefaultCallback.Delete().Register("gorm:begin_transaction", beginTransactionCallback) + DefaultCallback.Delete().Register("gorm:before_delete", beforeDeleteCallback) + DefaultCallback.Delete().Register("gorm:delete", deleteCallback) + DefaultCallback.Delete().Register("gorm:after_delete", afterDeleteCallback) + DefaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) } -func Delete(scope *Scope) { +// beforeDeleteCallback will invoke `BeforeDelete` method before deleting +func beforeDeleteCallback(scope *Scope) { if !scope.HasError() { - if !scope.Search.Unscoped && scope.HasColumn("DeletedAt") { - scope.Raw( - fmt.Sprintf("UPDATE %v SET deleted_at=%v %v", - scope.QuotedTableName(), - scope.AddToVars(NowFunc()), - scope.CombinedConditionSql(), - )) - } else { - scope.Raw(fmt.Sprintf("DELETE FROM %v %v", scope.QuotedTableName(), scope.CombinedConditionSql())) - } - - scope.Exec() + scope.CallMethod("BeforeDelete") } } -func AfterDelete(scope *Scope) { - scope.CallMethodWithErrorCheck("AfterDelete") +// deleteCallback used to delete data from database or set deleted_at to current time (when using with soft delete) +func deleteCallback(scope *Scope) { + if !scope.HasError() { + var extraOption string + if str, ok := scope.Get("gorm:delete_option"); ok { + extraOption = fmt.Sprint(str) + } + + if !scope.Search.Unscoped && scope.HasColumn("DeletedAt") { + scope.Raw(fmt.Sprintf( + "UPDATE %v SET deleted_at=%v%v%v", + scope.QuotedTableName(), + scope.AddToVars(NowFunc()), + addExtraSpaceIfExist(scope.CombinedConditionSql()), + addExtraSpaceIfExist(extraOption), + )).Exec() + } else { + scope.Raw(fmt.Sprintf( + "DELETE FROM %v%v%v", + scope.QuotedTableName(), + addExtraSpaceIfExist(scope.CombinedConditionSql()), + addExtraSpaceIfExist(extraOption), + )).Exec() + } + } } -func init() { - DefaultCallback.Delete().Register("gorm:begin_transaction", BeginTransaction) - DefaultCallback.Delete().Register("gorm:before_delete", BeforeDelete) - DefaultCallback.Delete().Register("gorm:delete", Delete) - DefaultCallback.Delete().Register("gorm:after_delete", AfterDelete) - DefaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) +// afterDeleteCallback will invoke `AfterDelete` method after deleting +func afterDeleteCallback(scope *Scope) { + if !scope.HasError() { + scope.CallMethod("AfterDelete") + } } diff --git a/callback_query.go b/callback_query.go index b80f5be5..08678f92 100644 --- a/callback_query.go +++ b/callback_query.go @@ -6,115 +6,89 @@ import ( "reflect" ) -func Query(scope *Scope) { - defer scope.Trace(NowFunc()) +// Define callbacks for querying +func init() { + DefaultCallback.Query().Register("gorm:query", queryCallback) + DefaultCallback.Query().Register("gorm:preload", preloadCallback) + DefaultCallback.Query().Register("gorm:after_query", afterQueryCallback) +} + +// queryCallback used to query data from database +func queryCallback(scope *Scope) { + defer scope.trace(NowFunc()) var ( - isSlice bool - isPtr bool - anyRecordFound bool - destType reflect.Type + isSlice bool + isPtr bool + results = scope.IndirectValue() + resultType reflect.Type ) if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok { - if primaryKey := scope.PrimaryKey(); primaryKey != "" { - scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryKey), orderBy)) + if primaryField := scope.PrimaryField(); primaryField != nil { + scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryField.DBName), orderBy)) } } - var dest = scope.IndirectValue() if value, ok := scope.Get("gorm:query_destination"); ok { - dest = reflect.Indirect(reflect.ValueOf(value)) + results = reflect.Indirect(reflect.ValueOf(value)) } - if kind := dest.Kind(); kind == reflect.Slice { + if kind := results.Kind(); kind == reflect.Slice { isSlice = true - destType = dest.Type().Elem() - dest.Set(reflect.MakeSlice(dest.Type(), 0, 0)) + resultType = results.Type().Elem() + results.Set(reflect.MakeSlice(results.Type(), 0, 0)) - if destType.Kind() == reflect.Ptr { + if resultType.Kind() == reflect.Ptr { isPtr = true - destType = destType.Elem() + resultType = resultType.Elem() } } else if kind != reflect.Struct { scope.Err(errors.New("unsupported destination, should be slice or struct")) return } - scope.prepareQuerySql() + scope.prepareQuerySQL() if !scope.HasError() { - rows, err := scope.SqlDB().Query(scope.Sql, scope.SqlVars...) scope.db.RowsAffected = 0 - - if scope.Err(err) != nil { - return + if str, ok := scope.Get("gorm:query_option"); ok { + scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str)) } - defer rows.Close() - columns, _ := rows.Columns() - for rows.Next() { - scope.db.RowsAffected++ + if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { + defer rows.Close() - anyRecordFound = true - elem := dest - if isSlice { - elem = reflect.New(destType).Elem() - } + columns, _ := rows.Columns() + for rows.Next() { + scope.db.RowsAffected++ - var values = make([]interface{}, len(columns)) + elem := results + if isSlice { + elem = reflect.New(resultType).Elem() + } - fields := scope.New(elem.Addr().Interface()).Fields() + scope.scan(rows, columns, scope.New(elem.Addr().Interface()).fieldsMap()) - for index, column := range columns { - if field, ok := fields[column]; ok { - if field.Field.Kind() == reflect.Ptr { - values[index] = field.Field.Addr().Interface() + if isSlice { + if isPtr { + results.Set(reflect.Append(results, elem.Addr())) } else { - reflectValue := reflect.New(reflect.PtrTo(field.Struct.Type)) - reflectValue.Elem().Set(field.Field.Addr()) - values[index] = reflectValue.Interface() - } - } else { - var value interface{} - values[index] = &value - } - } - - scope.Err(rows.Scan(values...)) - - for index, column := range columns { - value := values[index] - if field, ok := fields[column]; ok { - if field.Field.Kind() == reflect.Ptr { - field.Field.Set(reflect.ValueOf(value).Elem()) - } else if v := reflect.ValueOf(value).Elem().Elem(); v.IsValid() { - field.Field.Set(v) + results.Set(reflect.Append(results, elem)) } } } - if isSlice { - if isPtr { - dest.Set(reflect.Append(dest, elem.Addr())) - } else { - dest.Set(reflect.Append(dest, elem)) - } + if scope.db.RowsAffected == 0 && !isSlice { + scope.Err(ErrRecordNotFound) } } - - if !anyRecordFound && !isSlice { - scope.Err(RecordNotFound) - } } } -func AfterQuery(scope *Scope) { - scope.CallMethodWithErrorCheck("AfterFind") -} - -func init() { - DefaultCallback.Query().Register("gorm:query", Query) - DefaultCallback.Query().Register("gorm:preload", Preload) - DefaultCallback.Query().Register("gorm:after_query", AfterQuery) +// afterQueryCallback will invoke `AfterFind` method after querying +func afterQueryCallback(scope *Scope) { + if !scope.HasError() { + scope.CallMethod("AfterFind") + } } diff --git a/callback_query_preload.go b/callback_query_preload.go new file mode 100644 index 00000000..1c9bbc84 --- /dev/null +++ b/callback_query_preload.go @@ -0,0 +1,308 @@ +package gorm + +import ( + "errors" + "fmt" + "reflect" + "strings" +) + +// preloadCallback used to preload associations +func preloadCallback(scope *Scope) { + if scope.Search.preload == nil || scope.HasError() { + return + } + + var ( + preloadedMap = map[string]bool{} + fields = scope.Fields() + ) + + for _, preload := range scope.Search.preload { + var ( + preloadFields = strings.Split(preload.schema, ".") + currentScope = scope + currentFields = fields + ) + + for idx, preloadField := range preloadFields { + var currentPreloadConditions []interface{} + + // if not preloaded + if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] { + + // assign search conditions to last preload + if idx == len(preloadFields)-1 { + currentPreloadConditions = preload.conditions + } + + for _, field := range currentFields { + if field.Name != preloadField || field.Relationship == nil { + continue + } + + switch field.Relationship.Kind { + case "has_one": + currentScope.handleHasOnePreload(field, currentPreloadConditions) + case "has_many": + currentScope.handleHasManyPreload(field, currentPreloadConditions) + case "belongs_to": + currentScope.handleBelongsToPreload(field, currentPreloadConditions) + case "many_to_many": + currentScope.handleManyToManyPreload(field, currentPreloadConditions) + default: + scope.Err(errors.New("unsupported relation")) + } + + preloadedMap[preloadKey] = true + break + } + + if !preloadedMap[preloadKey] { + scope.Err(fmt.Errorf("can't preload field %s for %s", preloadField, currentScope.GetModelStruct().ModelType)) + return + } + } + + // preload next level + if idx < len(preloadFields)-1 { + currentScope = currentScope.getColumnAsScope(preloadField) + currentFields = currentScope.Fields() + } + } + } +} + +func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*DB, []interface{}) { + var ( + preloadDB = scope.NewDB() + preloadConditions []interface{} + ) + + for _, condition := range conditions { + if scopes, ok := condition.(func(*DB) *DB); ok { + preloadDB = scopes(preloadDB) + } else { + preloadConditions = append(preloadConditions, condition) + } + } + + return preloadDB, preloadConditions +} + +// handleHasOnePreload used to preload has one associations +func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) { + relation := field.Relationship + + // get relations's primary keys + primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value) + if len(primaryKeys) == 0 { + return + } + + // preload conditions + preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) + + // find relations + results := makeSlice(field.Struct.Type) + scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error) + + // assign find results + var ( + resultsValue = indirect(reflect.ValueOf(results)) + indirectScopeValue = scope.IndirectValue() + ) + + for i := 0; i < resultsValue.Len(); i++ { + result := resultsValue.Index(i) + if indirectScopeValue.Kind() == reflect.Slice { + foreignValues := getValueFromFields(result, relation.ForeignFieldNames) + for j := 0; j < indirectScopeValue.Len(); j++ { + if indirectValue := indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) { + indirectValue.FieldByName(field.Name).Set(result) + break + } + } + } else { + scope.Err(field.Set(result)) + } + } +} + +// handleHasManyPreload used to preload has many associations +func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) { + relation := field.Relationship + + // get relations's primary keys + primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value) + if len(primaryKeys) == 0 { + return + } + + // preload conditions + preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) + + // find relations + results := makeSlice(field.Struct.Type) + scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error) + + // assign find results + var ( + resultsValue = indirect(reflect.ValueOf(results)) + indirectScopeValue = scope.IndirectValue() + ) + + if indirectScopeValue.Kind() == reflect.Slice { + for i := 0; i < resultsValue.Len(); i++ { + result := resultsValue.Index(i) + foreignValues := getValueFromFields(result, relation.ForeignFieldNames) + for j := 0; j < indirectScopeValue.Len(); j++ { + object := indirect(indirectScopeValue.Index(j)) + if equalAsString(getValueFromFields(object, relation.AssociationForeignFieldNames), foreignValues) { + objectField := object.FieldByName(field.Name) + objectField.Set(reflect.Append(objectField, result)) + break + } + } + } + } else { + scope.Err(field.Set(resultsValue)) + } +} + +// handleBelongsToPreload used to preload belongs to associations +func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) { + relation := field.Relationship + + // preload conditions + preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) + + // get relations's primary keys + primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value) + if len(primaryKeys) == 0 { + return + } + + // find relations + results := makeSlice(field.Struct.Type) + scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error) + + // assign find results + var ( + resultsValue = indirect(reflect.ValueOf(results)) + indirectScopeValue = scope.IndirectValue() + ) + + for i := 0; i < resultsValue.Len(); i++ { + result := resultsValue.Index(i) + if indirectScopeValue.Kind() == reflect.Slice { + value := getValueFromFields(result, relation.AssociationForeignFieldNames) + for j := 0; j < indirectScopeValue.Len(); j++ { + object := indirect(indirectScopeValue.Index(j)) + if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) { + object.FieldByName(field.Name).Set(result) + } + } + } else { + scope.Err(field.Set(result)) + } + } +} + +// handleManyToManyPreload used to preload many to many associations +func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) { + var ( + relation = field.Relationship + joinTableHandler = relation.JoinTableHandler + fieldType = field.Struct.Type.Elem() + foreignKeyValue interface{} + foreignKeyType = reflect.ValueOf(&foreignKeyValue).Type() + linkHash = map[string][]reflect.Value{} + isPtr bool + ) + + if fieldType.Kind() == reflect.Ptr { + isPtr = true + fieldType = fieldType.Elem() + } + + var sourceKeys = []string{} + for _, key := range joinTableHandler.SourceForeignKeys() { + sourceKeys = append(sourceKeys, key.DBName) + } + + // preload conditions + preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) + + // generate query with join table + newScope := scope.New(reflect.New(fieldType).Interface()) + preloadDB = preloadDB.Table(newScope.TableName()).Select("*") + preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value) + + // preload inline conditions + if len(preloadConditions) > 0 { + preloadDB = preloadDB.Where(preloadConditions[0], preloadConditions[1:]...) + } + + rows, err := preloadDB.Rows() + + if scope.Err(err) != nil { + return + } + defer rows.Close() + + columns, _ := rows.Columns() + for rows.Next() { + var ( + elem = reflect.New(fieldType).Elem() + fields = scope.New(elem.Addr().Interface()).fieldsMap() + ) + + // register foreign keys in join tables + for _, sourceKey := range sourceKeys { + fields[sourceKey] = &Field{Field: reflect.New(foreignKeyType).Elem()} + } + + scope.scan(rows, columns, fields) + + // generate hashed forkey keys in join table + var foreignKeys = make([]interface{}, len(sourceKeys)) + for idx, sourceKey := range sourceKeys { + foreignKeys[idx] = fields[sourceKey].Field.Elem().Interface() + } + hashedSourceKeys := toString(foreignKeys) + + if isPtr { + linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem.Addr()) + } else { + linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem) + } + } + + // assign find results + var ( + indirectScopeValue = scope.IndirectValue() + fieldsSourceMap = map[string]reflect.Value{} + foreignFieldNames = []string{} + fields = scope.fieldsMap() + ) + + for _, dbName := range relation.ForeignFieldNames { + if field, ok := fields[dbName]; ok { + foreignFieldNames = append(foreignFieldNames, field.Name) + } + } + + if indirectScopeValue.Kind() == reflect.Slice { + for j := 0; j < indirectScopeValue.Len(); j++ { + object := indirect(indirectScopeValue.Index(j)) + fieldsSourceMap[toString(getValueFromFields(object, foreignFieldNames))] = object.FieldByName(field.Name) + } + } else if indirectScopeValue.IsValid() { + fieldsSourceMap[toString(getValueFromFields(indirectScopeValue, foreignFieldNames))] = indirectScopeValue.FieldByName(field.Name) + } + + for source, link := range linkHash { + fieldsSourceMap[source].Set(reflect.Append(fieldsSourceMap[source], link...)) + } +} diff --git a/callback_shared.go b/callback_save.go similarity index 93% rename from callback_shared.go rename to callback_save.go index 547059e3..a525b709 100644 --- a/callback_shared.go +++ b/callback_save.go @@ -2,15 +2,15 @@ package gorm import "reflect" -func BeginTransaction(scope *Scope) { +func beginTransactionCallback(scope *Scope) { scope.Begin() } -func CommitOrRollbackTransaction(scope *Scope) { +func commitOrRollbackTransactionCallback(scope *Scope) { scope.CommitOrRollback() } -func SaveBeforeAssociations(scope *Scope) { +func saveBeforeAssociationsCallback(scope *Scope) { if !scope.shouldSaveAssociations() { return } @@ -32,7 +32,7 @@ func SaveBeforeAssociations(scope *Scope) { } } -func SaveAfterAssociations(scope *Scope) { +func saveAfterAssociationsCallback(scope *Scope) { if !scope.shouldSaveAssociations() { return } diff --git a/callback_test.go b/callback_test.go index b416d6af..13ca3f42 100644 --- a/callback_test.go +++ b/callback_test.go @@ -23,7 +23,7 @@ func afterCreate1(s *Scope) {} func afterCreate2(s *Scope) {} func TestRegisterCallback(t *testing.T) { - var callback = &callback{processors: []*callbackProcessor{}} + var callback = &Callback{} callback.Create().Register("before_create1", beforeCreate1) callback.Create().Register("before_create2", beforeCreate2) @@ -37,7 +37,7 @@ func TestRegisterCallback(t *testing.T) { } func TestRegisterCallbackWithOrder(t *testing.T) { - var callback1 = &callback{processors: []*callbackProcessor{}} + var callback1 = &Callback{} callback1.Create().Register("before_create1", beforeCreate1) callback1.Create().Register("create", create) callback1.Create().Register("after_create1", afterCreate1) @@ -46,7 +46,7 @@ func TestRegisterCallbackWithOrder(t *testing.T) { t.Errorf("register callback with order") } - var callback2 = &callback{processors: []*callbackProcessor{}} + var callback2 = &Callback{} callback2.Update().Register("create", create) callback2.Update().Before("create").Register("before_create1", beforeCreate1) @@ -60,7 +60,7 @@ func TestRegisterCallbackWithOrder(t *testing.T) { } func TestRegisterCallbackWithComplexOrder(t *testing.T) { - var callback1 = &callback{processors: []*callbackProcessor{}} + var callback1 = &Callback{} callback1.Query().Before("after_create1").After("before_create1").Register("create", create) callback1.Query().Register("before_create1", beforeCreate1) @@ -70,7 +70,7 @@ func TestRegisterCallbackWithComplexOrder(t *testing.T) { t.Errorf("register callback with order") } - var callback2 = &callback{processors: []*callbackProcessor{}} + var callback2 = &Callback{} callback2.Delete().Before("after_create1").After("before_create1").Register("create", create) callback2.Delete().Before("create").Register("before_create1", beforeCreate1) @@ -86,7 +86,7 @@ func TestRegisterCallbackWithComplexOrder(t *testing.T) { func replaceCreate(s *Scope) {} func TestReplaceCallback(t *testing.T) { - var callback = &callback{processors: []*callbackProcessor{}} + var callback = &Callback{} callback.Create().Before("after_create1").After("before_create1").Register("create", create) callback.Create().Register("before_create1", beforeCreate1) @@ -99,7 +99,7 @@ func TestReplaceCallback(t *testing.T) { } func TestRemoveCallback(t *testing.T) { - var callback = &callback{processors: []*callbackProcessor{}} + var callback = &Callback{} callback.Create().Before("after_create1").After("before_create1").Register("create", create) callback.Create().Register("before_create1", beforeCreate1) diff --git a/callback_update.go b/callback_update.go index 4c9952d2..192d8a9e 100644 --- a/callback_update.go +++ b/callback_update.go @@ -5,91 +5,102 @@ import ( "strings" ) -func AssignUpdateAttributes(scope *Scope) { +// Define callbacks for updating +func init() { + DefaultCallback.Update().Register("gorm:assign_updating_attributes", assignUpdatingAttributesCallback) + DefaultCallback.Update().Register("gorm:begin_transaction", beginTransactionCallback) + DefaultCallback.Update().Register("gorm:before_update", beforeUpdateCallback) + DefaultCallback.Update().Register("gorm:save_before_associations", saveBeforeAssociationsCallback) + DefaultCallback.Update().Register("gorm:update_time_stamp", updateTimeStampForUpdateCallback) + DefaultCallback.Update().Register("gorm:update", updateCallback) + DefaultCallback.Update().Register("gorm:save_after_associations", saveAfterAssociationsCallback) + DefaultCallback.Update().Register("gorm:after_update", afterUpdateCallback) + DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) +} + +// assignUpdatingAttributesCallback assign updating attributes to model +func assignUpdatingAttributesCallback(scope *Scope) { if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok { if maps := convertInterfaceToMap(attrs); len(maps) > 0 { - protected, ok := scope.Get("gorm:ignore_protected_attrs") - _, updateColumn := scope.Get("gorm:update_column") - updateAttrs, hasUpdate := scope.updatedAttrsWithValues(maps, ok && protected.(bool)) - - if updateColumn { - scope.InstanceSet("gorm:update_attrs", maps) - } else if len(updateAttrs) > 0 { - scope.InstanceSet("gorm:update_attrs", updateAttrs) - } else if !hasUpdate { + if updateMaps, hasUpdate := scope.updatedAttrsWithValues(maps); hasUpdate { + scope.InstanceSet("gorm:update_attrs", updateMaps) + } else { scope.SkipLeft() - return } } } } -func BeforeUpdate(scope *Scope) { +// beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating +func beforeUpdateCallback(scope *Scope) { if _, ok := scope.Get("gorm:update_column"); !ok { - scope.CallMethodWithErrorCheck("BeforeSave") - scope.CallMethodWithErrorCheck("BeforeUpdate") + if !scope.HasError() { + scope.CallMethod("BeforeSave") + } + if !scope.HasError() { + scope.CallMethod("BeforeUpdate") + } } } -func UpdateTimeStampWhenUpdate(scope *Scope) { +// updateTimeStampForUpdateCallback will set `UpdatedAt` when updating +func updateTimeStampForUpdateCallback(scope *Scope) { if _, ok := scope.Get("gorm:update_column"); !ok { scope.SetColumn("UpdatedAt", NowFunc()) } } -func Update(scope *Scope) { +// updateCallback the callback used to update data to database +func updateCallback(scope *Scope) { if !scope.HasError() { var sqls []string if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { - for key, value := range updateAttrs.(map[string]interface{}) { - if scope.changeableDBColumn(key) { - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(key), scope.AddToVars(value))) - } + for column, value := range updateAttrs.(map[string]interface{}) { + sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value))) } } else { - fields := scope.Fields() - for _, field := range fields { - if scope.changeableField(field) && !field.IsPrimaryKey && field.IsNormal { - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) - } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { - for _, dbName := range relationship.ForeignDBNames { - if relationField := fields[dbName]; !scope.changeableField(relationField) && !relationField.IsBlank { - sql := fmt.Sprintf("%v = %v", scope.Quote(relationField.DBName), scope.AddToVars(relationField.Field.Interface())) - sqls = append(sqls, sql) + for _, field := range scope.Fields() { + if scope.changeableField(field) { + if !field.IsPrimaryKey && field.IsNormal { + sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) + } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { + for _, foreignKey := range relationship.ForeignDBNames { + if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) { + sqls = append(sqls, + fmt.Sprintf("%v = %v", scope.Quote(foreignField.DBName), scope.AddToVars(foreignField.Field.Interface()))) + } } } } } } + var extraOption string + if str, ok := scope.Get("gorm:update_option"); ok { + extraOption = fmt.Sprint(str) + } + if len(sqls) > 0 { scope.Raw(fmt.Sprintf( - "UPDATE %v SET %v %v", + "UPDATE %v SET %v%v%v", scope.QuotedTableName(), strings.Join(sqls, ", "), - scope.CombinedConditionSql(), - )) - scope.Exec() + addExtraSpaceIfExist(scope.CombinedConditionSql()), + addExtraSpaceIfExist(extraOption), + )).Exec() } } } -func AfterUpdate(scope *Scope) { +// afterUpdateCallback will invoke `AfterUpdate`, `AfterSave` method after updating +func afterUpdateCallback(scope *Scope) { if _, ok := scope.Get("gorm:update_column"); !ok { - scope.CallMethodWithErrorCheck("AfterUpdate") - scope.CallMethodWithErrorCheck("AfterSave") + if !scope.HasError() { + scope.CallMethod("AfterUpdate") + } + if !scope.HasError() { + scope.CallMethod("AfterSave") + } } } - -func init() { - DefaultCallback.Update().Register("gorm:assign_update_attributes", AssignUpdateAttributes) - DefaultCallback.Update().Register("gorm:begin_transaction", BeginTransaction) - DefaultCallback.Update().Register("gorm:before_update", BeforeUpdate) - DefaultCallback.Update().Register("gorm:save_before_associations", SaveBeforeAssociations) - DefaultCallback.Update().Register("gorm:update_time_stamp_when_update", UpdateTimeStampWhenUpdate) - DefaultCallback.Update().Register("gorm:update", Update) - DefaultCallback.Update().Register("gorm:save_after_associations", SaveAfterAssociations) - DefaultCallback.Update().Register("gorm:after_update", AfterUpdate) - DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) -} diff --git a/common_dialect.go b/common_dialect.go deleted file mode 100644 index 7f08b04f..00000000 --- a/common_dialect.go +++ /dev/null @@ -1,117 +0,0 @@ -package gorm - -import ( - "fmt" - "reflect" - "time" -) - -type commonDialect struct{} - -func (commonDialect) BinVar(i int) string { - return "$$" // ? -} - -func (commonDialect) SupportLastInsertId() bool { - return true -} - -func (commonDialect) HasTop() bool { - return false -} - -func (commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) string { - switch value.Kind() { - case reflect.Bool: - return "BOOLEAN" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if autoIncrease { - return "INTEGER AUTO_INCREMENT" - } - return "INTEGER" - case reflect.Int64, reflect.Uint64: - if autoIncrease { - return "BIGINT AUTO_INCREMENT" - } - return "BIGINT" - case reflect.Float32, reflect.Float64: - return "FLOAT" - case reflect.String: - if size > 0 && size < 65532 { - return fmt.Sprintf("VARCHAR(%d)", size) - } - return "VARCHAR(65532)" - case reflect.Struct: - if _, ok := value.Interface().(time.Time); ok { - return "TIMESTAMP" - } - default: - if _, ok := value.Interface().([]byte); ok { - if size > 0 && size < 65532 { - return fmt.Sprintf("BINARY(%d)", size) - } - return "BINARY(65532)" - } - } - panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", value.Type().Name(), value.Kind().String())) -} - -func (commonDialect) ReturningStr(tableName, key string) string { - return "" -} - -func (commonDialect) SelectFromDummyTable() string { - return "" -} - -func (commonDialect) Quote(key string) string { - return fmt.Sprintf(`"%s"`, key) -} - -func (c commonDialect) HasTable(scope *Scope, tableName string) bool { - var ( - count int - databaseName = c.CurrentDatabase(scope) - ) - c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", databaseName, tableName) - return count > 0 -} - -func (c commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool { - var ( - count int - databaseName = c.CurrentDatabase(scope) - ) - c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName) - return count > 0 -} - -func (c commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool { - var ( - count int - databaseName = c.CurrentDatabase(scope) - ) - c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", databaseName, tableName, indexName) - return count > 0 -} - -func (commonDialect) RemoveIndex(scope *Scope, indexName string) { - scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())).Error) -} - -// RawScanInt scans the first column of the first row into the `scan' int pointer. -// This function captures raw query errors and propagates them to the original scope. -func (commonDialect) RawScanInt(scope *Scope, scanPtr *int, query string, args ...interface{}) { - scope.Err(scope.NewDB().Raw(query, args...).Row().Scan(scanPtr)) -} - -// RawScanString scans the first column of the first row into the `scan' string pointer. -// This function captures raw query errors and propagates them to the original scope. -func (commonDialect) RawScanString(scope *Scope, scanPtr *string, query string, args ...interface{}) { - scope.Err(scope.NewDB().Raw(query, args...).Row().Scan(scanPtr)) -} - -func (commonDialect) CurrentDatabase(scope *Scope) (name string) { - scope.Err(scope.NewDB().Raw("SELECT DATABASE()").Row().Scan(&name)) - return -} diff --git a/customize_column_test.go b/customize_column_test.go index 93bab2e1..177b4a5d 100644 --- a/customize_column_test.go +++ b/customize_column_test.go @@ -26,7 +26,7 @@ func TestCustomizeColumn(t *testing.T) { DB.AutoMigrate(&CustomizeColumn{}) scope := DB.NewScope(&CustomizeColumn{}) - if !scope.Dialect().HasColumn(scope, scope.TableName(), col) { + if !scope.Dialect().HasColumn(scope.TableName(), col) { t.Errorf("CustomizeColumn should have column %s", col) } diff --git a/ddl_errors_test.go b/ddl_errors_test.go index aca59553..2c31b354 100644 --- a/ddl_errors_test.go +++ b/ddl_errors_test.go @@ -17,8 +17,7 @@ func TestDdlErrors(t *testing.T) { } }() - DB.HasTable("foobarbaz") - if DB.Error == nil { + if err := DB.Find(&User{}).Error; err == nil { t.Errorf("Expected operation on closed db to produce an error, but err was nil") } } diff --git a/delete_test.go b/delete_test.go index e0c71660..d3de0a6d 100644 --- a/delete_test.go +++ b/delete_test.go @@ -45,7 +45,7 @@ func TestSoftDelete(t *testing.T) { type User struct { Id int64 Name string - DeletedAt time.Time + DeletedAt *time.Time } DB.AutoMigrate(&User{}) diff --git a/dialect.go b/dialect.go index 926f8a11..6c9405da 100644 --- a/dialect.go +++ b/dialect.go @@ -1,41 +1,100 @@ package gorm import ( + "database/sql" "fmt" "reflect" + "strconv" + "strings" ) +// Dialect interface contains behaviors that differ across SQL database type Dialect interface { - BinVar(i int) string - SupportLastInsertId() bool - HasTop() bool - SqlTag(value reflect.Value, size int, autoIncrease bool) string - ReturningStr(tableName, key string) string - SelectFromDummyTable() string + // GetName get dialect's name + GetName() string + + // SetDB set db for dialect + SetDB(db *sql.DB) + + // BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1 + BindVar(i int) string + // Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name Quote(key string) string - HasTable(scope *Scope, tableName string) bool - HasColumn(scope *Scope, tableName string, columnName string) bool - HasIndex(scope *Scope, tableName string, indexName string) bool - RemoveIndex(scope *Scope, indexName string) - CurrentDatabase(scope *Scope) string + // DataTypeOf return data's sql type + DataTypeOf(field *StructField) string + + // HasIndex check has index or not + HasIndex(tableName string, indexName string) bool + // HasForeignKey check has foreign key or not + HasForeignKey(tableName string, foreignKeyName string) bool + // RemoveIndex remove index + RemoveIndex(tableName string, indexName string) error + // HasTable check has table or not + HasTable(tableName string) bool + // HasColumn check has column or not + HasColumn(tableName string, columnName string) bool + + // LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case + LimitAndOffsetSQL(limit, offset int) string + // SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL` + SelectFromDummyTable() string + // LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING` + LastInsertIDReturningSuffix(tableName, columnName string) string } -func NewDialect(driver string) Dialect { - var d Dialect - switch driver { - case "postgres": - d = &postgres{} - case "foundation": - d = &foundation{} - case "mysql": - d = &mysql{} - case "sqlite3": - d = &sqlite3{} - case "mssql": - d = &mssql{} - default: - fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", driver) - d = &commonDialect{} +var dialectsMap = map[string]Dialect{} + +func newDialect(name string, db *sql.DB) Dialect { + if value, ok := dialectsMap[name]; ok { + dialect := reflect.New(reflect.TypeOf(value).Elem()).Interface().(Dialect) + dialect.SetDB(db) + return dialect } - return d + + fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", name) + commontDialect := &commonDialect{} + commontDialect.SetDB(db) + return commontDialect +} + +// RegisterDialect register new dialect +func RegisterDialect(name string, dialect Dialect) { + dialectsMap[name] = dialect +} + +// ParseFieldStructForDialect parse field struct for dialect +func ParseFieldStructForDialect(field *StructField) (fieldValue reflect.Value, sqlType string, size int, additionalType string) { + // Get redirected field type + var reflectType = field.Struct.Type + for reflectType.Kind() == reflect.Ptr { + reflectType = reflectType.Elem() + } + + // Get redirected field value + fieldValue = reflect.Indirect(reflect.New(reflectType)) + + // Get scanner's real value + var getScannerValue func(reflect.Value) + getScannerValue = func(value reflect.Value) { + fieldValue = value + if _, isScanner := reflect.New(fieldValue.Type()).Interface().(sql.Scanner); isScanner && fieldValue.Kind() == reflect.Struct { + getScannerValue(fieldValue.Field(0)) + } + } + getScannerValue(fieldValue) + + // Default Size + if num, ok := field.TagSettings["SIZE"]; ok { + size, _ = strconv.Atoi(num) + } else { + size = 255 + } + + // Default type from tag setting + additionalType = field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"] + if value, ok := field.TagSettings["DEFAULT"]; ok { + additionalType = additionalType + " DEFAULT " + value + } + + return fieldValue, field.TagSettings["TYPE"], size, strings.TrimSpace(additionalType) } diff --git a/dialect_common.go b/dialect_common.go new file mode 100644 index 00000000..f009271b --- /dev/null +++ b/dialect_common.go @@ -0,0 +1,137 @@ +package gorm + +import ( + "database/sql" + "fmt" + "reflect" + "strings" + "time" +) + +type commonDialect struct { + db *sql.DB +} + +func init() { + RegisterDialect("common", &commonDialect{}) +} + +func (commonDialect) GetName() string { + return "common" +} + +func (s *commonDialect) SetDB(db *sql.DB) { + s.db = db +} + +func (commonDialect) BindVar(i int) string { + return "$$" // ? +} + +func (commonDialect) Quote(key string) string { + return fmt.Sprintf(`"%s"`, key) +} + +func (commonDialect) DataTypeOf(field *StructField) string { + var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) + + if sqlType == "" { + switch dataValue.Kind() { + case reflect.Bool: + sqlType = "BOOLEAN" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok { + sqlType = "INTEGER AUTO_INCREMENT" + } else { + sqlType = "INTEGER" + } + case reflect.Int64, reflect.Uint64: + if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok { + sqlType = "BIGINT AUTO_INCREMENT" + } else { + sqlType = "BIGINT" + } + case reflect.Float32, reflect.Float64: + sqlType = "FLOAT" + case reflect.String: + if size > 0 && size < 65532 { + sqlType = fmt.Sprintf("VARCHAR(%d)", size) + } else { + sqlType = "VARCHAR(65532)" + } + case reflect.Struct: + if _, ok := dataValue.Interface().(time.Time); ok { + sqlType = "TIMESTAMP" + } + default: + if _, ok := dataValue.Interface().([]byte); ok { + if size > 0 && size < 65532 { + sqlType = fmt.Sprintf("BINARY(%d)", size) + } else { + sqlType = "BINARY(65532)" + } + } + } + } + + if sqlType == "" { + panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", dataValue.Type().Name(), dataValue.Kind().String())) + } + + if strings.TrimSpace(additionalType) == "" { + return sqlType + } + return fmt.Sprintf("%v %v", sqlType, additionalType) +} + +func (s commonDialect) HasIndex(tableName string, indexName string) bool { + var count int + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", s.currentDatabase(), tableName, indexName).Scan(&count) + return count > 0 +} + +func (s commonDialect) RemoveIndex(tableName string, indexName string) error { + _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v", indexName)) + return err +} + +func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool { + return false +} + +func (s commonDialect) HasTable(tableName string) bool { + var count int + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", s.currentDatabase(), tableName).Scan(&count) + return count > 0 +} + +func (s commonDialect) HasColumn(tableName string, columnName string) bool { + var count int + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.currentDatabase(), tableName, columnName).Scan(&count) + return count > 0 +} + +func (s commonDialect) currentDatabase() (name string) { + s.db.QueryRow("SELECT DATABASE()").Scan(&name) + return +} + +func (commonDialect) LimitAndOffsetSQL(limit, offset int) (sql string) { + if limit > 0 || offset > 0 { + if limit >= 0 { + sql += fmt.Sprintf(" LIMIT %d", limit) + } + if offset >= 0 { + sql += fmt.Sprintf(" OFFSET %d", offset) + } + } + return +} + +func (commonDialect) SelectFromDummyTable() string { + return "" +} + +func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string { + return "" +} diff --git a/dialect_mysql.go b/dialect_mysql.go new file mode 100644 index 00000000..a6946d5c --- /dev/null +++ b/dialect_mysql.go @@ -0,0 +1,113 @@ +package gorm + +import ( + "fmt" + "reflect" + "strings" + "time" +) + +type mysql struct { + commonDialect +} + +func init() { + RegisterDialect("mysql", &mysql{}) +} + +func (mysql) GetName() string { + return "mysql" +} + +func (mysql) Quote(key string) string { + return fmt.Sprintf("`%s`", key) +} + +// Get Data Type for MySQL Dialect +func (mysql) DataTypeOf(field *StructField) string { + var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) + + if sqlType == "" { + switch dataValue.Kind() { + case reflect.Bool: + sqlType = "boolean" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32: + if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + sqlType = "int AUTO_INCREMENT" + } else { + sqlType = "int" + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + sqlType = "int unsigned AUTO_INCREMENT" + } else { + sqlType = "int unsigned" + } + case reflect.Int64: + if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + sqlType = "bigint AUTO_INCREMENT" + } else { + sqlType = "bigint" + } + case reflect.Uint64: + if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + sqlType = "bigint unsigned AUTO_INCREMENT" + } else { + sqlType = "bigint unsigned" + } + case reflect.Float32, reflect.Float64: + sqlType = "double" + case reflect.String: + if size > 0 && size < 65532 { + sqlType = fmt.Sprintf("varchar(%d)", size) + } else { + sqlType = "longtext" + } + case reflect.Struct: + if _, ok := dataValue.Interface().(time.Time); ok { + if _, ok := field.TagSettings["NOT NULL"]; ok { + sqlType = "timestamp" + } else { + sqlType = "timestamp NULL" + } + } + default: + if _, ok := dataValue.Interface().([]byte); ok { + if size > 0 && size < 65532 { + sqlType = fmt.Sprintf("varbinary(%d)", size) + } else { + sqlType = "longblob" + } + } + } + } + + if sqlType == "" { + panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", dataValue.Type().Name(), dataValue.Kind().String())) + } + + if strings.TrimSpace(additionalType) == "" { + return sqlType + } + return fmt.Sprintf("%v %v", sqlType, additionalType) +} + +func (s mysql) RemoveIndex(tableName string, indexName string) error { + _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName))) + return err +} + +func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool { + var count int + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", s.currentDatabase(), foreignKeyName).Scan(&count) + return count > 0 +} + +func (s mysql) currentDatabase() (name string) { + s.db.QueryRow("SELECT DATABASE()").Scan(&name) + return +} + +func (mysql) SelectFromDummyTable() string { + return "FROM DUAL" +} diff --git a/dialect_postgres.go b/dialect_postgres.go new file mode 100644 index 00000000..18cbe028 --- /dev/null +++ b/dialect_postgres.go @@ -0,0 +1,128 @@ +package gorm + +import ( + "fmt" + "reflect" + "strings" + "time" +) + +type postgres struct { + commonDialect +} + +func init() { + RegisterDialect("postgres", &postgres{}) +} + +func (postgres) GetName() string { + return "postgres" +} + +func (postgres) BindVar(i int) string { + return fmt.Sprintf("$%v", i) +} + +func (postgres) DataTypeOf(field *StructField) string { + var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) + + if sqlType == "" { + switch dataValue.Kind() { + case reflect.Bool: + sqlType = "boolean" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + sqlType = "serial" + } else { + sqlType = "integer" + } + case reflect.Int64, reflect.Uint64: + if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + sqlType = "bigserial" + } else { + sqlType = "bigint" + } + case reflect.Float32, reflect.Float64: + sqlType = "numeric" + case reflect.String: + if size > 0 && size < 65532 { + sqlType = fmt.Sprintf("varchar(%d)", size) + } else { + sqlType = "text" + } + case reflect.Struct: + if _, ok := dataValue.Interface().(time.Time); ok { + sqlType = "timestamp with time zone" + } + case reflect.Map: + if dataValue.Type().Name() == "Hstore" { + sqlType = "hstore" + } + default: + if isByteArrayOrSlice(dataValue) { + sqlType = "bytea" + } else if isUUID(dataValue) { + sqlType = "uuid" + } + } + } + + if sqlType == "" { + panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", dataValue.Type().Name(), dataValue.Kind().String())) + } + + if strings.TrimSpace(additionalType) == "" { + return sqlType + } + return fmt.Sprintf("%v %v", sqlType, additionalType) +} + +func (s postgres) HasIndex(tableName string, indexName string) bool { + var count int + s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2", tableName, indexName).Scan(&count) + return count > 0 +} + +func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool { + var count int + s.db.QueryRow("SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", s.currentDatabase(), foreignKeyName).Scan(&count) + return count > 0 +} + +func (s postgres) HasTable(tableName string) bool { + var count int + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE'", tableName).Scan(&count) + return count > 0 +} + +func (s postgres) HasColumn(tableName string, columnName string) bool { + var count int + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2", tableName, columnName).Scan(&count) + return count > 0 +} + +func (s postgres) currentDatabase() (name string) { + s.db.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name) + return +} + +func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string { + return fmt.Sprintf("RETURNING %v.%v", tableName, key) +} + +func (postgres) SupportLastInsertID() bool { + return false +} + +func isByteArrayOrSlice(value reflect.Value) bool { + return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0)) +} + +func isUUID(value reflect.Value) bool { + if value.Kind() != reflect.Array || value.Type().Len() != 16 { + return false + } + typename := value.Type().Name() + lower := strings.ToLower(typename) + return "uuid" == lower || "guid" == lower +} diff --git a/dialect_sqlite3.go b/dialect_sqlite3.go new file mode 100644 index 00000000..5c262aaf --- /dev/null +++ b/dialect_sqlite3.go @@ -0,0 +1,106 @@ +package gorm + +import ( + "fmt" + "reflect" + "strings" + "time" +) + +type sqlite3 struct { + commonDialect +} + +func init() { + RegisterDialect("sqlite", &sqlite3{}) + RegisterDialect("sqlite3", &sqlite3{}) +} + +func (sqlite3) GetName() string { + return "sqlite3" +} + +// Get Data Type for Sqlite Dialect +func (sqlite3) DataTypeOf(field *StructField) string { + var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) + + if sqlType == "" { + switch dataValue.Kind() { + case reflect.Bool: + sqlType = "bool" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + if field.IsPrimaryKey { + sqlType = "integer primary key autoincrement" + } else { + sqlType = "integer" + } + case reflect.Int64, reflect.Uint64: + if field.IsPrimaryKey { + sqlType = "integer primary key autoincrement" + } else { + sqlType = "bigint" + } + case reflect.Float32, reflect.Float64: + sqlType = "real" + case reflect.String: + if size > 0 && size < 65532 { + sqlType = fmt.Sprintf("varchar(%d)", size) + } else { + sqlType = "text" + } + case reflect.Struct: + if _, ok := dataValue.Interface().(time.Time); ok { + sqlType = "datetime" + } + default: + if _, ok := dataValue.Interface().([]byte); ok { + sqlType = "blob" + } + } + } + + if sqlType == "" { + panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", dataValue.Type().Name(), dataValue.Kind().String())) + } + + if strings.TrimSpace(additionalType) == "" { + return sqlType + } + return fmt.Sprintf("%v %v", sqlType, additionalType) +} + +func (s sqlite3) HasIndex(tableName string, indexName string) bool { + var count int + s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count) + return count > 0 +} + +func (s sqlite3) HasTable(tableName string) bool { + var count int + s.db.QueryRow("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count) + return count > 0 +} + +func (s sqlite3) HasColumn(tableName string, columnName string) bool { + var count int + s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');\n", columnName, columnName), tableName).Scan(&count) + return count > 0 +} + +func (s sqlite3) currentDatabase() (name string) { + var ( + ifaces = make([]interface{}, 3) + pointers = make([]*string, 3) + i int + ) + for i = 0; i < 3; i++ { + ifaces[i] = &pointers[i] + } + if err := s.db.QueryRow("PRAGMA database_list").Scan(ifaces...); err != nil { + return + } + if pointers[1] != nil { + name = *pointers[1] + } + return +} diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go new file mode 100644 index 00000000..5b994f9d --- /dev/null +++ b/dialects/mssql/mssql.go @@ -0,0 +1,150 @@ +package mssql + +import ( + "database/sql" + "fmt" + "reflect" + "strings" + "time" + + _ "github.com/denisenkom/go-mssqldb" + "github.com/jinzhu/gorm" +) + +func setIdentityInsert(scope *gorm.Scope) { + if scope.Dialect().GetName() == "mssql" { + scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v ON", scope.TableName())) + } +} + +func init() { + gorm.DefaultCallback.Create().After("gorm:begin_transaction").Register("mssql:set_identity_insert", setIdentityInsert) + gorm.RegisterDialect("mssql", &mssql{}) +} + +type mssql struct { + db *sql.DB +} + +func (mssql) GetName() string { + return "mssql" +} + +func (s *mssql) SetDB(db *sql.DB) { + s.db = db +} + +func (mssql) BindVar(i int) string { + return "$$" // ? +} + +func (mssql) Quote(key string) string { + return fmt.Sprintf(`"%s"`, key) +} + +func (mssql) DataTypeOf(field *gorm.StructField) string { + var dataValue, sqlType, size, additionalType = gorm.ParseFieldStructForDialect(field) + + if sqlType == "" { + switch dataValue.Kind() { + case reflect.Bool: + sqlType = "bit" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + sqlType = "int IDENTITY(1,1)" + } else { + sqlType = "int" + } + case reflect.Int64, reflect.Uint64: + if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + sqlType = "bigint IDENTITY(1,1)" + } else { + sqlType = "bigint" + } + case reflect.Float32, reflect.Float64: + sqlType = "float" + case reflect.String: + if size > 0 && size < 65532 { + sqlType = fmt.Sprintf("nvarchar(%d)", size) + } else { + sqlType = "text" + } + case reflect.Struct: + if _, ok := dataValue.Interface().(time.Time); ok { + sqlType = "datetime2" + } + default: + if _, ok := dataValue.Interface().([]byte); ok { + if size > 0 && size < 65532 { + sqlType = fmt.Sprintf("varchar(%d)", size) + } else { + sqlType = "text" + } + } + } + } + + if sqlType == "" { + panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", dataValue.Type().Name(), dataValue.Kind().String())) + } + + if strings.TrimSpace(additionalType) == "" { + return sqlType + } + return fmt.Sprintf("%v %v", sqlType, additionalType) +} + +func (s mssql) HasIndex(tableName string, indexName string) bool { + var count int + s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count) + return count > 0 +} + +func (s mssql) RemoveIndex(tableName string, indexName string) error { + _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName))) + return err +} + +func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool { + return false +} + +func (s mssql) HasTable(tableName string) bool { + var count int + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, s.currentDatabase()).Scan(&count) + return count > 0 +} + +func (s mssql) HasColumn(tableName string, columnName string) bool { + var count int + s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", s.currentDatabase(), tableName, columnName).Scan(&count) + return count > 0 +} + +func (s mssql) currentDatabase() (name string) { + s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name) + return +} + +func (mssql) LimitAndOffsetSQL(limit, offset int) (sql string) { + if limit > 0 || offset > 0 { + if offset < 0 { + offset = 0 + } + + sql += fmt.Sprintf(" OFFSET %d ROWS", offset) + + if limit >= 0 { + sql += fmt.Sprintf(" FETCH NEXT %d ROWS ONLY", limit) + } + } + return +} + +func (mssql) SelectFromDummyTable() string { + return "" +} + +func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string { + return "" +} diff --git a/dialects/mysql/mysql.go b/dialects/mysql/mysql.go new file mode 100644 index 00000000..9deba48a --- /dev/null +++ b/dialects/mysql/mysql.go @@ -0,0 +1,3 @@ +package mysql + +import _ "github.com/go-sql-driver/mysql" diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go new file mode 100644 index 00000000..37881090 --- /dev/null +++ b/dialects/postgres/postgres.go @@ -0,0 +1,52 @@ +package postgres + +import ( + "database/sql" + "database/sql/driver" + + _ "github.com/lib/pq" + "github.com/lib/pq/hstore" +) + +type Hstore map[string]*string + +func (h Hstore) Value() (driver.Value, error) { + hstore := hstore.Hstore{Map: map[string]sql.NullString{}} + if len(h) == 0 { + return nil, nil + } + + for key, value := range h { + var s sql.NullString + if value != nil { + s.String = *value + s.Valid = true + } + hstore.Map[key] = s + } + return hstore.Value() +} + +func (h *Hstore) Scan(value interface{}) error { + hstore := hstore.Hstore{} + + if err := hstore.Scan(value); err != nil { + return err + } + + if len(hstore.Map) == 0 { + return nil + } + + *h = Hstore{} + for k := range hstore.Map { + if hstore.Map[k].Valid { + s := hstore.Map[k].String + (*h)[k] = &s + } else { + (*h)[k] = nil + } + } + + return nil +} diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go new file mode 100644 index 00000000..069ad3a9 --- /dev/null +++ b/dialects/sqlite/sqlite.go @@ -0,0 +1,3 @@ +package sqlite + +import _ "github.com/mattn/go-sqlite3" diff --git a/doc/development.md b/doc/development.md deleted file mode 100644 index 08166661..00000000 --- a/doc/development.md +++ /dev/null @@ -1,68 +0,0 @@ -# Gorm Development - -## Architecture - -The most notable component of Gorm is`gorm.DB`, which hold database connection. It could be initialized like this: - - db, err := gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable") - -Gorm has chainable API, `gorm.DB` is the bridge of chains, it save related information and pass it to the next chain. - -Lets use below code to explain how it works: - - db.Where("name = ?", "jinzhu").Find(&users) - - // equivalent code - newdb := db.Where("name =?", "jinzhu") - newdb.Find(&user) - -`newdb` is `db`'s clone, in addition, it contains search conditions from the `Where` method. -`Find` is a query method, it creates a `Scope` instance, and pass it as argument to query callbacks. - -There are four kinds of callbacks corresponds to sql's CURD: create callbacks, update callbacks, query callbacks, delete callbacks. - -## Callbacks - -### Register a new callback - - func updateCreated(scope *Scope) { - if scope.HasColumn("Created") { - scope.SetColumn("Created", NowFunc()) - } - } - - db.Callback().Create().Register("update_created_at", updateCreated) - // register a callback for Create process - -### Delete an existing callback - - db.Callback().Create().Remove("gorm:create") - // delete callback `gorm:create` from Create callbacks - -### Replace an existing callback - - db.Callback().Create().Replace("gorm:create", newCreateFunction) - // replace callback `gorm:create` with new function `newCreateFunction` for Create process - -### Register callback orders - - db.Callback().Create().Before("gorm:create").Register("update_created_at", updateCreated) - db.Callback().Create().After("gorm:create").Register("update_created_at", updateCreated) - db.Callback().Query().After("gorm:query").Register("my_plugin:after_query", afterQuery) - db.Callback().Delete().After("gorm:delete").Register("my_plugin:after_delete", afterDelete) - db.Callback().Update().Before("gorm:update").Register("my_plugin:before_update", beforeUpdate) - db.Callback().Create().Before("gorm:create").After("gorm:before_create").Register("my_plugin:before_create", beforeCreate) - -### Callback API - -Gorm is powered by callbacks, so you could refer below links to learn how to write callbacks - -[Create callbacks](https://github.com/jinzhu/gorm/blob/master/callback_create.go) - -[Update callbacks](https://github.com/jinzhu/gorm/blob/master/callback_update.go) - -[Query callbacks](https://github.com/jinzhu/gorm/blob/master/callback_query.go) - -[Delete callbacks](https://github.com/jinzhu/gorm/blob/master/callback_delete.go) - -View [https://github.com/jinzhu/gorm/blob/master/scope.go](https://github.com/jinzhu/gorm/blob/master/scope.go) for all available API diff --git a/errors.go b/errors.go index 9dfcd2e9..cc66567d 100644 --- a/errors.go +++ b/errors.go @@ -6,25 +6,31 @@ import ( ) var ( - RecordNotFound = errors.New("record not found") - InvalidSql = errors.New("invalid sql") - NoNewAttrs = errors.New("no new attributes") - NoValidTransaction = errors.New("no valid transaction") - CantStartTransaction = errors.New("can't start transaction") + // ErrRecordNotFound record not found error, happens when haven't find any matched data when looking up with a struct + ErrRecordNotFound = errors.New("record not found") + // ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL + ErrInvalidSQL = errors.New("invalid SQL") + // ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback` + ErrInvalidTransaction = errors.New("no valid transaction") + // ErrCantStartTransaction can't start transaction when you are trying to start one with `Begin` + ErrCantStartTransaction = errors.New("can't start transaction") ) type errorsInterface interface { GetErrors() []error } +// Errors contains all happened errors type Errors struct { errors []error } +// GetErrors get all happened errors func (errs Errors) GetErrors() []error { return errs.errors } +// Add add an error func (errs *Errors) Add(err error) { if errors, ok := err.(errorsInterface); ok { for _, err := range errors.GetErrors() { @@ -40,6 +46,7 @@ func (errs *Errors) Add(err error) { } } +// Error format happened errors func (errs Errors) Error() string { var errors = []string{} for _, e := range errs.errors { diff --git a/field.go b/field.go index 2ed4e732..cadc1a72 100644 --- a/field.go +++ b/field.go @@ -7,12 +7,14 @@ import ( "reflect" ) +// Field model field definition type Field struct { *StructField IsBlank bool Field reflect.Value } +// Set set a value to the field func (field *Field) Set(value interface{}) (err error) { if !field.Field.IsValid() { return errors.New("field value not valid") @@ -56,35 +58,34 @@ func (field *Field) Set(value interface{}) (err error) { } // Fields get value's fields -func (scope *Scope) Fields() map[string]*Field { - if scope.fields == nil { - fields := map[string]*Field{} - modelStruct := scope.GetModelStruct() +func (scope *Scope) Fields() []*Field { + var ( + fields []*Field + indirectScopeValue = scope.IndirectValue() + isStruct = indirectScopeValue.Kind() == reflect.Struct + ) - indirectValue := scope.IndirectValue() - isStruct := indirectValue.Kind() == reflect.Struct - for _, structField := range modelStruct.StructFields { - if field, ok := fields[structField.DBName]; !ok || field.IsIgnored { - if isStruct { - fields[structField.DBName] = getField(indirectValue, structField) - } else { - fields[structField.DBName] = &Field{StructField: structField, IsBlank: true} - } + for _, structField := range scope.GetModelStruct().StructFields { + if isStruct { + fieldValue := indirectScopeValue + for _, name := range structField.Names { + fieldValue = reflect.Indirect(fieldValue).FieldByName(name) } + fields = append(fields, &Field{StructField: structField, Field: fieldValue, IsBlank: isBlank(fieldValue)}) + } else { + fields = append(fields, &Field{StructField: structField, IsBlank: true}) } - - scope.fields = fields - return fields } - return scope.fields + + return fields } -func getField(indirectValue reflect.Value, structField *StructField) *Field { - field := &Field{StructField: structField} - for _, name := range structField.Names { - indirectValue = reflect.Indirect(indirectValue).FieldByName(name) +func (scope *Scope) fieldsMap() map[string]*Field { + var results = map[string]*Field{} + for _, field := range scope.Fields() { + if field.IsNormal { + results[field.DBName] = field + } } - field.Field = indirectValue - field.IsBlank = isBlank(indirectValue) - return field + return results } diff --git a/field_test.go b/field_test.go index 2172b059..30e9a778 100644 --- a/field_test.go +++ b/field_test.go @@ -32,12 +32,16 @@ type CalculateFieldCategory struct { func TestCalculateField(t *testing.T) { var field CalculateField - fields := DB.NewScope(&field).Fields() - if fields["children"].Relationship == nil || fields["category"].Relationship == nil { + var scope = DB.NewScope(&field) + if field, ok := scope.FieldByName("Children"); !ok || field.Relationship == nil { t.Errorf("Should calculate fields correctly for the first time") } - if field, ok := fields["embedded_name"]; !ok { + if field, ok := scope.FieldByName("Category"); !ok || field.Relationship == nil { + t.Errorf("Should calculate fields correctly for the first time") + } + + if field, ok := scope.FieldByName("embedded_name"); !ok { t.Errorf("should find embedded field") } else if _, ok := field.TagSettings["NOT NULL"]; !ok { t.Errorf("should find embedded field's tag settings") diff --git a/foundation.go b/foundation.go deleted file mode 100644 index 422fcc60..00000000 --- a/foundation.go +++ /dev/null @@ -1,83 +0,0 @@ -package gorm - -import ( - "fmt" - "reflect" - "time" -) - -type foundation struct { - commonDialect -} - -func (foundation) BinVar(i int) string { - return fmt.Sprintf("$%v", i) -} - -func (foundation) SupportLastInsertId() bool { - return false -} - -func (foundation) SqlTag(value reflect.Value, size int, autoIncrease bool) string { - switch value.Kind() { - case reflect.Bool: - return "boolean" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if autoIncrease { - return "serial" - } - return "int" - case reflect.Int64, reflect.Uint64: - if autoIncrease { - return "bigserial" - } - return "bigint" - case reflect.Float32, reflect.Float64: - return "double" - case reflect.String: - if size > 0 && size < 65532 { - return fmt.Sprintf("varchar(%d)", size) - } - return "clob" - case reflect.Struct: - if _, ok := value.Interface().(time.Time); ok { - return "datetime" - } - default: - if _, ok := value.Interface().([]byte); ok { - return "blob" - } - } - panic(fmt.Sprintf("invalid sql type %s (%s) for foundation", value.Type().Name(), value.Kind().String())) -} - -func (s foundation) ReturningStr(tableName, key string) string { - return fmt.Sprintf("RETURNING %v.%v", tableName, key) -} - -func (s foundation) HasTable(scope *Scope, tableName string) bool { - var count int - s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_schema = current_schema AND table_type = 'TABLE' AND table_name = ?", tableName) - return count > 0 -} - -func (s foundation) HasColumn(scope *Scope, tableName string, columnName string) bool { - var count int - s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = current_schema AND table_name = ? AND column_name = ?", tableName, columnName) - return count > 0 -} - -func (s foundation) RemoveIndex(scope *Scope, indexName string) { - scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", s.Quote(indexName))) -} - -func (s foundation) HasIndex(scope *Scope, tableName string, indexName string) bool { - var count int - s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.indexes WHERE table_schema = current_schema AND table_name = ? AND index_name = ?", tableName, indexName) - return count > 0 -} - -func (s foundation) CurrentDatabase(scope *Scope) (name string) { - s.RawScanString(scope, &name, "SELECT CURRENT_SCHEMA") - return -} diff --git a/images/logger.png b/images/logger.png deleted file mode 100644 index 8c46588f..00000000 Binary files a/images/logger.png and /dev/null differ diff --git a/join_table_handler.go b/join_table_handler.go index 006701a6..18c12a85 100644 --- a/join_table_handler.go +++ b/join_table_handler.go @@ -7,40 +7,54 @@ import ( "strings" ) +// JoinTableHandlerInterface is an interface for how to handle many2many relations type JoinTableHandlerInterface interface { + // initialize join table handler Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) + // Table return join table's table name Table(db *DB) string + // Add create relationship in join table for source and destination Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error + // Delete delete relationship in join table for sources Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error + // JoinWith query with `Join` conditions JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB + // SourceForeignKeys return source foreign keys SourceForeignKeys() []JoinTableForeignKey + // DestinationForeignKeys return destination foreign keys DestinationForeignKeys() []JoinTableForeignKey } +// JoinTableForeignKey join table foreign key struct type JoinTableForeignKey struct { DBName string AssociationDBName string } +// JoinTableSource is a struct that contains model type and foreign keys type JoinTableSource struct { ModelType reflect.Type ForeignKeys []JoinTableForeignKey } +// JoinTableHandler default join table handler type JoinTableHandler struct { TableName string `sql:"-"` Source JoinTableSource `sql:"-"` Destination JoinTableSource `sql:"-"` } +// SourceForeignKeys return source foreign keys func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey { return s.Source.ForeignKeys } +// DestinationForeignKeys return destination foreign keys func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey { return s.Destination.ForeignKeys } +// Setup initialize a default join table handler func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) { s.TableName = tableName @@ -61,11 +75,12 @@ func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, s } } +// Table return join table's table name func (s JoinTableHandler) Table(db *DB) string { return s.TableName } -func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[string]interface{} { +func (s JoinTableHandler) getSearchMap(db *DB, sources ...interface{}) map[string]interface{} { values := map[string]interface{}{} for _, source := range sources { @@ -74,20 +89,25 @@ func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[strin if s.Source.ModelType == modelType { for _, foreignKey := range s.Source.ForeignKeys { - values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface() + if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { + values[foreignKey.DBName] = field.Field.Interface() + } } } else if s.Destination.ModelType == modelType { for _, foreignKey := range s.Destination.ForeignKeys { - values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface() + if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { + values[foreignKey.DBName] = field.Field.Interface() + } } } } return values } -func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1 interface{}, source2 interface{}) error { +// Add create relationship in join table for source and destination +func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error { scope := db.NewScope("") - searchMap := s.GetSearchMap(db, source1, source2) + searchMap := s.getSearchMap(db, source, destination) var assignColumns, binVars, conditions []string var values []interface{} @@ -116,6 +136,7 @@ func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1 return db.Exec(sql, values...).Error } +// Delete delete relationship in join table for sources func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error { var ( scope = db.NewScope(nil) @@ -123,7 +144,7 @@ func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sour values []interface{} ) - for key, value := range s.GetSearchMap(db, sources...) { + for key, value := range s.getSearchMap(db, sources...) { conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) values = append(values, value) } @@ -131,6 +152,7 @@ func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sour return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error } +// JoinWith query with `Join` conditions func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB { var ( scope = db.NewScope(source) @@ -151,10 +173,12 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so for _, foreignKey := range s.Source.ForeignKeys { foreignDBNames = append(foreignDBNames, foreignKey.DBName) - foreignFieldNames = append(foreignFieldNames, scope.Fields()[foreignKey.AssociationDBName].Name) + if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { + foreignFieldNames = append(foreignFieldNames, field.Name) + } } - foreignFieldValues := scope.getColumnAsArray(foreignFieldNames) + foreignFieldValues := scope.getColumnAsArray(foreignFieldNames, scope.Value) var condString string if len(foreignFieldValues) > 0 { @@ -165,7 +189,7 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), toQueryMarks(foreignFieldValues)) - keys := scope.getColumnAsArray(foreignFieldNames) + keys := scope.getColumnAsArray(foreignFieldNames, scope.Value) values = append(values, toQueryValues(keys)) } else { condString = fmt.Sprintf("1 <> 1") @@ -173,8 +197,8 @@ func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, so return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTableName, strings.Join(joinConditions, " AND "))). Where(condString, toQueryValues(foreignFieldValues)...) - } else { - db.Error = errors.New("wrong source type for join table handler") - return db } + + db.Error = errors.New("wrong source type for join table handler") + return db } diff --git a/join_table_test.go b/join_table_test.go index 70e792ed..1a83a9c8 100644 --- a/join_table_test.go +++ b/join_table_test.go @@ -18,7 +18,7 @@ type PersonAddress struct { gorm.JoinTableHandler PersonID int AddressID int - DeletedAt time.Time + DeletedAt *time.Time CreatedAt time.Time } diff --git a/logger.go b/logger.go index fcd0b57c..f4c79ead 100644 --- a/logger.go +++ b/logger.go @@ -8,25 +8,28 @@ import ( "reflect" "regexp" "time" + "unicode" +) + +var ( + defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)} + sqlRegexp = regexp.MustCompile(`(\$\d+)|\?`) ) type logger interface { Print(v ...interface{}) } -type LogWriter interface { +type logWriter interface { Println(v ...interface{}) } +// Logger default logger type Logger struct { - LogWriter + logWriter } -var defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)} - -// Format log -var sqlRegexp = regexp.MustCompile(`(\$\d+)|\?`) - +// Print format & print log func (logger Logger) Print(values ...interface{}) { if len(values) > 1 { level := values[0] @@ -38,29 +41,44 @@ func (logger Logger) Print(values ...interface{}) { // duration messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0)) // sql - var formatedValues []interface{} + var sql string + var formattedValues []string + for _, value := range values[4].([]interface{}) { indirectValue := reflect.Indirect(reflect.ValueOf(value)) if indirectValue.IsValid() { value = indirectValue.Interface() if t, ok := value.(time.Time); ok { - formatedValues = append(formatedValues, fmt.Sprintf("'%v'", t.Format(time.RFC3339))) + formattedValues = append(formattedValues, fmt.Sprintf("'%v'", t.Format(time.RFC3339))) } else if b, ok := value.([]byte); ok { - formatedValues = append(formatedValues, fmt.Sprintf("'%v'", string(b))) + if str := string(b); isPrintable(str) { + formattedValues = append(formattedValues, fmt.Sprintf("'%v'", str)) + } else { + formattedValues = append(formattedValues, "''") + } } else if r, ok := value.(driver.Valuer); ok { if value, err := r.Value(); err == nil && value != nil { - formatedValues = append(formatedValues, fmt.Sprintf("'%v'", value)) + formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value)) } else { - formatedValues = append(formatedValues, "NULL") + formattedValues = append(formattedValues, "NULL") } } else { - formatedValues = append(formatedValues, fmt.Sprintf("'%v'", value)) + formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value)) } } else { - formatedValues = append(formatedValues, fmt.Sprintf("'%v'", value)) + formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value)) } } - messages = append(messages, fmt.Sprintf(sqlRegexp.ReplaceAllString(values[3].(string), "%v"), formatedValues...)) + + var formattedValuesLength = len(formattedValues) + for index, value := range sqlRegexp.Split(values[3].(string), -1) { + sql += value + if index < formattedValuesLength { + sql += formattedValues[index] + } + } + + messages = append(messages, sql) } else { messages = append(messages, "\033[31;1m") messages = append(messages, values[2:]...) @@ -69,3 +87,12 @@ func (logger Logger) Print(values ...interface{}) { logger.Println(messages...) } } + +func isPrintable(s string) bool { + for _, r := range s { + if !unicode.IsPrint(r) { + return false + } + } + return true +} diff --git a/main.go b/main.go index f2eeb5d6..46d64b29 100644 --- a/main.go +++ b/main.go @@ -6,24 +6,14 @@ import ( "fmt" "reflect" "strings" - "time" ) -// NowFunc returns current time, this function is exported in order to be able -// to give the flexibility to the developer to customize it according to their -// needs -// -// e.g: return time.Now().UTC() -// -var NowFunc = func() time.Time { - return time.Now() -} - +// DB contains information for current db connection type DB struct { Value interface{} Error error RowsAffected int64 - callback *callback + callbacks *Callback db sqlCommon parent *DB search *search @@ -36,7 +26,18 @@ type DB struct { joinTableHandlers map[string]JoinTableHandler } -func Open(dialect string, args ...interface{}) (DB, error) { +// Open initialize a new db connection, need to import driver first, e.g: +// +// import _ "github.com/go-sql-driver/mysql" +// func main() { +// db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local") +// } +// GORM has wrapped some drivers, for easier to remember driver's import path, so you could import the mysql driver with +// import _ "github.com/jinzhu/gorm/dialects/mysql" +// // import _ "github.com/jinzhu/gorm/dialects/postgres" +// // import _ "github.com/jinzhu/gorm/dialects/sqlite" +// // import _ "github.com/jinzhu/gorm/dialects/mssql" +func Open(dialect string, args ...interface{}) (*DB, error) { var db DB var err error @@ -44,7 +45,7 @@ func Open(dialect string, args ...interface{}) (DB, error) { err = errors.New("invalid database source") } else { var source string - var dbSql sqlCommon + var dbSQL sqlCommon switch value := args[0].(type) { case string: @@ -55,22 +56,19 @@ func Open(dialect string, args ...interface{}) (DB, error) { driver = value source = args[1].(string) } - if driver == "foundation" { - driver = "postgres" // FoundationDB speaks a postgres-compatible protocol. - } - dbSql, err = sql.Open(driver, source) + dbSQL, err = sql.Open(driver, source) case sqlCommon: source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String() - dbSql = value + dbSQL = value } db = DB{ - dialect: NewDialect(dialect), - logger: defaultLogger, - callback: DefaultCallback, - source: source, - values: map[string]interface{}{}, - db: dbSql, + dialect: newDialect(dialect, dbSQL.(*sql.DB)), + logger: defaultLogger, + callbacks: DefaultCallback, + source: source, + values: map[string]interface{}{}, + db: dbSQL, } db.parent = &db @@ -79,17 +77,20 @@ func Open(dialect string, args ...interface{}) (DB, error) { } } - return db, err + return &db, err } +// Close close current db connection func (s *DB) Close() error { return s.parent.db.(*sql.DB).Close() } +// DB get `*sql.DB` from current connection func (s *DB) DB() *sql.DB { return s.db.(*sql.DB) } +// New clone a new db connection without search conditions func (s *DB) New() *DB { clone := s.clone() clone.search = nil @@ -97,29 +98,32 @@ func (s *DB) New() *DB { return clone } -// NewScope create scope for callbacks, including DB's search information -func (db *DB) NewScope(value interface{}) *Scope { - dbClone := db.clone() +// NewScope create a scope for current operation +func (s *DB) NewScope(value interface{}) *Scope { + dbClone := s.clone() dbClone.Value = value return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value} } -// CommonDB Return the underlying sql.DB or sql.Tx instance. -// Use of this method is discouraged. It's mainly intended to allow -// coexistence with legacy non-GORM code. +// CommonDB return the underlying `*sql.DB` or `*sql.Tx` instance, mainly intended to allow coexistence with legacy non-GORM code. func (s *DB) CommonDB() sqlCommon { return s.db } -func (s *DB) Callback() *callback { - s.parent.callback = s.parent.callback.clone() - return s.parent.callback +// Callback return `Callbacks` container, you could add/change/delete callbacks with it +// db.Callback().Create().Register("update_created_at", updateCreated) +// Refer https://jinzhu.github.io/gorm/development.html#callbacks +func (s *DB) Callback() *Callback { + s.parent.callbacks = s.parent.callbacks.clone() + return s.parent.callbacks } -func (s *DB) SetLogger(l logger) { - s.logger = l +// SetLogger replace default logger +func (s *DB) SetLogger(log logger) { + s.logger = log } +// LogMode set log mode, `true` for detailed logs, `false` for no log, default, will only print error logs func (s *DB) LogMode(enable bool) *DB { if enable { s.logMode = 2 @@ -129,55 +133,82 @@ func (s *DB) LogMode(enable bool) *DB { return s } +// SingularTable use singular table by default func (s *DB) SingularTable(enable bool) { modelStructsMap = newModelStructsMap() s.parent.singularTable = enable } +// Where return a new relation, filter records with given conditions, accepts `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/curd.html#query func (s *DB) Where(query interface{}, args ...interface{}) *DB { return s.clone().search.Where(query, args...).db } +// Or filter records that match before conditions or this one, similar to `Where` func (s *DB) Or(query interface{}, args ...interface{}) *DB { return s.clone().search.Or(query, args...).db } +// Not filter records that don't match current conditions, similar to `Where` func (s *DB) Not(query interface{}, args ...interface{}) *DB { return s.clone().search.Not(query, args...).db } -func (s *DB) Limit(value interface{}) *DB { - return s.clone().search.Limit(value).db +// Limit specify the number of records to be retrieved +func (s *DB) Limit(limit int) *DB { + return s.clone().search.Limit(limit).db } -func (s *DB) Offset(value interface{}) *DB { - return s.clone().search.Offset(value).db +// Offset specify the number of records to skip before starting to return the records +func (s *DB) Offset(offset int) *DB { + return s.clone().search.Offset(offset).db } +// Order specify order when retrieve records from database, set reorder to `true` to overwrite defined conditions func (s *DB) Order(value string, reorder ...bool) *DB { return s.clone().search.Order(value, reorder...).db } +// Select specify fields that you want to retrieve from database when querying, by default, will select all fields; +// When creating/updating, specify fields that you want to save to database func (s *DB) Select(query interface{}, args ...interface{}) *DB { return s.clone().search.Select(query, args...).db } +// Omit specify fields that you want to ignore when saving to database for creating, updating func (s *DB) Omit(columns ...string) *DB { return s.clone().search.Omit(columns...).db } +// Group specify the group method on the find func (s *DB) Group(query string) *DB { return s.clone().search.Group(query).db } +// Having specify HAVING conditions for GROUP BY func (s *DB) Having(query string, values ...interface{}) *DB { return s.clone().search.Having(query, values...).db } -func (s *DB) Joins(query string) *DB { - return s.clone().search.Joins(query).db +// Joins specify Joins conditions +// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) +func (s *DB) Joins(query string, args ...interface{}) *DB { + return s.clone().search.Joins(query, args...).db } +// Scopes pass current database connection to arguments `func(*DB) *DB`, which could be used to add conditions dynamically +// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB { +// return db.Where("amount > ?", 1000) +// } +// +// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB { +// return func (db *gorm.DB) *gorm.DB { +// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status) +// } +// } +// +// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders) +// Refer https://jinzhu.github.io/gorm/curd.html#scopes func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB { for _, f := range funcs { s = f(s) @@ -185,60 +216,91 @@ func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB { return s } +// Unscoped return all record including deleted record, refer Soft Delete https://jinzhu.github.io/gorm/curd.html#soft-delete func (s *DB) Unscoped() *DB { return s.clone().search.unscoped().db } +// Attrs initalize struct with argument if record not found with `FirstOrInit` https://jinzhu.github.io/gorm/curd.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/curd.html#firstorcreate func (s *DB) Attrs(attrs ...interface{}) *DB { return s.clone().search.Attrs(attrs...).db } +// Assign assign result with argument regardless it is found or not with `FirstOrInit` https://jinzhu.github.io/gorm/curd.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/curd.html#firstorcreate func (s *DB) Assign(attrs ...interface{}) *DB { return s.clone().search.Assign(attrs...).db } +// First find first record that match given conditions, order by primary key func (s *DB) First(out interface{}, where ...interface{}) *DB { newScope := s.clone().NewScope(out) newScope.Search.Limit(1) return newScope.Set("gorm:order_by_primary_key", "ASC"). - inlineCondition(where...).callCallbacks(s.parent.callback.queries).db + inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db } +// Last find last record that match given conditions, order by primary key func (s *DB) Last(out interface{}, where ...interface{}) *DB { newScope := s.clone().NewScope(out) newScope.Search.Limit(1) return newScope.Set("gorm:order_by_primary_key", "DESC"). - inlineCondition(where...).callCallbacks(s.parent.callback.queries).db + inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db } +// Find find records that match given conditions func (s *DB) Find(out interface{}, where ...interface{}) *DB { - return s.clone().NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callback.queries).db + return s.clone().NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db } +// Scan scan value to a struct func (s *DB) Scan(dest interface{}) *DB { - return s.clone().NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callback.queries).db + return s.clone().NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db } +// Row return `*sql.Row` with given conditions func (s *DB) Row() *sql.Row { return s.NewScope(s.Value).row() } +// Rows return `*sql.Rows` with given conditions func (s *DB) Rows() (*sql.Rows, error) { return s.NewScope(s.Value).rows() } +// ScanRows scan `*sql.Rows` to give struct +func (s *DB) ScanRows(rows *sql.Rows, result interface{}) error { + var ( + clone = s.clone() + scope = clone.NewScope(result) + columns, err = rows.Columns() + ) + + if clone.AddError(err) == nil { + scope.scan(rows, columns, scope.fieldsMap()) + } + + return clone.Error +} + +// Pluck used to query single column from a model as a map +// var ages []int64 +// db.Find(&users).Pluck("age", &ages) func (s *DB) Pluck(column string, value interface{}) *DB { return s.NewScope(s.Value).pluck(column, value).db } +// Count get how many records for a model func (s *DB) Count(value interface{}) *DB { return s.NewScope(s.Value).count(value).db } +// Related get related associations func (s *DB) Related(value interface{}, foreignKeys ...string) *DB { return s.clone().NewScope(s.Value).related(value, foreignKeys...).db } +// FirstOrInit find first matched record or initalize a new one with given conditions (only works with struct, map conditions) +// https://jinzhu.github.io/gorm/curd.html#firstorinit func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { c := s.clone() if result := c.First(out, where...); result.Error != nil { @@ -247,82 +309,100 @@ func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { } c.NewScope(out).inlineCondition(where...).initialize() } else { - c.NewScope(out).updatedAttrsWithValues(convertInterfaceToMap(c.search.assignAttrs), false) + c.NewScope(out).updatedAttrsWithValues(convertInterfaceToMap(c.search.assignAttrs)) } return c } +// FirstOrCreate find first matched record or create a new one with given conditions (only works with struct, map conditions) +// https://jinzhu.github.io/gorm/curd.html#firstorcreate func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB { c := s.clone() if result := c.First(out, where...); result.Error != nil { if !result.RecordNotFound() { return result } - c.AddError(c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(c.parent.callback.creates).db.Error) + c.AddError(c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(c.parent.callbacks.creates).db.Error) } else if len(c.search.assignAttrs) > 0 { - c.AddError(c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(c.parent.callback.updates).db.Error) + c.AddError(c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(c.parent.callbacks.updates).db.Error) } return c } +// Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/curd.html#update func (s *DB) Update(attrs ...interface{}) *DB { return s.Updates(toSearchableMap(attrs...), true) } +// Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/curd.html#update func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB { return s.clone().NewScope(s.Value). Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0). InstanceSet("gorm:update_interface", values). - callCallbacks(s.parent.callback.updates).db + callCallbacks(s.parent.callbacks.updates).db } +// UpdateColumn update attributes without callbacks, refer: https://jinzhu.github.io/gorm/curd.html#update func (s *DB) UpdateColumn(attrs ...interface{}) *DB { return s.UpdateColumns(toSearchableMap(attrs...)) } +// UpdateColumns update attributes without callbacks, refer: https://jinzhu.github.io/gorm/curd.html#update func (s *DB) UpdateColumns(values interface{}) *DB { return s.clone().NewScope(s.Value). Set("gorm:update_column", true). Set("gorm:save_associations", false). InstanceSet("gorm:update_interface", values). - callCallbacks(s.parent.callback.updates).db + callCallbacks(s.parent.callbacks.updates).db } +// Save update value in database, if the value doesn't have primary key, will insert it func (s *DB) Save(value interface{}) *DB { scope := s.clone().NewScope(value) if scope.PrimaryKeyZero() { - return scope.callCallbacks(s.parent.callback.creates).db + return scope.callCallbacks(s.parent.callbacks.creates).db } - return scope.callCallbacks(s.parent.callback.updates).db + return scope.callCallbacks(s.parent.callbacks.updates).db } +// Create insert the value into database func (s *DB) Create(value interface{}) *DB { scope := s.clone().NewScope(value) - return scope.callCallbacks(s.parent.callback.creates).db + return scope.callCallbacks(s.parent.callbacks.creates).db } +// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition func (s *DB) Delete(value interface{}, where ...interface{}) *DB { - return s.clone().NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callback.deletes).db + return s.clone().NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db } +// Raw use raw sql as conditions, won't run it unless invoked by other methods +// db.Raw("SELECT name, age FROM users WHERE name = ?", 3).Scan(&result) func (s *DB) Raw(sql string, values ...interface{}) *DB { return s.clone().search.Raw(true).Where(sql, values...).db } +// Exec execute raw sql func (s *DB) Exec(sql string, values ...interface{}) *DB { scope := s.clone().NewScope(nil) - generatedSql := scope.buildWhereCondition(map[string]interface{}{"query": sql, "args": values}) - generatedSql = strings.TrimSuffix(strings.TrimPrefix(generatedSql, "("), ")") - scope.Raw(generatedSql) + generatedSQL := scope.buildWhereCondition(map[string]interface{}{"query": sql, "args": values}) + generatedSQL = strings.TrimSuffix(strings.TrimPrefix(generatedSQL, "("), ")") + scope.Raw(generatedSQL) return scope.Exec().db } +// Model specify the model you would like to run db operations +// // update all users's name to `hello` +// db.Model(&User{}).Update("name", "hello") +// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello` +// db.Model(&user).Update("name", "hello") func (s *DB) Model(value interface{}) *DB { c := s.clone() c.Value = value return c } +// Table specify the table you would like to run db operations func (s *DB) Table(name string) *DB { clone := s.clone() clone.search.Table(name) @@ -330,10 +410,12 @@ func (s *DB) Table(name string) *DB { return clone } +// Debug start debug mode func (s *DB) Debug() *DB { return s.clone().LogMode(true) } +// Begin begin a transaction func (s *DB) Begin() *DB { c := s.clone() if db, ok := c.db.(sqlDb); ok { @@ -341,46 +423,56 @@ func (s *DB) Begin() *DB { c.db = interface{}(tx).(sqlCommon) c.AddError(err) } else { - c.AddError(CantStartTransaction) + c.AddError(ErrCantStartTransaction) } return c } +// Commit commit a transaction func (s *DB) Commit() *DB { if db, ok := s.db.(sqlTx); ok { s.AddError(db.Commit()) } else { - s.AddError(NoValidTransaction) + s.AddError(ErrInvalidTransaction) } return s } +// Rollback rollback a transaction func (s *DB) Rollback() *DB { if db, ok := s.db.(sqlTx); ok { s.AddError(db.Rollback()) } else { - s.AddError(NoValidTransaction) + s.AddError(ErrInvalidTransaction) } return s } +// NewRecord check if value's primary key is blank func (s *DB) NewRecord(value interface{}) bool { return s.clone().NewScope(value).PrimaryKeyZero() } +// RecordNotFound check if returning ErrRecordNotFound error func (s *DB) RecordNotFound() bool { - return s.Error == RecordNotFound + for _, err := range s.GetErrors() { + if err == ErrRecordNotFound { + return true + } + } + return false } -// Migrations -func (s *DB) CreateTable(values ...interface{}) *DB { +// CreateTable create table for models +func (s *DB) CreateTable(models ...interface{}) *DB { db := s.clone() - for _, value := range values { - db = db.NewScope(value).createTable().db + for _, model := range models { + db = db.NewScope(model).createTable().db } return db } +// DropTable drop table for models func (s *DB) DropTable(values ...interface{}) *DB { db := s.clone() for _, value := range values { @@ -393,18 +485,18 @@ func (s *DB) DropTable(values ...interface{}) *DB { return db } +// DropTableIfExists drop table if it is exist func (s *DB) DropTableIfExists(values ...interface{}) *DB { db := s.clone() for _, value := range values { - if tableName, ok := value.(string); ok { - db = db.Table(tableName) + if s.HasTable(value) { + db.AddError(s.DropTable(value).Error) } - - db = db.NewScope(value).dropTableIfExists().db } return db } +// HasTable check has table or not func (s *DB) HasTable(value interface{}) bool { var ( scope = s.clone().NewScope(value) @@ -417,69 +509,64 @@ func (s *DB) HasTable(value interface{}) bool { tableName = scope.TableName() } - has := scope.Dialect().HasTable(scope, tableName) + has := scope.Dialect().HasTable(tableName) s.AddError(scope.db.Error) return has } +// AutoMigrate run auto migration for given models, will only add missing fields, won't delete/change current data func (s *DB) AutoMigrate(values ...interface{}) *DB { db := s.clone() for _, value := range values { - db = db.NewScope(value).NeedPtr().autoMigrate().db + db = db.NewScope(value).autoMigrate().db } return db } +// ModifyColumn modify column to type func (s *DB) ModifyColumn(column string, typ string) *DB { scope := s.clone().NewScope(s.Value) scope.modifyColumn(column, typ) return scope.db } +// DropColumn drop a column func (s *DB) DropColumn(column string) *DB { scope := s.clone().NewScope(s.Value) scope.dropColumn(column) return scope.db } -func (s *DB) AddIndex(indexName string, column ...string) *DB { +// AddIndex add index for columns with given name +func (s *DB) AddIndex(indexName string, columns ...string) *DB { scope := s.Unscoped().NewScope(s.Value) - scope.addIndex(false, indexName, column...) + scope.addIndex(false, indexName, columns...) return scope.db } -func (s *DB) AddUniqueIndex(indexName string, column ...string) *DB { +// AddUniqueIndex add unique index for columns with given name +func (s *DB) AddUniqueIndex(indexName string, columns ...string) *DB { scope := s.clone().NewScope(s.Value) - scope.addIndex(true, indexName, column...) + scope.addIndex(true, indexName, columns...) return scope.db } +// RemoveIndex remove index with name func (s *DB) RemoveIndex(indexName string) *DB { scope := s.clone().NewScope(s.Value) scope.removeIndex(indexName) return scope.db } -func (s *DB) CurrentDatabase() string { - var ( - scope = s.clone().NewScope(s.Value) - name = s.dialect.CurrentDatabase(scope) - ) - return name -} - -/* -Add foreign key to the given scope - -Example: - db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT") -*/ +// AddForeignKey Add foreign key to the given scope, e.g: +// db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT") func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB { scope := s.clone().NewScope(s.Value) scope.addForeignKey(field, dest, onDelete, onUpdate) return scope.db } +// Association start `Association Mode` to handler relations things easir in that mode, refer: https://jinzhu.github.io/gorm/associations.html#association-mode func (s *DB) Association(column string) *Association { var err error scope := s.clone().NewScope(s.Value) @@ -491,7 +578,7 @@ func (s *DB) Association(column string) *Association { if field.Relationship == nil || len(field.Relationship.ForeignFieldNames) == 0 { err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type()) } else { - return &Association{Scope: scope, Column: column, Field: field} + return &Association{scope: scope, column: column, field: field} } } else { err = fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column) @@ -501,26 +588,30 @@ func (s *DB) Association(column string) *Association { return &Association{Error: err} } +// Preload preload associations with given conditions +// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) func (s *DB) Preload(column string, conditions ...interface{}) *DB { return s.clone().search.Preload(column, conditions...).db } -// Set set value by name +// Set set setting by name, which could be used in callbacks, will clone a new db, and update its setting func (s *DB) Set(name string, value interface{}) *DB { return s.clone().InstantSet(name, value) } +// InstantSet instant set setting, will affect current db func (s *DB) InstantSet(name string, value interface{}) *DB { s.values[name] = value return s } -// Get get value by name +// Get get setting by name func (s *DB) Get(name string) (value interface{}, ok bool) { value, ok = s.values[name] return } +// SetJoinTableHandler set a model's join table handler for a relation func (s *DB) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) { scope := s.NewScope(source) for _, field := range scope.GetModelStruct().StructFields { @@ -530,7 +621,7 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType handler.Setup(field.Relationship, many2many, source, destination) field.Relationship.JoinTableHandler = handler - if table := handler.Table(s); scope.Dialect().HasTable(scope, table) { + if table := handler.Table(s); scope.Dialect().HasTable(table) { s.Table(table).AutoMigrate(handler) } } @@ -538,9 +629,10 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join } } +// AddError add error to the db func (s *DB) AddError(err error) error { if err != nil { - if err != RecordNotFound { + if err != ErrRecordNotFound { if s.logMode == 0 { go s.print(fileWithLineNum(), err) } else { @@ -559,6 +651,7 @@ func (s *DB) AddError(err error) error { return err } +// GetErrors get happened errors from the db func (s *DB) GetErrors() (errors []error) { if errs, ok := s.Error.(errorsInterface); ok { return errs.GetErrors() diff --git a/main_private.go b/main_private.go index bd097ce0..a6e5a6a9 100644 --- a/main_private.go +++ b/main_private.go @@ -10,7 +10,7 @@ func (s *DB) clone() *DB { } if s.search == nil { - db.search = &search{} + db.search = &search{limit: -1, offset: -1} } else { db.search = s.search.clone() } diff --git a/main_test.go b/main_test.go index 8722c7c0..2da3de94 100644 --- a/main_test.go +++ b/main_test.go @@ -4,23 +4,23 @@ import ( "database/sql" "database/sql/driver" "fmt" - "strconv" - - _ "github.com/denisenkom/go-mssqldb" - testdb "github.com/erikstmartin/go-testdb" - _ "github.com/go-sql-driver/mysql" - "github.com/jinzhu/gorm" - "github.com/jinzhu/now" - _ "github.com/lib/pq" - _ "github.com/mattn/go-sqlite3" - "os" + "reflect" + "strconv" "testing" "time" + + "github.com/erikstmartin/go-testdb" + "github.com/jinzhu/gorm" + _ "github.com/jinzhu/gorm/dialects/mssql" + _ "github.com/jinzhu/gorm/dialects/mysql" + "github.com/jinzhu/gorm/dialects/postgres" + _ "github.com/jinzhu/gorm/dialects/sqlite" + "github.com/jinzhu/now" ) var ( - DB gorm.DB + DB *gorm.DB t1, t2, t3, t4, t5 time.Time ) @@ -42,7 +42,7 @@ func init() { runMigration() } -func OpenTestConnection() (db gorm.DB, err error) { +func OpenTestConnection() (db *gorm.DB, err error) { switch os.Getenv("GORM_DIALECT") { case "mysql": // CREATE USER 'gorm'@'localhost' IDENTIFIED BY 'gorm'; @@ -115,7 +115,7 @@ func TestSetTable(t *testing.T) { DB.Create(getPreparedUser("pluck_user3", "pluck_user")) if err := DB.Table("users").Where("role = ?", "pluck_user").Pluck("age", &[]int{}).Error; err != nil { - t.Errorf("No errors should happen if set table for pluck", err.Error()) + t.Error("No errors should happen if set table for pluck", err) } var users []User @@ -376,7 +376,7 @@ func TestRows(t *testing.T) { rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows() if err != nil { - t.Errorf("Not error should happen, but got") + t.Errorf("Not error should happen, got %v", err) } count := 0 @@ -386,8 +386,39 @@ func TestRows(t *testing.T) { rows.Scan(&name, &age) count++ } + if count != 2 { - t.Errorf("Should found two records with name 3") + t.Errorf("Should found two records") + } +} + +func TestScanRows(t *testing.T) { + user1 := User{Name: "ScanRowsUser1", Age: 1, Birthday: now.MustParse("2000-1-1")} + user2 := User{Name: "ScanRowsUser2", Age: 10, Birthday: now.MustParse("2010-1-1")} + user3 := User{Name: "ScanRowsUser3", Age: 20, Birthday: now.MustParse("2020-1-1")} + DB.Save(&user1).Save(&user2).Save(&user3) + + rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows() + if err != nil { + t.Errorf("Not error should happen, got %v", err) + } + + type Result struct { + Name string + Age int + } + + var results []Result + for rows.Next() { + var result Result + if err := DB.ScanRows(rows, &result); err != nil { + t.Errorf("should get no error, but got %v", err) + } + results = append(results, result) + } + + if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) { + t.Errorf("Should find expected results") } } @@ -448,7 +479,7 @@ func TestRaw(t *testing.T) { } DB.Exec("update users set name=? where name in (?)", "jinzhu", []string{user1.Name, user2.Name, user3.Name}) - if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.RecordNotFound { + if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.ErrRecordNotFound { t.Error("Raw sql to update records") } } @@ -469,15 +500,34 @@ func TestGroup(t *testing.T) { func TestJoins(t *testing.T) { var user = User{ - Name: "joins", - Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}}, + Name: "joins", + CreditCard: CreditCard{Number: "411111111111"}, + Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}}, } DB.Save(&user) - var result User - DB.Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins").First(&result) - if result.Name != "joins" || result.Id != user.Id { - t.Errorf("Should find all two emails with Join") + var users1 []User + DB.Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins").Find(&users1) + if len(users1) != 2 { + t.Errorf("should find two users using left join") + } + + var users2 []User + DB.Joins("left join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Where("name = ?", "joins").First(&users2) + if len(users2) != 1 { + t.Errorf("should find one users using left join with conditions") + } + + var users3 []User + DB.Joins("join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Joins("join credit_cards on credit_cards.user_id = users.id AND credit_cards.number = ?", "411111111111").Where("name = ?", "joins").First(&users3) + if len(users3) != 1 { + t.Errorf("should find one users using multiple left join conditions") + } + + var users4 []User + DB.Joins("join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Joins("join credit_cards on credit_cards.user_id = users.id AND credit_cards.number = ?", "422222222222").Where("name = ?", "joins").First(&users4) + if len(users4) != 0 { + t.Errorf("should find no user when searching with unexisting credit card") } } @@ -557,7 +607,7 @@ func TestTimeWithZone(t *testing.T) { DB.First(&findUser, "name = ?", name) foundBirthday = findUser.Birthday.UTC().Format(format) if foundBirthday != expectedBirthday { - t.Errorf("User's birthday should not be changed after find for name=%s, expected bday=%+v but actual value=%+v or %+v", name, expectedBirthday, foundBirthday) + t.Errorf("User's birthday should not be changed after find for name=%s, expected bday=%+v but actual value=%+v", name, expectedBirthday, foundBirthday) } if DB.Where("id = ? AND birthday >= ?", findUser.Id, user.Birthday.Add(-time.Minute)).First(&findUser2).RecordNotFound() { @@ -573,7 +623,7 @@ func TestTimeWithZone(t *testing.T) { func TestHstore(t *testing.T) { type Details struct { Id int64 - Bulk gorm.Hstore + Bulk postgres.Hstore } if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" { @@ -659,7 +709,7 @@ func TestOpenExistingDB(t *testing.T) { } var user User - if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.RecordNotFound { + if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.ErrRecordNotFound { t.Errorf("Should have found existing record") } } diff --git a/migration_test.go b/migration_test.go index 0411872e..de35c1df 100644 --- a/migration_test.go +++ b/migration_test.go @@ -31,7 +31,7 @@ func TestIndexes(t *testing.T) { } scope := DB.NewScope(&Email{}) - if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email") { + if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") { t.Errorf("Email should have index idx_email_email") } @@ -39,7 +39,7 @@ func TestIndexes(t *testing.T) { t.Errorf("Got error when tried to remove index: %+v", err) } - if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email") { + if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") { t.Errorf("Email's index idx_email_email should be deleted") } @@ -47,7 +47,7 @@ func TestIndexes(t *testing.T) { t.Errorf("Got error when tried to create index: %+v", err) } - if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") { + if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { t.Errorf("Email should have index idx_email_email_and_user_id") } @@ -55,7 +55,7 @@ func TestIndexes(t *testing.T) { t.Errorf("Got error when tried to remove index: %+v", err) } - if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") { + if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { t.Errorf("Email's index idx_email_email_and_user_id should be deleted") } @@ -63,7 +63,7 @@ func TestIndexes(t *testing.T) { t.Errorf("Got error when tried to create index: %+v", err) } - if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") { + if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { t.Errorf("Email should have index idx_email_email_and_user_id") } @@ -85,7 +85,7 @@ func TestIndexes(t *testing.T) { t.Errorf("Got error when tried to remove index: %+v", err) } - if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") { + if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { t.Errorf("Email's index idx_email_email_and_user_id should be deleted") } @@ -117,11 +117,11 @@ func TestAutoMigration(t *testing.T) { DB.Save(&BigEmail{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: time.Now()}) scope := DB.NewScope(&BigEmail{}) - if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_agent") { + if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_agent") { t.Errorf("Failed to create index") } - if !scope.Dialect().HasIndex(scope, scope.TableName(), "uix_emails_registered_at") { + if !scope.Dialect().HasIndex(scope.TableName(), "uix_emails_registered_at") { t.Errorf("Failed to create index") } diff --git a/model.go b/model.go index ffa68b07..a6650877 100644 --- a/model.go +++ b/model.go @@ -2,6 +2,10 @@ package gorm import "time" +// Model base model definition, including fields `ID`, `CreatedAt`, `UpdatedAt`, `DeletedAt`, which could be embeded in your models +// type User struct { +// gorm.Model +// } type Model struct { ID uint `gorm:"primary_key"` CreatedAt time.Time diff --git a/model_struct.go b/model_struct.go index d80165c8..6df615d1 100644 --- a/model_struct.go +++ b/model_struct.go @@ -3,10 +3,8 @@ package gorm import ( "database/sql" "errors" - "fmt" "go/ast" "reflect" - "strconv" "strings" "sync" "time" @@ -14,6 +12,7 @@ import ( "github.com/jinzhu/inflection" ) +// DefaultTableNameHandler default table name handler var DefaultTableNameHandler = func(db *DB, defaultTableName string) string { return defaultTableName } @@ -41,6 +40,7 @@ func newModelStructsMap() *safeModelStructsMap { var modelStructsMap = newModelStructsMap() +// ModelStruct model definition type ModelStruct struct { PrimaryFields []*StructField StructFields []*StructField @@ -48,10 +48,12 @@ type ModelStruct struct { defaultTableName string } +// TableName get model's table name func (s *ModelStruct) TableName(db *DB) string { return DefaultTableNameHandler(db, s.defaultTableName) } +// StructField model field's struct definition type StructField struct { DBName string Name string @@ -107,7 +109,7 @@ func getForeignField(column string, fields []*StructField) *StructField { return nil } -// GetModelStruct generate model struct & relationships based on struct and tag definition +// GetModelStruct get value's model struct, relationships based on struct and tag definition func (scope *Scope) GetModelStruct() *ModelStruct { var modelStruct ModelStruct // Scope value can't be nil @@ -296,7 +298,10 @@ func (scope *Scope) GetModelStruct() *ModelStruct { if len(associationForeignKeys) == 0 { for _, foreignKey := range foreignKeys { if strings.HasPrefix(foreignKey, associationType) { - associationForeignKeys = append(associationForeignKeys, strings.TrimPrefix(foreignKey, associationType)) + associationForeignKey := strings.TrimPrefix(foreignKey, associationType) + if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { + associationForeignKeys = append(associationForeignKeys, associationForeignKey) + } } } if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { @@ -389,7 +394,10 @@ func (scope *Scope) GetModelStruct() *ModelStruct { if len(associationForeignKeys) == 0 { for _, foreignKey := range foreignKeys { if strings.HasPrefix(foreignKey, associationType) { - associationForeignKeys = append(associationForeignKeys, strings.TrimPrefix(foreignKey, associationType)) + associationForeignKey := strings.TrimPrefix(foreignKey, associationType) + if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { + associationForeignKeys = append(associationForeignKeys, associationForeignKey) + } } } if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { @@ -445,7 +453,10 @@ func (scope *Scope) GetModelStruct() *ModelStruct { if len(associationForeignKeys) == 0 { for _, foreignKey := range foreignKeys { if strings.HasPrefix(foreignKey, field.Name) { - associationForeignKeys = append(associationForeignKeys, strings.TrimPrefix(foreignKey, field.Name)) + associationForeignKey := strings.TrimPrefix(foreignKey, field.Name) + if foreignField := getForeignField(associationForeignKey, toFields); foreignField != nil { + associationForeignKeys = append(associationForeignKeys, associationForeignKey) + } } } if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { @@ -508,63 +519,11 @@ func (scope *Scope) GetModelStruct() *ModelStruct { return &modelStruct } +// GetStructFields get model's field structs func (scope *Scope) GetStructFields() (fields []*StructField) { return scope.GetModelStruct().StructFields } -func (scope *Scope) generateSqlTag(field *StructField) string { - var sqlType string - structType := field.Struct.Type - if structType.Kind() == reflect.Ptr { - structType = structType.Elem() - } - reflectValue := reflect.Indirect(reflect.New(structType)) - - if value, ok := field.TagSettings["TYPE"]; ok { - sqlType = value - } - - additionalType := field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"] - if value, ok := field.TagSettings["DEFAULT"]; ok { - additionalType = additionalType + " DEFAULT " + value - } - - if field.IsScanner { - var getScannerValue func(reflect.Value) - getScannerValue = func(value reflect.Value) { - reflectValue = value - if _, isScanner := reflect.New(reflectValue.Type()).Interface().(sql.Scanner); isScanner && reflectValue.Kind() == reflect.Struct { - getScannerValue(reflectValue.Field(0)) - } - } - getScannerValue(reflectValue) - } - - if sqlType == "" { - var size = 255 - - if value, ok := field.TagSettings["SIZE"]; ok { - size, _ = strconv.Atoi(value) - } - - v, autoIncrease := field.TagSettings["AUTO_INCREMENT"] - if field.IsPrimaryKey { - autoIncrease = true - } - if v == "FALSE" { - autoIncrease = false - } - - sqlType = scope.Dialect().SqlTag(reflectValue, size, autoIncrease) - } - - if strings.TrimSpace(additionalType) == "" { - return sqlType - } else { - return fmt.Sprintf("%v %v", sqlType, additionalType) - } -} - func parseTagSetting(tags reflect.StructTag) map[string]string { setting := map[string]string{} for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} { diff --git a/mssql.go b/mssql.go deleted file mode 100644 index a9bd1e52..00000000 --- a/mssql.go +++ /dev/null @@ -1,80 +0,0 @@ -package gorm - -import ( - "fmt" - "reflect" - "time" -) - -type mssql struct { - commonDialect -} - -func (mssql) HasTop() bool { - return true -} - -func (mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string { - switch value.Kind() { - case reflect.Bool: - return "bit" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if autoIncrease { - return "int IDENTITY(1,1)" - } - return "int" - case reflect.Int64, reflect.Uint64: - if autoIncrease { - return "bigint IDENTITY(1,1)" - } - return "bigint" - case reflect.Float32, reflect.Float64: - return "float" - case reflect.String: - if size > 0 && size < 65532 { - return fmt.Sprintf("nvarchar(%d)", size) - } - return "text" - case reflect.Struct: - if _, ok := value.Interface().(time.Time); ok { - return "datetime2" - } - default: - if _, ok := value.Interface().([]byte); ok { - if size > 0 && size < 65532 { - return fmt.Sprintf("varchar(%d)", size) - } - return "text" - } - } - panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String())) -} - -func (s mssql) HasTable(scope *Scope, tableName string) bool { - var ( - count int - databaseName = s.CurrentDatabase(scope) - ) - s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, databaseName) - return count > 0 -} - -func (s mssql) HasColumn(scope *Scope, tableName string, columnName string) bool { - var ( - count int - databaseName = s.CurrentDatabase(scope) - ) - s.RawScanInt(scope, &count, "SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName) - return count > 0 -} - -func (s mssql) HasIndex(scope *Scope, tableName string, indexName string) bool { - var count int - s.RawScanInt(scope, &count, "SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName) - return count > 0 -} - -func (s mssql) CurrentDatabase(scope *Scope) (name string) { - s.RawScanString(scope, &name, "SELECT DB_NAME() AS [Current Database]") - return -} diff --git a/multi_primary_keys_test.go b/multi_primary_keys_test.go index ea80326e..8b275d18 100644 --- a/multi_primary_keys_test.go +++ b/multi_primary_keys_test.go @@ -21,7 +21,7 @@ type Tag struct { ID uint `gorm:"primary_key"` Locale string `gorm:"primary_key"` Value string - Blogs []*Blog `gorm:"many2many:"blogs_tags` + Blogs []*Blog `gorm:"many2many:blogs_tags"` } func compareTags(tags []Tag, contents []string) bool { diff --git a/mysql.go b/mysql.go deleted file mode 100644 index 9e1d56d3..00000000 --- a/mysql.go +++ /dev/null @@ -1,70 +0,0 @@ -package gorm - -import ( - "fmt" - "reflect" - "time" -) - -type mysql struct { - commonDialect -} - -func (mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string { - switch value.Kind() { - case reflect.Bool: - return "boolean" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32: - if autoIncrease { - return "int AUTO_INCREMENT" - } - return "int" - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if autoIncrease { - return "int unsigned AUTO_INCREMENT" - } - return "int unsigned" - case reflect.Int64: - if autoIncrease { - return "bigint AUTO_INCREMENT" - } - return "bigint" - case reflect.Uint64: - if autoIncrease { - return "bigint unsigned AUTO_INCREMENT" - } - return "bigint unsigned" - case reflect.Float32, reflect.Float64: - return "double" - case reflect.String: - if size > 0 && size < 65532 { - return fmt.Sprintf("varchar(%d)", size) - } - return "longtext" - case reflect.Struct: - if _, ok := value.Interface().(time.Time); ok { - return "timestamp NULL" - } - default: - if _, ok := value.Interface().([]byte); ok { - if size > 0 && size < 65532 { - return fmt.Sprintf("varbinary(%d)", size) - } - return "longblob" - } - } - panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String())) -} - -func (mysql) Quote(key string) string { - return fmt.Sprintf("`%s`", key) -} - -func (mysql) SelectFromDummyTable() string { - return "FROM DUAL" -} - -func (s mysql) CurrentDatabase(scope *Scope) (name string) { - s.RawScanString(scope, &name, "SELECT DATABASE()") - return -} diff --git a/pointer_test.go b/pointer_test.go index b47717f3..2a68a5ab 100644 --- a/pointer_test.go +++ b/pointer_test.go @@ -39,46 +39,46 @@ func TestPointerFields(t *testing.T) { var nilPointerStruct = PointerStruct{} if err := DB.Create(&nilPointerStruct).Error; err != nil { - t.Errorf("Failed to save nil pointer struct", err) + t.Error("Failed to save nil pointer struct", err) } var pointerStruct2 PointerStruct if err := DB.First(&pointerStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil { - t.Errorf("Failed to query saved nil pointer struct", err) + t.Error("Failed to query saved nil pointer struct", err) } var normalStruct2 NormalStruct if err := DB.Table(tableName).First(&normalStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil { - t.Errorf("Failed to query saved nil pointer struct", err) + t.Error("Failed to query saved nil pointer struct", err) } var partialNilPointerStruct1 = PointerStruct{Num: &num} if err := DB.Create(&partialNilPointerStruct1).Error; err != nil { - t.Errorf("Failed to save partial nil pointer struct", err) + t.Error("Failed to save partial nil pointer struct", err) } var pointerStruct3 PointerStruct if err := DB.First(&pointerStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || *pointerStruct3.Num != num { - t.Errorf("Failed to query saved partial nil pointer struct", err) + t.Error("Failed to query saved partial nil pointer struct", err) } var normalStruct3 NormalStruct if err := DB.Table(tableName).First(&normalStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || normalStruct3.Num != num { - t.Errorf("Failed to query saved partial pointer struct", err) + t.Error("Failed to query saved partial pointer struct", err) } var partialNilPointerStruct2 = PointerStruct{Name: &name} if err := DB.Create(&partialNilPointerStruct2).Error; err != nil { - t.Errorf("Failed to save partial nil pointer struct", err) + t.Error("Failed to save partial nil pointer struct", err) } var pointerStruct4 PointerStruct if err := DB.First(&pointerStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || *pointerStruct4.Name != name { - t.Errorf("Failed to query saved partial nil pointer struct", err) + t.Error("Failed to query saved partial nil pointer struct", err) } var normalStruct4 NormalStruct if err := DB.Table(tableName).First(&normalStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || normalStruct4.Name != name { - t.Errorf("Failed to query saved partial pointer struct", err) + t.Error("Failed to query saved partial pointer struct", err) } } diff --git a/postgres.go b/postgres.go deleted file mode 100644 index 3b083dfa..00000000 --- a/postgres.go +++ /dev/null @@ -1,154 +0,0 @@ -package gorm - -import ( - "database/sql" - "database/sql/driver" - "fmt" - "reflect" - "strings" - "time" - - "github.com/lib/pq/hstore" -) - -type postgres struct { - commonDialect -} - -func (postgres) BinVar(i int) string { - return fmt.Sprintf("$%v", i) -} - -func (postgres) SupportLastInsertId() bool { - return false -} - -func (postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) string { - switch value.Kind() { - case reflect.Bool: - return "boolean" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if autoIncrease { - return "serial" - } - return "integer" - case reflect.Int64, reflect.Uint64: - if autoIncrease { - return "bigserial" - } - return "bigint" - case reflect.Float32, reflect.Float64: - return "numeric" - case reflect.String: - if size > 0 && size < 65532 { - return fmt.Sprintf("varchar(%d)", size) - } - return "text" - case reflect.Struct: - if _, ok := value.Interface().(time.Time); ok { - return "timestamp with time zone" - } - case reflect.Map: - if value.Type() == hstoreType { - return "hstore" - } - default: - if isByteArrayOrSlice(value) { - return "bytea" - } else if isUUID(value) { - return "uuid" - } - } - panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", value.Type().Name(), value.Kind().String())) -} - -var byteType = reflect.TypeOf(uint8(0)) - -func isByteArrayOrSlice(value reflect.Value) bool { - return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == byteType -} - -func isUUID(value reflect.Value) bool { - if value.Kind() != reflect.Array || value.Type().Len() != 16 { - return false - } - typename := value.Type().Name() - lower := strings.ToLower(typename) - return "uuid" == lower || "guid" == lower -} - -func (s postgres) ReturningStr(tableName, key string) string { - return fmt.Sprintf("RETURNING %v.%v", tableName, key) -} - -func (s postgres) HasTable(scope *Scope, tableName string) bool { - var count int - s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_type = 'BASE TABLE'", tableName) - return count > 0 -} - -func (s postgres) HasColumn(scope *Scope, tableName string, columnName string) bool { - var count int - s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = ? AND column_name = ?", tableName, columnName) - return count > 0 -} - -func (postgres) RemoveIndex(scope *Scope, indexName string) { - scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error) -} - -func (s postgres) HasIndex(scope *Scope, tableName string, indexName string) bool { - var count int - s.RawScanInt(scope, &count, "SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ?", tableName, indexName) - return count > 0 -} - -func (s postgres) CurrentDatabase(scope *Scope) (name string) { - s.RawScanString(scope, &name, "SELECT CURRENT_DATABASE()") - return -} - -var hstoreType = reflect.TypeOf(Hstore{}) - -type Hstore map[string]*string - -func (h Hstore) Value() (driver.Value, error) { - hstore := hstore.Hstore{Map: map[string]sql.NullString{}} - if len(h) == 0 { - return nil, nil - } - - for key, value := range h { - var s sql.NullString - if value != nil { - s.String = *value - s.Valid = true - } - hstore.Map[key] = s - } - return hstore.Value() -} - -func (h *Hstore) Scan(value interface{}) error { - hstore := hstore.Hstore{} - - if err := hstore.Scan(value); err != nil { - return err - } - - if len(hstore.Map) == 0 { - return nil - } - - *h = Hstore{} - for k := range hstore.Map { - if hstore.Map[k].Valid { - s := hstore.Map[k].String - (*h)[k] = &s - } else { - (*h)[k] = nil - } - } - - return nil -} diff --git a/preload.go b/preload.go deleted file mode 100644 index e1ca69ff..00000000 --- a/preload.go +++ /dev/null @@ -1,388 +0,0 @@ -package gorm - -import ( - "database/sql/driver" - "errors" - "fmt" - "reflect" - "strings" -) - -func getRealValue(value reflect.Value, columns []string) (results []interface{}) { - // If value is a nil pointer, Indirect returns a zero Value! - // Therefor we need to check for a zero value, - // as FieldByName could panic - if pointedValue := reflect.Indirect(value); pointedValue.IsValid() { - for _, column := range columns { - if pointedValue.FieldByName(column).IsValid() { - result := pointedValue.FieldByName(column).Interface() - if r, ok := result.(driver.Valuer); ok { - result, _ = r.Value() - } - results = append(results, result) - } - } - } - return -} - -func equalAsString(a interface{}, b interface{}) bool { - return toString(a) == toString(b) -} - -func Preload(scope *Scope) { - if scope.Search.preload == nil || scope.HasError() { - return - } - - preloadMap := map[string]bool{} - fields := scope.Fields() - for _, preload := range scope.Search.preload { - schema, conditions := preload.schema, preload.conditions - keys := strings.Split(schema, ".") - currentScope := scope - currentFields := fields - originalConditions := conditions - conditions = []interface{}{} - for i, key := range keys { - var found bool - if preloadMap[strings.Join(keys[:i+1], ".")] { - goto nextLoop - } - - if i == len(keys)-1 { - conditions = originalConditions - } - - for _, field := range currentFields { - if field.Name != key || field.Relationship == nil { - continue - } - - found = true - switch field.Relationship.Kind { - case "has_one": - currentScope.handleHasOnePreload(field, conditions) - case "has_many": - currentScope.handleHasManyPreload(field, conditions) - case "belongs_to": - currentScope.handleBelongsToPreload(field, conditions) - case "many_to_many": - currentScope.handleManyToManyPreload(field, conditions) - default: - currentScope.Err(errors.New("not supported relation")) - } - break - } - - if !found { - value := reflect.ValueOf(currentScope.Value) - if value.Kind() == reflect.Slice && value.Type().Elem().Kind() == reflect.Interface { - value = value.Index(0).Elem() - } - scope.Err(fmt.Errorf("can't find field %s in %s", key, value.Type())) - return - } - - preloadMap[strings.Join(keys[:i+1], ".")] = true - - nextLoop: - if i < len(keys)-1 { - currentScope = currentScope.getColumnsAsScope(key) - currentFields = currentScope.Fields() - } - } - } - -} - -func makeSlice(typ reflect.Type) interface{} { - if typ.Kind() == reflect.Slice { - typ = typ.Elem() - } - sliceType := reflect.SliceOf(typ) - slice := reflect.New(sliceType) - slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0)) - return slice.Interface() -} - -func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) { - relation := field.Relationship - - primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames) - if len(primaryKeys) == 0 { - return - } - - results := makeSlice(field.Struct.Type) - scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error) - resultValues := reflect.Indirect(reflect.ValueOf(results)) - - for i := 0; i < resultValues.Len(); i++ { - result := resultValues.Index(i) - if scope.IndirectValue().Kind() == reflect.Slice { - value := getRealValue(result, relation.ForeignFieldNames) - objects := scope.IndirectValue() - for j := 0; j < objects.Len(); j++ { - if equalAsString(getRealValue(objects.Index(j), relation.AssociationForeignFieldNames), value) { - reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result) - break - } - } - } else { - if err := scope.SetColumn(field, result); err != nil { - scope.Err(err) - return - } - } - } -} - -func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) { - relation := field.Relationship - primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames) - if len(primaryKeys) == 0 { - return - } - - results := makeSlice(field.Struct.Type) - scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error) - resultValues := reflect.Indirect(reflect.ValueOf(results)) - - if scope.IndirectValue().Kind() == reflect.Slice { - preloadMap := make(map[string][]reflect.Value) - for i := 0; i < resultValues.Len(); i++ { - result := resultValues.Index(i) - value := getRealValue(result, relation.ForeignFieldNames) - preloadMap[toString(value)] = append(preloadMap[toString(value)], result) - } - - objects := scope.IndirectValue() - for j := 0; j < objects.Len(); j++ { - object := reflect.Indirect(objects.Index(j)) - objectRealValue := getRealValue(object, relation.AssociationForeignFieldNames) - objectStringValue := toString(objectRealValue) - if results, ok := preloadMap[objectStringValue]; ok { - if object.Kind() == reflect.Ptr { - object = object.Elem() - } - f := object.FieldByName(field.Name) - f.Set(reflect.Append(f, results...)) - } - } - } else { - scope.SetColumn(field, resultValues) - } -} - -func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) { - relation := field.Relationship - primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames) - if len(primaryKeys) == 0 { - return - } - - results := makeSlice(field.Struct.Type) - scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error) - resultValues := reflect.Indirect(reflect.ValueOf(results)) - - for i := 0; i < resultValues.Len(); i++ { - result := resultValues.Index(i) - if scope.IndirectValue().Kind() == reflect.Slice { - value := getRealValue(result, relation.AssociationForeignFieldNames) - objects := scope.IndirectValue() - for j := 0; j < objects.Len(); j++ { - object := reflect.Indirect(objects.Index(j)) - if object.Kind() == reflect.Ptr { - object = reflect.Indirect(objects.Index(j).Elem()) - } - if equalAsString(getRealValue(object, relation.ForeignFieldNames), value) { - object.FieldByName(field.Name).Set(result) - } - } - } else { - scope.SetColumn(field, result) - } - } -} - -func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) { - relation := field.Relationship - joinTableHandler := relation.JoinTableHandler - destType := field.StructField.Struct.Type.Elem() - var isPtr bool - if destType.Kind() == reflect.Ptr { - isPtr = true - destType = destType.Elem() - } - - var sourceKeys []string - var linkHash = make(map[string][]reflect.Value) - - for _, key := range joinTableHandler.SourceForeignKeys() { - sourceKeys = append(sourceKeys, key.DBName) - } - - db := scope.NewDB().Table(scope.New(reflect.New(destType).Interface()).TableName()).Select("*") - - preloadJoinDB := joinTableHandler.JoinWith(joinTableHandler, db, scope.Value) - - if len(conditions) > 0 { - preloadJoinDB = preloadJoinDB.Where(conditions[0], conditions[1:]...) - } - rows, err := preloadJoinDB.Rows() - - if scope.Err(err) != nil { - return - } - defer rows.Close() - - columns, _ := rows.Columns() - for rows.Next() { - elem := reflect.New(destType).Elem() - var values = make([]interface{}, len(columns)) - - fields := scope.New(elem.Addr().Interface()).Fields() - - var foundFields = map[string]bool{} - for index, column := range columns { - if field, ok := fields[column]; ok && !foundFields[column] { - if field.Field.Kind() == reflect.Ptr { - values[index] = field.Field.Addr().Interface() - } else { - values[index] = reflect.New(reflect.PtrTo(field.Field.Type())).Interface() - } - foundFields[column] = true - } else { - var i interface{} - values[index] = &i - } - } - - scope.Err(rows.Scan(values...)) - - var sourceKey []interface{} - - var scannedFields = map[string]bool{} - for index, column := range columns { - value := values[index] - if field, ok := fields[column]; ok && !scannedFields[column] { - if field.Field.Kind() == reflect.Ptr { - field.Field.Set(reflect.ValueOf(value).Elem()) - } else if v := reflect.ValueOf(value).Elem().Elem(); v.IsValid() { - field.Field.Set(v) - } - scannedFields[column] = true - } else if strInSlice(column, sourceKeys) { - sourceKey = append(sourceKey, *(value.(*interface{}))) - } - } - - if len(sourceKey) != 0 { - if isPtr { - linkHash[toString(sourceKey)] = append(linkHash[toString(sourceKey)], elem.Addr()) - } else { - linkHash[toString(sourceKey)] = append(linkHash[toString(sourceKey)], elem) - } - } - } - - var foreignFieldNames []string - for _, dbName := range relation.ForeignFieldNames { - if field, ok := scope.FieldByName(dbName); ok { - foreignFieldNames = append(foreignFieldNames, field.Name) - } - } - - if scope.IndirectValue().Kind() == reflect.Slice { - objects := scope.IndirectValue() - for j := 0; j < objects.Len(); j++ { - object := reflect.Indirect(objects.Index(j)) - if object.Kind() == reflect.Ptr { - object = object.Elem() - } - source := getRealValue(object, foreignFieldNames) - field := object.FieldByName(field.Name) - for _, link := range linkHash[toString(source)] { - field.Set(reflect.Append(field, link)) - } - } - } else { - if object := scope.IndirectValue(); object.IsValid() { - source := getRealValue(object, foreignFieldNames) - field := object.FieldByName(field.Name) - for _, link := range linkHash[toString(source)] { - field.Set(reflect.Append(field, link)) - } - } - } -} - -func (scope *Scope) getColumnAsArray(columns []string) (results [][]interface{}) { - values := scope.IndirectValue() - switch values.Kind() { - case reflect.Slice: - for i := 0; i < values.Len(); i++ { - var result []interface{} - for _, column := range columns { - value := reflect.Indirect(values.Index(i)) - if value.Kind() == reflect.Ptr { - value = reflect.Indirect(values.Index(i).Elem()) - } - result = append(result, value.FieldByName(column).Interface()) - } - results = append(results, result) - } - case reflect.Struct: - var result []interface{} - for _, column := range columns { - result = append(result, values.FieldByName(column).Interface()) - } - return [][]interface{}{result} - } - return -} - -func (scope *Scope) getColumnsAsScope(column string) *Scope { - values := scope.IndirectValue() - switch values.Kind() { - case reflect.Slice: - modelType := values.Type().Elem() - if modelType.Kind() == reflect.Ptr { - modelType = modelType.Elem() - } - fieldStruct, _ := modelType.FieldByName(column) - var columns reflect.Value - if fieldStruct.Type.Kind() == reflect.Slice || fieldStruct.Type.Kind() == reflect.Ptr { - columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type.Elem()))).Elem() - } else { - columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type))).Elem() - } - for i := 0; i < values.Len(); i++ { - column := reflect.Indirect(values.Index(i)).FieldByName(column) - if column.Kind() == reflect.Ptr { - column = column.Elem() - } - if column.Kind() == reflect.Slice { - for i := 0; i < column.Len(); i++ { - elem := column.Index(i) - if elem.CanAddr() { - columns = reflect.Append(columns, elem.Addr()) - } - } - } else { - if column.CanAddr() { - columns = reflect.Append(columns, column.Addr()) - } - } - } - return scope.New(columns.Interface()) - case reflect.Struct: - field := values.FieldByName(column) - if !field.CanAddr() { - return nil - } - return scope.New(field.Addr().Interface()) - } - return nil -} diff --git a/preload_test.go b/preload_test.go index c5a3a136..cde8e800 100644 --- a/preload_test.go +++ b/preload_test.go @@ -133,7 +133,7 @@ func TestNestedPreload1(t *testing.T) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } - if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got, "name = ?", "not_found").Error; err != gorm.RecordNotFound { + if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got, "name = ?", "not_found").Error; err != gorm.ErrRecordNotFound { t.Error(err) } } @@ -818,90 +818,6 @@ func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) { } } -func TestManyToManyPreloadForPointer(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - } - Level2 struct { - ID uint - Value string - Level1s []*Level1 `gorm:"many2many:levels;"` - } - ) - - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists("levels") - - if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - want := Level2{Value: "Bob", Level1s: []*Level1{ - {Value: "ru"}, - {Value: "en"}, - }} - if err := DB.Save(&want).Error; err != nil { - t.Error(err) - } - - want2 := Level2{Value: "Tom", Level1s: []*Level1{ - {Value: "zh"}, - {Value: "de"}, - }} - if err := DB.Save(&want2).Error; err != nil { - t.Error(err) - } - - var got Level2 - if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } - - var got2 Level2 - if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got2, want2) { - t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2)) - } - - var got3 []Level2 - if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got3, []Level2{got, got2}) { - t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level2{got, got2})) - } - - var got4 []Level2 - if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { - t.Error(err) - } - - var got5 Level2 - DB.Preload("Level1s").First(&got5, "value = ?", "bogus") - - var ruLevel1 Level1 - var zhLevel1 Level1 - DB.First(&ruLevel1, "value = ?", "ru") - DB.First(&zhLevel1, "value = ?", "zh") - - got.Level1s = []*Level1{&ruLevel1} - got2.Level1s = []*Level1{&zhLevel1} - if !reflect.DeepEqual(got4, []Level2{got, got2}) { - t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level2{got, got2})) - } -} - func TestManyToManyPreloadForNestedPointer(t *testing.T) { type ( Level1 struct { @@ -1065,7 +981,7 @@ func TestNestedManyToManyPreload(t *testing.T) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } - if err := DB.Preload("Level2s.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.RecordNotFound { + if err := DB.Preload("Level2s.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound { t.Error(err) } } @@ -1122,12 +1038,87 @@ func TestNestedManyToManyPreload2(t *testing.T) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } - if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.RecordNotFound { + if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound { t.Error(err) } } func TestNestedManyToManyPreload3(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + } + Level2 struct { + ID uint + Value string + Level1s []*Level1 `gorm:"many2many:level1_level2;"` + } + Level3 struct { + ID uint + Value string + Level2ID sql.NullInt64 + Level2 *Level2 + } + ) + + DB.DropTableIfExists(&Level1{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists("level1_level2") + + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + t.Error(err) + } + + level1Zh := &Level1{Value: "zh"} + level1Ru := &Level1{Value: "ru"} + level1En := &Level1{Value: "en"} + + level21 := &Level2{ + Value: "Level2-1", + Level1s: []*Level1{level1Zh, level1Ru}, + } + + level22 := &Level2{ + Value: "Level2-2", + Level1s: []*Level1{level1Zh, level1En}, + } + + wants := []*Level3{ + { + Value: "Level3-1", + Level2: level21, + }, + { + Value: "Level3-2", + Level2: level22, + }, + { + Value: "Level3-3", + Level2: level21, + }, + } + + for _, want := range wants { + if err := DB.Save(&want).Error; err != nil { + t.Error(err) + } + } + + var gots []*Level3 + if err := DB.Preload("Level2.Level1s", func(db *gorm.DB) *gorm.DB { + return db.Order("level1.id ASC") + }).Find(&gots).Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(gots, wants) { + t.Errorf("got %s; want %s", toJSONString(gots), toJSONString(wants)) + } +} + +func TestNestedManyToManyPreload4(t *testing.T) { type ( Level4 struct { ID uint @@ -1185,6 +1176,90 @@ func TestNestedManyToManyPreload3(t *testing.T) { } } +func TestManyToManyPreloadForPointer(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + } + Level2 struct { + ID uint + Value string + Level1s []*Level1 `gorm:"many2many:levels;"` + } + ) + + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + DB.DropTableIfExists("levels") + + if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil { + t.Error(err) + } + + want := Level2{Value: "Bob", Level1s: []*Level1{ + {Value: "ru"}, + {Value: "en"}, + }} + if err := DB.Save(&want).Error; err != nil { + t.Error(err) + } + + want2 := Level2{Value: "Tom", Level1s: []*Level1{ + {Value: "zh"}, + {Value: "de"}, + }} + if err := DB.Save(&want2).Error; err != nil { + t.Error(err) + } + + var got Level2 + if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } + + var got2 Level2 + if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got2, want2) { + t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2)) + } + + var got3 []Level2 + if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got3, []Level2{got, got2}) { + t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level2{got, got2})) + } + + var got4 []Level2 + if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { + t.Error(err) + } + + var got5 Level2 + DB.Preload("Level1s").First(&got5, "value = ?", "bogus") + + var ruLevel1 Level1 + var zhLevel1 Level1 + DB.First(&ruLevel1, "value = ?", "ru") + DB.First(&zhLevel1, "value = ?", "zh") + + got.Level1s = []*Level1{&ruLevel1} + got2.Level1s = []*Level1{&zhLevel1} + if !reflect.DeepEqual(got4, []Level2{got, got2}) { + t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level2{got, got2})) + } +} + func TestNilPointerSlice(t *testing.T) { type ( Level3 struct { @@ -1234,7 +1309,7 @@ func TestNilPointerSlice(t *testing.T) { } if len(got) != 2 { - t.Error("got %v items, expected 2", len(got)) + t.Errorf("got %v items, expected 2", len(got)) } if !reflect.DeepEqual(got[0], want) && !reflect.DeepEqual(got[1], want) { diff --git a/query_test.go b/query_test.go index ed7a518d..b376dc82 100644 --- a/query_test.go +++ b/query_test.go @@ -629,14 +629,3 @@ func TestSelectWithArrayInput(t *testing.T) { t.Errorf("Should have selected both age and name") } } - -func TestCurrentDatabase(t *testing.T) { - databaseName := DB.CurrentDatabase() - if err := DB.Error; err != nil { - t.Errorf("Problem getting current db name: %s", err) - } - if databaseName == "" { - t.Errorf("Current db name returned empty; this should never happen!") - } - t.Logf("Got current db name: %v", databaseName) -} diff --git a/scope.go b/scope.go index a11d4ec4..6239db7a 100644 --- a/scope.go +++ b/scope.go @@ -1,48 +1,32 @@ package gorm import ( + "database/sql" "errors" "fmt" "regexp" "strings" - "time" "reflect" ) +// Scope contain current operation's information when you perform any operation on the database type Scope struct { Search *search Value interface{} - Sql string - SqlVars []interface{} + SQL string + SQLVars []interface{} db *DB - indirectValue *reflect.Value - instanceId string + instanceID string primaryKeyField *Field skipLeft bool fields map[string]*Field selectAttrs *[]string } +// IndirectValue return scope's reflect value's indirect value func (scope *Scope) IndirectValue() reflect.Value { - if scope.indirectValue == nil { - value := reflect.Indirect(reflect.ValueOf(scope.Value)) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - scope.indirectValue = &value - } - return *scope.indirectValue -} - -func (scope *Scope) NeedPtr() *Scope { - reflectKind := reflect.ValueOf(scope.Value).Kind() - if !((reflectKind == reflect.Invalid) || (reflectKind == reflect.Ptr)) { - err := fmt.Errorf("%v %v\n", fileWithLineNum(), "using unaddressable value") - scope.Err(err) - fmt.Printf(err.Error()) - } - return scope + return indirect(reflect.ValueOf(scope.Value)) } // New create a new Scope without search information @@ -61,12 +45,13 @@ func (scope *Scope) NewDB() *DB { return nil } +// DB return scope's DB connection func (scope *Scope) DB() *DB { return scope.db } -// SqlDB return *sql.DB -func (scope *Scope) SqlDB() sqlCommon { +// SQLDB return *sql.DB +func (scope *Scope) SQLDB() sqlCommon { return scope.db.db } @@ -75,7 +60,7 @@ func (scope *Scope) SkipLeft() { scope.skipLeft = true } -// Quote used to quote database column name according to database dialect +// Quote used to quote string to escape them for database func (scope *Scope) Quote(str string) string { if strings.Index(str, ".") != -1 { newStrs := []string{} @@ -83,12 +68,12 @@ func (scope *Scope) Quote(str string) string { newStrs = append(newStrs, scope.Dialect().Quote(str)) } return strings.Join(newStrs, ".") - } else { - return scope.Dialect().Quote(str) } + + return scope.Dialect().Quote(str) } -func (scope *Scope) QuoteIfPossible(str string) string { +func (scope *Scope) quoteIfPossible(str string) string { if regexp.MustCompile("^[a-zA-Z]+(.[a-zA-Z]+)*$").MatchString(str) { return scope.Quote(str) } @@ -100,7 +85,7 @@ func (scope *Scope) Dialect() Dialect { return scope.db.parent.dialect } -// Err write error +// Err add error to Scope func (scope *Scope) Err(err error) error { if err != nil { scope.db.AddError(err) @@ -118,27 +103,30 @@ func (scope *Scope) HasError() bool { return scope.db.Error != nil } -func (scope *Scope) PrimaryFields() []*Field { - var fields = []*Field{} - for _, field := range scope.GetModelStruct().PrimaryFields { - fields = append(fields, scope.Fields()[field.DBName]) +// PrimaryFields return scope's primary fields +func (scope *Scope) PrimaryFields() (fields []*Field) { + for _, field := range scope.Fields() { + if field.IsPrimaryKey { + fields = append(fields, field) + } } return fields } +// PrimaryField return scope's main primary field, if defined more that one primary fields, will return the one having column name `id` or the first one func (scope *Scope) PrimaryField() *Field { if primaryFields := scope.GetModelStruct().PrimaryFields; len(primaryFields) > 0 { if len(primaryFields) > 1 { - if field, ok := scope.Fields()["id"]; ok { + if field, ok := scope.FieldByName("id"); ok { return field } } - return scope.Fields()[primaryFields[0].DBName] + return scope.PrimaryFields()[0] } return nil } -// PrimaryKey get the primary key's column name +// PrimaryKey get main primary field's db name func (scope *Scope) PrimaryKey() string { if field := scope.PrimaryField(); field != nil { return field.DBName @@ -146,7 +134,7 @@ func (scope *Scope) PrimaryKey() string { return "" } -// PrimaryKeyZero check the primary key is blank or not +// PrimaryKeyZero check main primary field's value is blank or not func (scope *Scope) PrimaryKeyZero() bool { field := scope.PrimaryField() return field == nil || field.IsBlank @@ -170,80 +158,85 @@ func (scope *Scope) HasColumn(column string) bool { return false } -// SetColumn to set the column's value +// SetColumn to set the column's value, column could be field or field's name/dbname func (scope *Scope) SetColumn(column interface{}, value interface{}) error { + var updateAttrs = map[string]interface{}{} + if attrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { + updateAttrs = attrs.(map[string]interface{}) + defer scope.InstanceSet("gorm:update_attrs", updateAttrs) + } + if field, ok := column.(*Field); ok { + updateAttrs[field.DBName] = value return field.Set(value) } else if name, ok := column.(string); ok { - - if field, ok := scope.Fields()[name]; ok { - return field.Set(value) + var ( + dbName = ToDBName(name) + mostMatchedField *Field + ) + for _, field := range scope.Fields() { + if field.DBName == value { + updateAttrs[field.DBName] = value + return field.Set(value) + } + if (field.DBName == dbName) || (field.Name == name && mostMatchedField == nil) { + mostMatchedField = field + } } - dbName := ToDBName(name) - if field, ok := scope.Fields()[dbName]; ok { - return field.Set(value) - } - - if field, ok := scope.FieldByName(name); ok { - return field.Set(value) + if mostMatchedField != nil { + updateAttrs[mostMatchedField.DBName] = value + return mostMatchedField.Set(value) } } return errors.New("could not convert column to field") } -func (scope *Scope) CallMethod(name string, checkError bool) { - if scope.Value == nil || (checkError && scope.HasError()) { +func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) { + if reflectValue.CanAddr() { + reflectValue = reflectValue.Addr() + } + + if methodValue := reflectValue.MethodByName(methodName); methodValue.IsValid() { + switch method := methodValue.Interface().(type) { + case func(): + method() + case func(*Scope): + method(scope) + case func(*DB): + newDB := scope.NewDB() + method(newDB) + scope.Err(newDB.Error) + case func() error: + scope.Err(method()) + case func(*Scope) error: + scope.Err(method(scope)) + case func(*DB) error: + newDB := scope.NewDB() + scope.Err(method(newDB)) + scope.Err(newDB.Error) + default: + scope.Err(fmt.Errorf("unsupported function %v", methodName)) + } + } +} + +// CallMethod call scope value's method, if it is a slice, will call its element's method one by one +func (scope *Scope) CallMethod(methodName string) { + if scope.Value == nil { return } - call := func(value interface{}) { - if fm := reflect.ValueOf(value).MethodByName(name); fm.IsValid() { - switch f := fm.Interface().(type) { - case func(): - f() - case func(s *Scope): - f(scope) - case func(s *DB): - newDB := scope.NewDB() - f(newDB) - scope.Err(newDB.Error) - case func() error: - scope.Err(f()) - case func(s *Scope) error: - scope.Err(f(scope)) - case func(s *DB) error: - newDB := scope.NewDB() - scope.Err(f(newDB)) - scope.Err(newDB.Error) - default: - scope.Err(fmt.Errorf("unsupported function %v", name)) - } - } - } - - if values := scope.IndirectValue(); values.Kind() == reflect.Slice { - for i := 0; i < values.Len(); i++ { - value := values.Index(i).Addr().Interface() - if values.Index(i).Kind() == reflect.Ptr { - value = values.Index(i).Interface() - } - call(value) + if indirectScopeValue := scope.IndirectValue(); indirectScopeValue.Kind() == reflect.Slice { + for i := 0; i < indirectScopeValue.Len(); i++ { + scope.callMethod(methodName, indirectScopeValue.Index(i)) } } else { - if scope.IndirectValue().CanAddr() { - call(scope.IndirectValue().Addr().Interface()) - } else { - call(scope.IndirectValue().Interface()) - } + scope.callMethod(methodName, indirectScopeValue) } } -func (scope *Scope) CallMethodWithErrorCheck(name string) { - scope.CallMethod(name, true) -} - -// AddToVars add value as sql's vars, gorm will escape them +// AddToVars add value as sql's vars, used to prevent SQL injection func (scope *Scope) AddToVars(value interface{}) string { if expr, ok := value.(*expr); ok { exp := expr.expr @@ -251,10 +244,10 @@ func (scope *Scope) AddToVars(value interface{}) string { exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) } return exp - } else { - scope.SqlVars = append(scope.SqlVars, value) - return scope.Dialect().BinVar(len(scope.SqlVars)) } + + scope.SQLVars = append(scope.SQLVars, value) + return scope.Dialect().BindVar(len(scope.SQLVars)) } type tabler interface { @@ -265,7 +258,7 @@ type dbTabler interface { TableName(*DB) string } -// TableName get table name +// TableName return table name func (scope *Scope) TableName() string { if scope.Search != nil && len(scope.Search.tableName) > 0 { return scope.Search.tableName @@ -282,44 +275,54 @@ func (scope *Scope) TableName() string { return scope.GetModelStruct().TableName(scope.db.Model(scope.Value)) } +// QuotedTableName return quoted table name func (scope *Scope) QuotedTableName() (name string) { if scope.Search != nil && len(scope.Search.tableName) > 0 { if strings.Index(scope.Search.tableName, " ") != -1 { return scope.Search.tableName } return scope.Quote(scope.Search.tableName) - } else { - return scope.Quote(scope.TableName()) } + + return scope.Quote(scope.TableName()) } -// CombinedConditionSql get combined condition sql +// CombinedConditionSql return combined condition sql func (scope *Scope) CombinedConditionSql() string { - return scope.joinsSql() + scope.whereSql() + scope.groupSql() + - scope.havingSql() + scope.orderSql() + scope.limitSql() + scope.offsetSql() + return scope.joinsSQL() + scope.whereSQL() + scope.groupSQL() + + scope.havingSQL() + scope.orderSQL() + scope.limitAndOffsetSQL() } +// FieldByName find `gorm.Field` with field name or db name func (scope *Scope) FieldByName(name string) (field *Field, ok bool) { + var ( + dbName = ToDBName(name) + mostMatchedField *Field + ) + for _, field := range scope.Fields() { if field.Name == name || field.DBName == name { return field, true } + if field.DBName == dbName { + mostMatchedField = field + } } - return nil, false + return mostMatchedField, mostMatchedField != nil } -// Raw set sql +// Raw set raw sql func (scope *Scope) Raw(sql string) *Scope { - scope.Sql = strings.Replace(sql, "$$", "?", -1) + scope.SQL = strings.Replace(sql, "$$", "?", -1) return scope } -// Exec invoke sql +// Exec perform generated SQL func (scope *Scope) Exec() *Scope { - defer scope.Trace(NowFunc()) + defer scope.trace(NowFunc()) if !scope.HasError() { - if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil { + if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { if count, err := result.RowsAffected(); scope.Err(err) == nil { scope.db.RowsAffected = count } @@ -334,37 +337,32 @@ func (scope *Scope) Set(name string, value interface{}) *Scope { return scope } -// Get get value by name +// Get get setting by name func (scope *Scope) Get(name string) (interface{}, bool) { return scope.db.Get(name) } -// InstanceId get InstanceId for scope -func (scope *Scope) InstanceId() string { - if scope.instanceId == "" { - scope.instanceId = fmt.Sprintf("%v%v", &scope, &scope.db) +// InstanceID get InstanceID for scope +func (scope *Scope) InstanceID() string { + if scope.instanceID == "" { + scope.instanceID = fmt.Sprintf("%v%v", &scope, &scope.db) } - return scope.instanceId + return scope.instanceID } +// InstanceSet set instance setting for current operation, but not for operations in callbacks, like saving associations callback func (scope *Scope) InstanceSet(name string, value interface{}) *Scope { - return scope.Set(name+scope.InstanceId(), value) + return scope.Set(name+scope.InstanceID(), value) } +// InstanceGet get instance setting from current operation func (scope *Scope) InstanceGet(name string) (interface{}, bool) { - return scope.Get(name + scope.InstanceId()) -} - -// Trace print sql log -func (scope *Scope) Trace(t time.Time) { - if len(scope.Sql) > 0 { - scope.db.slog(scope.Sql, t, scope.SqlVars...) - } + return scope.Get(name + scope.InstanceID()) } // Begin start a transaction func (scope *Scope) Begin() *Scope { - if db, ok := scope.SqlDB().(sqlDb); ok { + if db, ok := scope.SQLDB().(sqlDb); ok { if tx, err := db.Begin(); err == nil { scope.db.db = interface{}(tx).(sqlCommon) scope.InstanceSet("gorm:started_transaction", true) @@ -373,7 +371,7 @@ func (scope *Scope) Begin() *Scope { return scope } -// CommitOrRollback commit current transaction if there is no error, otherwise rollback it +// CommitOrRollback commit current transaction if no error happened, otherwise will rollback it func (scope *Scope) CommitOrRollback() *Scope { if _, ok := scope.InstanceGet("gorm:started_transaction"); ok { if db, ok := scope.db.db.(sqlTx); ok { @@ -388,6 +386,7 @@ func (scope *Scope) CommitOrRollback() *Scope { return scope } +// SelectAttrs return selected attributes func (scope *Scope) SelectAttrs() []string { if scope.selectAttrs == nil { attrs := []string{} @@ -407,57 +406,38 @@ func (scope *Scope) SelectAttrs() []string { return *scope.selectAttrs } +// OmitAttrs return omited attributes func (scope *Scope) OmitAttrs() []string { return scope.Search.omits } -func (scope *Scope) changeableDBColumn(column string) bool { - selectAttrs := scope.SelectAttrs() - omitAttrs := scope.OmitAttrs() +func (scope *Scope) scan(rows *sql.Rows, columns []string, fieldsMap map[string]*Field) { + var values = make([]interface{}, len(columns)) + var ignored interface{} - if len(selectAttrs) > 0 { - for _, attr := range selectAttrs { - if column == ToDBName(attr) { - return true + for index, column := range columns { + if field, ok := fieldsMap[column]; ok { + if field.Field.Kind() == reflect.Ptr { + values[index] = field.Field.Addr().Interface() + } else { + reflectValue := reflect.New(reflect.PtrTo(field.Struct.Type)) + reflectValue.Elem().Set(field.Field.Addr()) + values[index] = reflectValue.Interface() + } + } else { + values[index] = &ignored + } + } + + scope.Err(rows.Scan(values...)) + + for index, column := range columns { + if field, ok := fieldsMap[column]; ok { + if field.Field.Kind() != reflect.Ptr { + if v := reflect.ValueOf(values[index]).Elem().Elem(); v.IsValid() { + field.Field.Set(v) + } } } - return false } - - for _, attr := range omitAttrs { - if column == ToDBName(attr) { - return false - } - } - return true -} - -func (scope *Scope) changeableField(field *Field) bool { - selectAttrs := scope.SelectAttrs() - omitAttrs := scope.OmitAttrs() - - if len(selectAttrs) > 0 { - for _, attr := range selectAttrs { - if field.Name == attr || field.DBName == attr { - return true - } - } - return false - } - - for _, attr := range omitAttrs { - if field.Name == attr || field.DBName == attr { - return false - } - } - - return !field.IsIgnored -} - -func (scope *Scope) shouldSaveAssociations() bool { - saveAssociations, ok := scope.Get("gorm:save_associations") - if ok && !saveAssociations.(bool) { - return false - } - return true && !scope.HasError() } diff --git a/scope_private.go b/scope_private.go index 153392c5..c491cc7a 100644 --- a/scope_private.go +++ b/scope_private.go @@ -8,6 +8,7 @@ import ( "regexp" "strconv" "strings" + "time" ) func (scope *Scope) primaryCondition(value interface{}) string { @@ -75,7 +76,7 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri } func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) { - var notEqualSql string + var notEqualSQL string var primaryKey = scope.PrimaryKey() switch value := clause["query"].(type) { @@ -86,10 +87,10 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id) } else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS|IN) ").MatchString(value) { str = fmt.Sprintf(" NOT (%v) ", value) - notEqualSql = fmt.Sprintf("NOT (%v)", value) + notEqualSQL = fmt.Sprintf("NOT (%v)", value) } else { str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(value)) - notEqualSql = fmt.Sprintf("(%v <> ?)", scope.Quote(value)) + notEqualSQL = fmt.Sprintf("(%v <> ?)", scope.Quote(value)) } case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64: return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), value) @@ -138,7 +139,7 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string if scanner, ok := interface{}(arg).(driver.Valuer); ok { arg, _ = scanner.Value() } - str = strings.Replace(notEqualSql, "?", scope.AddToVars(arg), 1) + str = strings.Replace(notEqualSQL, "?", scope.AddToVars(arg), 1) } } return @@ -172,17 +173,20 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) return } -func (scope *Scope) whereSql() (sql string) { - var primaryConditions, andConditions, orConditions []string +func (scope *Scope) whereSQL() (sql string) { + var ( + quotedTableName = scope.QuotedTableName() + primaryConditions, andConditions, orConditions []string + ) - if !scope.Search.Unscoped && scope.Fields()["deleted_at"] != nil { - sql := fmt.Sprintf("(%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02')", scope.QuotedTableName(), scope.QuotedTableName()) + if !scope.Search.Unscoped && scope.HasColumn("deleted_at") { + sql := fmt.Sprintf("%v.deleted_at IS NULL", quotedTableName) primaryConditions = append(primaryConditions, sql) } if !scope.PrimaryKeyZero() { for _, field := range scope.PrimaryFields() { - sql := fmt.Sprintf("(%v = %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())) + sql := fmt.Sprintf("%v.%v = %v", quotedTableName, scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())) primaryConditions = append(primaryConditions, sql) } } @@ -205,30 +209,30 @@ func (scope *Scope) whereSql() (sql string) { } } - orSql := strings.Join(orConditions, " OR ") - combinedSql := strings.Join(andConditions, " AND ") - if len(combinedSql) > 0 { - if len(orSql) > 0 { - combinedSql = combinedSql + " OR " + orSql + orSQL := strings.Join(orConditions, " OR ") + combinedSQL := strings.Join(andConditions, " AND ") + if len(combinedSQL) > 0 { + if len(orSQL) > 0 { + combinedSQL = combinedSQL + " OR " + orSQL } } else { - combinedSql = orSql + combinedSQL = orSQL } if len(primaryConditions) > 0 { sql = "WHERE " + strings.Join(primaryConditions, " AND ") - if len(combinedSql) > 0 { - sql = sql + " AND (" + combinedSql + ")" + if len(combinedSQL) > 0 { + sql = sql + " AND (" + combinedSQL + ")" } - } else if len(combinedSql) > 0 { - sql = "WHERE " + combinedSql + } else if len(combinedSQL) > 0 { + sql = "WHERE " + combinedSQL } return } -func (scope *Scope) selectSql() string { +func (scope *Scope) selectSQL() string { if len(scope.Search.selects) == 0 { - if scope.Search.joins != "" { + if len(scope.Search.joinConditions) > 0 { return fmt.Sprintf("%v.*", scope.QuotedTableName()) } return "*" @@ -236,87 +240,60 @@ func (scope *Scope) selectSql() string { return scope.buildSelectQuery(scope.Search.selects) } -func (scope *Scope) orderSql() string { +func (scope *Scope) orderSQL() string { if len(scope.Search.orders) == 0 || scope.Search.countingQuery { return "" } return " ORDER BY " + strings.Join(scope.Search.orders, ",") } -func (scope *Scope) limitSql() string { - if !scope.Dialect().HasTop() { - if len(scope.Search.limit) == 0 { - return "" - } - return " LIMIT " + scope.Search.limit - } - - return "" +func (scope *Scope) limitAndOffsetSQL() string { + return scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset) } -func (scope *Scope) topSql() string { - if scope.Dialect().HasTop() && len(scope.Search.offset) == 0 { - if len(scope.Search.limit) == 0 { - return "" - } - return " TOP(" + scope.Search.limit + ")" - } - - return "" -} - -func (scope *Scope) offsetSql() string { - if len(scope.Search.offset) == 0 { - return "" - } - - if scope.Dialect().HasTop() { - sql := " OFFSET " + scope.Search.offset + " ROW " - if len(scope.Search.limit) > 0 { - sql += "FETCH NEXT " + scope.Search.limit + " ROWS ONLY" - } - return sql - } - return " OFFSET " + scope.Search.offset -} - -func (scope *Scope) groupSql() string { +func (scope *Scope) groupSQL() string { if len(scope.Search.group) == 0 { return "" } return " GROUP BY " + scope.Search.group } -func (scope *Scope) havingSql() string { - if scope.Search.havingConditions == nil { +func (scope *Scope) havingSQL() string { + if len(scope.Search.havingConditions) == 0 { return "" } var andConditions []string - for _, clause := range scope.Search.havingConditions { if sql := scope.buildWhereCondition(clause); sql != "" { andConditions = append(andConditions, sql) } } - combinedSql := strings.Join(andConditions, " AND ") - if len(combinedSql) == 0 { + combinedSQL := strings.Join(andConditions, " AND ") + if len(combinedSQL) == 0 { return "" } - return " HAVING " + combinedSql + return " HAVING " + combinedSQL } -func (scope *Scope) joinsSql() string { - return scope.Search.joins + " " +func (scope *Scope) joinsSQL() string { + var joinConditions []string + for _, clause := range scope.Search.joinConditions { + if sql := scope.buildWhereCondition(clause); sql != "" { + joinConditions = append(joinConditions, strings.TrimSuffix(strings.TrimPrefix(sql, "("), ")")) + } + } + + return strings.Join(joinConditions, " ") + " " } -func (scope *Scope) prepareQuerySql() { +func (scope *Scope) prepareQuerySQL() { if scope.Search.raw { scope.Raw(strings.TrimSuffix(strings.TrimPrefix(scope.CombinedConditionSql(), " WHERE ("), ")")) } else { - scope.Raw(fmt.Sprintf("SELECT %v %v FROM %v %v", scope.topSql(), scope.selectSql(), scope.QuotedTableName(), scope.CombinedConditionSql())) + scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSQL(), scope.QuotedTableName(), scope.CombinedConditionSql())) } return } @@ -338,61 +315,53 @@ func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope { return scope } -func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignoreProtectedAttrs bool) (results map[string]interface{}, hasUpdate bool) { - if !scope.IndirectValue().CanAddr() { +func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}) (results map[string]interface{}, hasUpdate bool) { + if scope.IndirectValue().Kind() != reflect.Struct { return values, true } - var hasExpr bool + results = map[string]interface{}{} for key, value := range values { - if field, ok := scope.FieldByName(key); ok && field.Field.IsValid() { + if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) { if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) { if _, ok := value.(*expr); ok { - hasExpr = true - } else if !equalAsString(field.Field.Interface(), value) { hasUpdate = true + results[field.DBName] = value + } else if !equalAsString(field.Field.Interface(), value) { field.Set(value) + if field.IsNormal { + hasUpdate = true + results[field.DBName] = field.Field.Interface() + } } + } else { + field.Set(value) } } } - - if hasExpr { - var updateMap = map[string]interface{}{} - for key, field := range scope.Fields() { - if field.IsNormal { - if v, ok := values[key]; ok { - updateMap[key] = v - } else { - updateMap[key] = field.Field.Interface() - } - } - } - return updateMap, true - } return } func (scope *Scope) row() *sql.Row { - defer scope.Trace(NowFunc()) - scope.callCallbacks(scope.db.parent.callback.rowQueries) - scope.prepareQuerySql() - return scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...) + defer scope.trace(NowFunc()) + scope.callCallbacks(scope.db.parent.callbacks.rowQueries) + scope.prepareQuerySQL() + return scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...) } func (scope *Scope) rows() (*sql.Rows, error) { - defer scope.Trace(NowFunc()) - scope.callCallbacks(scope.db.parent.callback.rowQueries) - scope.prepareQuerySql() - return scope.SqlDB().Query(scope.Sql, scope.SqlVars...) + defer scope.trace(NowFunc()) + scope.callCallbacks(scope.db.parent.callbacks.rowQueries) + scope.prepareQuerySQL() + return scope.SQLDB().Query(scope.SQL, scope.SQLVars...) } func (scope *Scope) initialize() *Scope { for _, clause := range scope.Search.whereConditions { - scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"]), false) + scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"])) } - scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.initAttrs), false) - scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.assignAttrs), false) + scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.initAttrs)) + scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.assignAttrs)) return scope } @@ -433,23 +402,45 @@ func (scope *Scope) typeName() string { return typ.Name() } +// trace print sql log +func (scope *Scope) trace(t time.Time) { + if len(scope.SQL) > 0 { + scope.db.slog(scope.SQL, t, scope.SQLVars...) + } +} + +func (scope *Scope) changeableField(field *Field) bool { + if selectAttrs := scope.SelectAttrs(); len(selectAttrs) > 0 { + for _, attr := range selectAttrs { + if field.Name == attr || field.DBName == attr { + return true + } + } + return false + } + + for _, attr := range scope.OmitAttrs() { + if field.Name == attr || field.DBName == attr { + return false + } + } + + return true +} + +func (scope *Scope) shouldSaveAssociations() bool { + if saveAssociations, ok := scope.Get("gorm:save_associations"); ok && !saveAssociations.(bool) { + return false + } + return true && !scope.HasError() +} + func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { toScope := scope.db.NewScope(value) - fromFields := scope.Fields() - toFields := toScope.Fields() for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") { - var fromField, toField *Field - if field, ok := scope.FieldByName(foreignKey); ok { - fromField = field - } else { - fromField = fromFields[ToDBName(foreignKey)] - } - if field, ok := toScope.FieldByName(foreignKey); ok { - toField = field - } else { - toField = toFields[ToDBName(foreignKey)] - } + fromField, _ := scope.FieldByName(foreignKey) + toField, _ := toScope.FieldByName(foreignKey) if fromField != nil { if relationship := fromField.Relationship; relationship != nil { @@ -508,30 +499,26 @@ func (scope *Scope) createJoinTable(field *StructField) { if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil { joinTableHandler := relationship.JoinTableHandler joinTable := joinTableHandler.Table(scope.db) - if !scope.Dialect().HasTable(scope, joinTable) { + if !scope.Dialect().HasTable(joinTable) { toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()} var sqlTypes, primaryKeys []string for idx, fieldName := range relationship.ForeignFieldNames { - if field, ok := scope.Fields()[fieldName]; ok { - value := reflect.Indirect(reflect.New(field.Struct.Type)) - primaryKeySqlType := field.TagSettings["TYPE"] - if primaryKeySqlType == "" { - primaryKeySqlType = scope.Dialect().SqlTag(value, 255, false) - } - sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+primaryKeySqlType) + if field, ok := scope.FieldByName(fieldName); ok { + foreignKeyStruct := field.clone() + foreignKeyStruct.IsPrimaryKey = false + foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true" + sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct)) primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx])) } } for idx, fieldName := range relationship.AssociationForeignFieldNames { - if field, ok := toScope.Fields()[fieldName]; ok { - value := reflect.Indirect(reflect.New(field.Struct.Type)) - primaryKeySqlType := field.TagSettings["TYPE"] - if primaryKeySqlType == "" { - primaryKeySqlType = scope.Dialect().SqlTag(value, 255, false) - } - sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+primaryKeySqlType) + if field, ok := toScope.FieldByName(fieldName); ok { + foreignKeyStruct := field.clone() + foreignKeyStruct.IsPrimaryKey = false + foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true" + sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct)) primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx])) } } @@ -545,10 +532,10 @@ func (scope *Scope) createJoinTable(field *StructField) { func (scope *Scope) createTable() *Scope { var tags []string var primaryKeys []string - var primaryKeyInColumnType bool = false - for _, field := range scope.GetStructFields() { + var primaryKeyInColumnType = false + for _, field := range scope.GetModelStruct().StructFields { if field.IsNormal { - sqlTag := scope.generateSqlTag(field) + sqlTag := scope.Dialect().DataTypeOf(field) // Check if the primary key constraint was specified as // part of the column type. If so, we can only support @@ -582,13 +569,6 @@ func (scope *Scope) dropTable() *Scope { return scope } -func (scope *Scope) dropTableIfExists() *Scope { - if scope.Dialect().HasTable(scope, scope.TableName()) { - scope.dropTable() - } - return scope -} - func (scope *Scope) modifyColumn(column string, typ string) { scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.QuotedTableName(), scope.Quote(column), typ)).Exec() } @@ -598,13 +578,13 @@ func (scope *Scope) dropColumn(column string) { } func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { - if scope.Dialect().HasIndex(scope, scope.TableName(), indexName) { + if scope.Dialect().HasIndex(scope.TableName(), indexName) { return } var columns []string for _, name := range column { - columns = append(columns, scope.QuoteIfPossible(name)) + columns = append(columns, scope.quoteIfPossible(name)) } sqlCreate := "CREATE INDEX" @@ -612,31 +592,35 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { sqlCreate = "CREATE UNIQUE INDEX" } - scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSql())).Exec() + scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSQL())).Exec() } func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) { var keyName = fmt.Sprintf("%s_%s_%s_foreign", scope.TableName(), field, dest) keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_") + + if scope.Dialect().HasForeignKey(scope.TableName(), keyName) { + return + } var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;` - scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.QuoteIfPossible(keyName), scope.QuoteIfPossible(field), dest, onDelete, onUpdate)).Exec() + scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec() } func (scope *Scope) removeIndex(indexName string) { - scope.Dialect().RemoveIndex(scope, indexName) + scope.Dialect().RemoveIndex(scope.TableName(), indexName) } func (scope *Scope) autoMigrate() *Scope { tableName := scope.TableName() quotedTableName := scope.QuotedTableName() - if !scope.Dialect().HasTable(scope, tableName) { + if !scope.Dialect().HasTable(tableName) { scope.createTable() } else { - for _, field := range scope.GetStructFields() { - if !scope.Dialect().HasColumn(scope, tableName, field.DBName) { + for _, field := range scope.GetModelStruct().StructFields { + if !scope.Dialect().HasColumn(tableName, field.DBName) { if field.IsNormal { - sqlTag := scope.generateSqlTag(field) + sqlTag := scope.Dialect().DataTypeOf(field) scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec() } } diff --git a/scope_utils.go b/scope_utils.go new file mode 100644 index 00000000..2d914314 --- /dev/null +++ b/scope_utils.go @@ -0,0 +1,67 @@ +package gorm + +import "reflect" + +func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (results [][]interface{}) { + for _, value := range values { + indirectValue := reflect.ValueOf(value) + for indirectValue.Kind() == reflect.Ptr { + indirectValue = indirectValue.Elem() + } + + switch indirectValue.Kind() { + case reflect.Slice: + for i := 0; i < indirectValue.Len(); i++ { + var result []interface{} + var object = indirect(indirectValue.Index(i)) + for _, column := range columns { + result = append(result, object.FieldByName(column).Interface()) + } + results = append(results, result) + } + case reflect.Struct: + var result []interface{} + for _, column := range columns { + result = append(result, indirectValue.FieldByName(column).Interface()) + } + results = append(results, result) + } + } + return +} + +func (scope *Scope) getColumnAsScope(column string) *Scope { + indirectScopeValue := scope.IndirectValue() + + switch indirectScopeValue.Kind() { + case reflect.Slice: + if fieldStruct, ok := scope.GetModelStruct().ModelType.FieldByName(column); ok { + fieldType := fieldStruct.Type + if fieldType.Kind() == reflect.Slice || fieldType.Kind() == reflect.Ptr { + fieldType = fieldType.Elem() + } + + results := reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType))).Elem() + + for i := 0; i < indirectScopeValue.Len(); i++ { + result := indirect(indirect(indirectScopeValue.Index(i)).FieldByName(column)) + + if result.Kind() == reflect.Slice { + for j := 0; j < result.Len(); j++ { + if elem := result.Index(j); elem.CanAddr() { + results = reflect.Append(results, elem.Addr()) + } + } + } else if result.CanAddr() { + results = reflect.Append(results, result.Addr()) + } + } + return scope.New(results.Interface()) + } + case reflect.Struct: + if field := indirectScopeValue.FieldByName(column); field.CanAddr() { + return scope.New(field.Addr().Interface()) + } + } + return nil +} diff --git a/search.go b/search.go index 166b9a86..078bd429 100644 --- a/search.go +++ b/search.go @@ -8,15 +8,15 @@ type search struct { orConditions []map[string]interface{} notConditions []map[string]interface{} havingConditions []map[string]interface{} + joinConditions []map[string]interface{} initAttrs []interface{} assignAttrs []interface{} selects map[string]interface{} omits []string orders []string - joins string preload []searchPreload - offset string - limit string + offset int + limit int group string tableName string raw bool @@ -82,18 +82,18 @@ func (s *search) Omit(columns ...string) *search { return s } -func (s *search) Limit(value interface{}) *search { - s.limit = s.getInterfaceAsSql(value) +func (s *search) Limit(limit int) *search { + s.limit = limit return s } -func (s *search) Offset(value interface{}) *search { - s.offset = s.getInterfaceAsSql(value) +func (s *search) Offset(offset int) *search { + s.offset = offset return s } func (s *search) Group(query string) *search { - s.group = s.getInterfaceAsSql(query) + s.group = s.getInterfaceAsSQL(query) return s } @@ -102,8 +102,8 @@ func (s *search) Having(query string, values ...interface{}) *search { return s } -func (s *search) Joins(query string) *search { - s.joins = query +func (s *search) Joins(query string, values ...interface{}) *search { + s.joinConditions = append(s.joinConditions, map[string]interface{}{"query": query, "args": values}) return s } @@ -134,12 +134,12 @@ func (s *search) Table(name string) *search { return s } -func (s *search) getInterfaceAsSql(value interface{}) (str string) { +func (s *search) getInterfaceAsSQL(value interface{}) (str string) { switch value.(type) { case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: str = fmt.Sprintf("%v", value) default: - s.db.AddError(InvalidSql) + s.db.AddError(ErrInvalidSQL) } if str == "-1" { diff --git a/sqlite3.go b/sqlite3.go deleted file mode 100644 index d052d2c1..00000000 --- a/sqlite3.go +++ /dev/null @@ -1,84 +0,0 @@ -package gorm - -import ( - "fmt" - "reflect" - "time" -) - -type sqlite3 struct { - commonDialect -} - -func (sqlite3) SqlTag(value reflect.Value, size int, autoIncrease bool) string { - switch value.Kind() { - case reflect.Bool: - return "bool" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if autoIncrease { - return "integer primary key autoincrement" - } - return "integer" - case reflect.Int64, reflect.Uint64: - if autoIncrease { - return "integer primary key autoincrement" - } - return "bigint" - case reflect.Float32, reflect.Float64: - return "real" - case reflect.String: - if size > 0 && size < 65532 { - return fmt.Sprintf("varchar(%d)", size) - } - return "text" - case reflect.Struct: - if _, ok := value.Interface().(time.Time); ok { - return "datetime" - } - default: - if _, ok := value.Interface().([]byte); ok { - return "blob" - } - } - panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String())) -} - -func (s sqlite3) HasTable(scope *Scope, tableName string) bool { - var count int - s.RawScanInt(scope, &count, "SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName) - return count > 0 -} - -func (s sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bool { - var count int - s.RawScanInt(scope, &count, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%(\"%v\" %%' OR sql LIKE '%%,\"%v\" %%' OR sql LIKE '%%, \"%v\" %%' OR sql LIKE '%%( %v %%' OR sql LIKE '%%, %v %%' OR sql LIKE '%%,%v %%');\n", columnName, columnName, columnName, columnName, columnName, columnName), tableName) - return count > 0 -} - -func (s sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool { - var count int - s.RawScanInt(scope, &count, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName) - return count > 0 -} - -func (sqlite3) RemoveIndex(scope *Scope, indexName string) { - scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error) -} - -func (sqlite3) CurrentDatabase(scope *Scope) (name string) { - var ( - ifaces = make([]interface{}, 3) - pointers = make([]*string, 3) - i int - ) - for i = 0; i < 3; i++ { - ifaces[i] = &pointers[i] - } - if err := scope.NewDB().Raw("PRAGMA database_list").Row().Scan(ifaces...); scope.Err(err) != nil { - return - } - if pointers[1] != nil { - name = *pointers[1] - } - return -} diff --git a/structs_test.go b/structs_test.go index cb9c9260..e7bae25f 100644 --- a/structs_test.go +++ b/structs_test.go @@ -42,9 +42,9 @@ type CreditCard struct { ID int8 Number string UserId sql.NullInt64 - CreatedAt time.Time + CreatedAt time.Time `sql:"not null"` UpdatedAt time.Time - DeletedAt time.Time + DeletedAt *time.Time } type Email struct { @@ -62,7 +62,7 @@ type Address struct { Post string CreatedAt time.Time UpdatedAt time.Time - DeletedAt time.Time + DeletedAt *time.Time } type Language struct { diff --git a/update_test.go b/update_test.go index 75877488..218c5834 100644 --- a/update_test.go +++ b/update_test.go @@ -71,13 +71,14 @@ func TestUpdate(t *testing.T) { } DB.First(&product4, product4.Id) + updatedAt4 := product4.UpdatedAt DB.Model(&product4).Update("price", gorm.Expr("price + ? - ?", 100, 50)) var product5 Product DB.First(&product5, product4.Id) if product5.Price != product4.Price+100-50 { t.Errorf("Update with expression") } - if product5.UpdatedAt.Format(time.RFC3339Nano) == product4.UpdatedAt.Format(time.RFC3339Nano) { + if product4.UpdatedAt.Format(time.RFC3339Nano) == updatedAt4.Format(time.RFC3339Nano) { t.Errorf("Update with expression should update UpdatedAt") } } @@ -170,13 +171,15 @@ func TestUpdates(t *testing.T) { t.Errorf("product2's code should be updated") } + updatedAt4 := product4.UpdatedAt DB.Model(&product4).Updates(map[string]interface{}{"price": gorm.Expr("price + ?", 100)}) var product5 Product DB.First(&product5, product4.Id) if product5.Price != product4.Price+100 { t.Errorf("Updates with expression") } - if product5.UpdatedAt.Format(time.RFC3339Nano) == product4.UpdatedAt.Format(time.RFC3339Nano) { + // product4's UpdatedAt will be reset when updating + if product4.UpdatedAt.Format(time.RFC3339Nano) == updatedAt4.Format(time.RFC3339Nano) { t.Errorf("Updates with expression should update UpdatedAt") } } @@ -419,3 +422,32 @@ func TestUpdateColumnsSkipsAssociations(t *testing.T) { t.Errorf("Expected user's BillingAddress.Address1=%s to remain unchanged after UpdateColumns invocation, but BillingAddress.Address1=%s", address1, freshUser.BillingAddress.Address1) } } + +func TestUpdatesWithBlankValues(t *testing.T) { + product := Product{Code: "product1", Price: 10} + DB.Save(&product) + + DB.Model(&Product{Id: product.Id}).Updates(&Product{Price: 100}) + + var product1 Product + DB.First(&product1, product.Id) + + if product1.Code != "product1" || product1.Price != 100 { + t.Errorf("product's code should not be updated") + } +} + +func TestUpdateDecodeVirtualAttributes(t *testing.T) { + var user = User{ + Name: "jinzhu", + IgnoreMe: 88, + } + + DB.Save(&user) + + DB.Model(&user).Updates(User{Name: "jinzhu2", IgnoreMe: 100}) + + if user.IgnoreMe != 100 { + t.Errorf("should decode virtual attributes to struct, so it could be used in callbacks") + } +} diff --git a/utils.go b/utils.go index b457f321..af11f5d2 100644 --- a/utils.go +++ b/utils.go @@ -2,10 +2,26 @@ package gorm import ( "bytes" + "database/sql/driver" + "fmt" + "reflect" + "regexp" + "runtime" "strings" "sync" + "time" ) +// NowFunc returns current time, this function is exported in order to be able +// to give the flexibility to the developer to customize it according to their +// needs, e.g: +// gorm.NowFunc = func() time.Time { +// return time.Now().UTC() +// } +var NowFunc = func() time.Time { + return time.Now() +} + // Copied from golint var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UI", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} var commonInitialismsReplacer *strings.Replacer @@ -41,30 +57,239 @@ func newSafeMap() *safeMap { var smap = newSafeMap() +type strCase bool + +const ( + lower strCase = false + upper strCase = true +) + +// ToDBName convert string to db name func ToDBName(name string) string { if v := smap.Get(name); v != "" { return v } - value := commonInitialismsReplacer.Replace(name) - buf := bytes.NewBufferString("") - for i, v := range value { - if i > 0 && v >= 'A' && v <= 'Z' { - buf.WriteRune('_') - } - buf.WriteRune(v) + if name == "" { + return "" } + var ( + value = commonInitialismsReplacer.Replace(name) + buf = bytes.NewBufferString("") + lastCase, currCase, nextCase strCase + ) + + for i, v := range value[:len(value)-1] { + nextCase = value[i+1] >= 'A' && value[i+1] <= 'Z' + if i > 0 { + if currCase == upper { + if lastCase == upper && nextCase == upper { + buf.WriteRune(v) + } else { + if value[i-1] != '_' && value[i+1] != '_' { + buf.WriteRune('_') + } + buf.WriteRune(v) + } + } else { + buf.WriteRune(v) + } + } else { + currCase = upper + buf.WriteRune(v) + } + lastCase = currCase + currCase = nextCase + } + + buf.WriteByte(value[len(value)-1]) + s := strings.ToLower(buf.String()) smap.Set(name, s) return s } +// SQL expression type expr struct { expr string args []interface{} } +// Expr generate raw SQL expression, for example: +// DB.Model(&product).Update("price", gorm.Expr("price * ? + ?", 2, 100)) func Expr(expression string, args ...interface{}) *expr { return &expr{expr: expression, args: args} } + +func indirect(reflectValue reflect.Value) reflect.Value { + for reflectValue.Kind() == reflect.Ptr { + reflectValue = reflectValue.Elem() + } + return reflectValue +} + +func toQueryMarks(primaryValues [][]interface{}) string { + var results []string + + for _, primaryValue := range primaryValues { + var marks []string + for _ = range primaryValue { + marks = append(marks, "?") + } + + if len(marks) > 1 { + results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ","))) + } else { + results = append(results, strings.Join(marks, "")) + } + } + return strings.Join(results, ",") +} + +func toQueryCondition(scope *Scope, columns []string) string { + var newColumns []string + for _, column := range columns { + newColumns = append(newColumns, scope.Quote(column)) + } + + if len(columns) > 1 { + return fmt.Sprintf("(%v)", strings.Join(newColumns, ",")) + } + return strings.Join(newColumns, ",") +} + +func toQueryValues(values [][]interface{}) (results []interface{}) { + for _, value := range values { + for _, v := range value { + results = append(results, v) + } + } + return +} + +func fileWithLineNum() string { + for i := 2; i < 15; i++ { + _, file, line, ok := runtime.Caller(i) + if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) { + return fmt.Sprintf("%v:%v", file, line) + } + } + return "" +} + +func isBlank(value reflect.Value) bool { + return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface()) +} + +func toSearchableMap(attrs ...interface{}) (result interface{}) { + if len(attrs) > 1 { + if str, ok := attrs[0].(string); ok { + result = map[string]interface{}{str: attrs[1]} + } + } else if len(attrs) == 1 { + if attr, ok := attrs[0].(map[string]interface{}); ok { + result = attr + } + + if attr, ok := attrs[0].(interface{}); ok { + result = attr + } + } + return +} + +func convertInterfaceToMap(values interface{}) map[string]interface{} { + attrs := map[string]interface{}{} + + switch value := values.(type) { + case map[string]interface{}: + return value + case []interface{}: + for _, v := range value { + for key, value := range convertInterfaceToMap(v) { + attrs[key] = value + } + } + case interface{}: + reflectValue := reflect.ValueOf(values) + + switch reflectValue.Kind() { + case reflect.Map: + for _, key := range reflectValue.MapKeys() { + attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface() + } + default: + for _, field := range (&Scope{Value: values}).Fields() { + if !field.IsBlank { + attrs[field.DBName] = field.Field.Interface() + } + } + } + } + return attrs +} + +func equalAsString(a interface{}, b interface{}) bool { + return toString(a) == toString(b) +} + +func toString(str interface{}) string { + if values, ok := str.([]interface{}); ok { + var results []string + for _, value := range values { + results = append(results, toString(value)) + } + return strings.Join(results, "_") + } else if bytes, ok := str.([]byte); ok { + return string(bytes) + } else if reflectValue := reflect.Indirect(reflect.ValueOf(str)); reflectValue.IsValid() { + return fmt.Sprintf("%v", reflectValue.Interface()) + } + return "" +} + +func makeSlice(elemType reflect.Type) interface{} { + if elemType.Kind() == reflect.Slice { + elemType = elemType.Elem() + } + sliceType := reflect.SliceOf(elemType) + slice := reflect.New(sliceType) + slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0)) + return slice.Interface() +} + +func strInSlice(a string, list []string) bool { + for _, b := range list { + if b == a { + return true + } + } + return false +} + +// getValueFromFields return given fields's value +func getValueFromFields(value reflect.Value, fieldNames []string) (results []interface{}) { + // If value is a nil pointer, Indirect returns a zero Value! + // Therefor we need to check for a zero value, + // as FieldByName could panic + if indirectValue := reflect.Indirect(value); indirectValue.IsValid() { + for _, fieldName := range fieldNames { + if fieldValue := indirectValue.FieldByName(fieldName); fieldValue.IsValid() { + result := fieldValue.Interface() + if r, ok := result.(driver.Valuer); ok { + result, _ = r.Value() + } + results = append(results, result) + } + } + } + return +} + +func addExtraSpaceIfExist(str string) string { + if str != "" { + return " " + str + } + return "" +} diff --git a/utils_private.go b/utils_private.go deleted file mode 100644 index f297857b..00000000 --- a/utils_private.go +++ /dev/null @@ -1,98 +0,0 @@ -package gorm - -import ( - "fmt" - "reflect" - "regexp" - "runtime" - "strings" -) - -func fileWithLineNum() string { - for i := 2; i < 15; i++ { - _, file, line, ok := runtime.Caller(i) - if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) { - return fmt.Sprintf("%v:%v", file, line) - } - } - return "" -} - -func isBlank(value reflect.Value) bool { - return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface()) -} - -func toSearchableMap(attrs ...interface{}) (result interface{}) { - if len(attrs) > 1 { - if str, ok := attrs[0].(string); ok { - result = map[string]interface{}{str: attrs[1]} - } - } else if len(attrs) == 1 { - if attr, ok := attrs[0].(map[string]interface{}); ok { - result = attr - } - - if attr, ok := attrs[0].(interface{}); ok { - result = attr - } - } - return -} - -func convertInterfaceToMap(values interface{}) map[string]interface{} { - attrs := map[string]interface{}{} - - switch value := values.(type) { - case map[string]interface{}: - for k, v := range value { - attrs[ToDBName(k)] = v - } - case []interface{}: - for _, v := range value { - for key, value := range convertInterfaceToMap(v) { - attrs[key] = value - } - } - case interface{}: - reflectValue := reflect.ValueOf(values) - - switch reflectValue.Kind() { - case reflect.Map: - for _, key := range reflectValue.MapKeys() { - attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface() - } - default: - scope := Scope{Value: values} - for _, field := range scope.Fields() { - if !field.IsBlank && !field.IsIgnored { - attrs[field.DBName] = field.Field.Interface() - } - } - } - } - return attrs -} - -func toString(str interface{}) string { - if values, ok := str.([]interface{}); ok { - var results []string - for _, value := range values { - results = append(results, toString(value)) - } - return strings.Join(results, "_") - } else if bytes, ok := str.([]byte); ok { - return string(bytes) - } else if reflectValue := reflect.Indirect(reflect.ValueOf(str)); reflectValue.IsValid() { - return fmt.Sprintf("%v", reflectValue.Interface()) - } - return "" -} - -func strInSlice(a string, list []string) bool { - for _, b := range list { - if b == a { - return true - } - } - return false -} diff --git a/utils_test.go b/utils_test.go new file mode 100644 index 00000000..07f5b17f --- /dev/null +++ b/utils_test.go @@ -0,0 +1,30 @@ +package gorm_test + +import ( + "testing" + + "github.com/jinzhu/gorm" +) + +func TestToDBNameGenerateFriendlyName(t *testing.T) { + var maps = map[string]string{ + "": "", + "ThisIsATest": "this_is_a_test", + "PFAndESI": "pf_and_esi", + "AbcAndJkl": "abc_and_jkl", + "EmployeeID": "employee_id", + "SKU_ID": "sku_id", + "HTTPAndSMTP": "http_and_smtp", + "HTTPServerHandlerForURLID": "http_server_handler_for_url_id", + "UUID": "uuid", + "HTTPURL": "http_url", + "HTTP_URL": "http_url", + "ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIdCanBeUsedAtTheEndAsID": "this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id", + } + + for key, value := range maps { + if gorm.ToDBName(key) != value { + t.Errorf("%v ToDBName should equal %v, but got %v", key, value, gorm.ToDBName(key)) + } + } +}