diff --git a/callbacks/create.go b/callbacks/create.go index 3f6a81e4..97a2832c 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -194,13 +194,13 @@ func AfterCreate(db *gorm.DB) { func ConvertToCreateValues(stmt *gorm.Statement) clause.Values { switch value := stmt.Dest.(type) { case map[string]interface{}: - return ConvertMapToValues(stmt, value) + return ConvertMapToValuesForCreate(stmt, value) case []map[string]interface{}: - return ConvertSliceOfMapToValues(stmt, value) + return ConvertSliceOfMapToValuesForCreate(stmt, value) default: var ( values = clause.Values{} - selectColumns, restricted = SelectAndOmitColumns(stmt) + selectColumns, restricted = SelectAndOmitColumns(stmt, true, false) curTime = stmt.DB.NowFunc() isZero = false ) diff --git a/callbacks/helper.go b/callbacks/helper.go index 0dd6ff43..8a69fbd1 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -8,7 +8,7 @@ import ( ) // SelectAndOmitColumns get select and omit columns, select -> true, omit -> false -func SelectAndOmitColumns(stmt *gorm.Statement) (map[string]bool, bool) { +func SelectAndOmitColumns(stmt *gorm.Statement, requireCreate, requireUpdate bool) (map[string]bool, bool) { results := map[string]bool{} // select columns @@ -36,13 +36,23 @@ func SelectAndOmitColumns(stmt *gorm.Statement) (map[string]bool, bool) { } } + if stmt.Schema != nil { + for _, field := range stmt.Schema.FieldsByDBName { + if requireCreate && !field.Creatable { + results[field.DBName] = false + } else if requireUpdate && !field.Updatable { + results[field.DBName] = false + } + } + } + return results, len(stmt.Selects) > 0 } -// ConvertMapToValues convert map to values -func ConvertMapToValues(stmt *gorm.Statement, mapValue map[string]interface{}) (values clause.Values) { +// ConvertMapToValuesForCreate convert map to values +func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]interface{}) (values clause.Values) { columns := make([]string, 0, len(mapValue)) - selectColumns, restricted := SelectAndOmitColumns(stmt) + selectColumns, restricted := SelectAndOmitColumns(stmt, true, false) var keys []string for k, _ := range mapValue { @@ -64,12 +74,12 @@ func ConvertMapToValues(stmt *gorm.Statement, mapValue map[string]interface{}) ( return } -// ConvertSliceOfMapToValues convert slice of map to values -func ConvertSliceOfMapToValues(stmt *gorm.Statement, mapValues []map[string]interface{}) (values clause.Values) { +// ConvertSliceOfMapToValuesForCreate convert slice of map to values +func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[string]interface{}) (values clause.Values) { var ( columns = []string{} result = map[string][]interface{}{} - selectColumns, restricted = SelectAndOmitColumns(stmt) + selectColumns, restricted = SelectAndOmitColumns(stmt, true, false) ) for idx, mapValue := range mapValues { diff --git a/callbacks/scan.go b/callbacks/scan.go index f8f1ef54..2bd0143c 100644 --- a/callbacks/scan.go +++ b/callbacks/scan.go @@ -56,7 +56,11 @@ func Scan(rows *sql.Rows, db *gorm.DB) { fields := make([]*schema.Field, len(columns)) for idx, column := range columns { - fields[idx] = db.Statement.Schema.LookUpField(column) + if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { + fields[idx] = field + } else { + values[idx] = sql.RawBytes{} + } } for rows.Next() { @@ -80,7 +84,7 @@ func Scan(rows *sql.Rows, db *gorm.DB) { } case reflect.Struct: for idx, column := range columns { - if field := db.Statement.Schema.LookUpField(column); field != nil { + if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() } else { values[idx] = sql.RawBytes{} diff --git a/callbacks/update.go b/callbacks/update.go index eab9f929..53c646e9 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -91,7 +91,7 @@ func AfterUpdate(db *gorm.DB) { // ConvertToAssignments convert to update assignments func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { - selectColumns, restricted := SelectAndOmitColumns(stmt) + selectColumns, restricted := SelectAndOmitColumns(stmt, false, true) reflectModelValue := reflect.ValueOf(stmt.Model) switch value := stmt.Dest.(type) { diff --git a/schema/field.go b/schema/field.go index ee1baf3c..a8e55acd 100644 --- a/schema/field.go +++ b/schema/field.go @@ -42,6 +42,7 @@ type Field struct { AutoIncrement bool Creatable bool Updatable bool + Readable bool HasDefaultValue bool AutoCreateTime TimeType AutoUpdateTime TimeType @@ -73,6 +74,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { StructField: fieldStruct, Creatable: true, Updatable: true, + Readable: true, Tag: fieldStruct.Tag, TagSettings: ParseTagSetting(fieldStruct.Tag.Get("gorm"), ";"), Schema: schema, @@ -117,6 +119,21 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { if _, ok := field.TagSettings["-"]; ok { field.Creatable = false field.Updatable = false + field.Readable = false + } + + if v, ok := field.TagSettings["<-"]; ok { + if !strings.Contains(v, "create") { + field.Creatable = false + } + + if !strings.Contains(v, "update") { + field.Updatable = false + } + } + + if _, ok := field.TagSettings["->"]; ok { + field.Readable = false } if dbName, ok := field.TagSettings["COLUMN"]; ok { @@ -235,6 +252,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { var err error field.Creatable = false field.Updatable = false + field.Readable = false if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil { schema.err = err }