package gorm import ( "errors" "fmt" "reflect" "strings" ) func (association *Association) setErr(err error) *Association { if err != nil { association.Error = err } return association } func (association *Association) saveAssociations(values ...interface{}) *Association { scope := association.Scope field := association.Field relationship := association.Field.Relationship saveAssociation := func(reflectValue reflect.Value) { // value has to been pointer if reflectValue.Kind() != reflect.Ptr { reflectPtr := reflect.New(reflectValue.Type()) reflectPtr.Elem().Set(reflectValue) reflectValue = reflectPtr } // value has to been saved for many2many if relationship.Kind == "many_to_many" { if scope.New(reflectValue.Interface()).PrimaryKeyZero() { association.setErr(scope.NewDB().Save(reflectValue.Interface()).Error) } } // Assign Fields var fieldType = field.Field.Type() var setFieldBackToValue, setSliceFieldBackToValue bool if reflectValue.Type().AssignableTo(fieldType) { field.Set(reflectValue) } else if reflectValue.Type().Elem().AssignableTo(fieldType) { // if field's type is struct, then need to set value back to argument after save setFieldBackToValue = true field.Set(reflectValue.Elem()) } else if fieldType.Kind() == reflect.Slice { if reflectValue.Type().AssignableTo(fieldType.Elem()) { field.Set(reflect.Append(field.Field, reflectValue)) } else if reflectValue.Type().Elem().AssignableTo(fieldType.Elem()) { // if field's type is slice of struct, then need to set value back to argument after save setSliceFieldBackToValue = true field.Set(reflect.Append(field.Field, reflectValue.Elem())) } } if relationship.Kind == "many_to_many" { association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, reflectValue.Interface())) } else { association.setErr(scope.NewDB().Select(field.Name).Save(scope.Value).Error) if setFieldBackToValue { reflectValue.Elem().Set(field.Field) } else if setSliceFieldBackToValue { reflectValue.Elem().Set(field.Field.Index(field.Field.Len() - 1)) } } } for _, value := range values { reflectValue := reflect.ValueOf(value) indirectReflectValue := reflect.Indirect(reflectValue) if indirectReflectValue.Kind() == reflect.Struct { saveAssociation(reflectValue) } else if indirectReflectValue.Kind() == reflect.Slice { for i := 0; i < indirectReflectValue.Len(); i++ { saveAssociation(indirectReflectValue.Index(i)) } } else { association.setErr(errors.New("invalid value type")) } } return association } func toQueryMarks(primaryValues [][]interface{}) string { var results []string for _, primaryValue := range primaryValues { var marks []string for _ = range primaryValue { marks = append(marks, "?") } if len(marks) > 1 { results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ","))) } else { results = append(results, strings.Join(marks, "")) } } return strings.Join(results, ",") } func toQueryCondition(scope *Scope, columns []string) string { var newColumns []string for _, column := range columns { newColumns = append(newColumns, scope.Quote(column)) } if len(columns) > 1 { return fmt.Sprintf("(%v)", strings.Join(newColumns, ",")) } return strings.Join(newColumns, ",") } func toQueryValues(primaryValues [][]interface{}) (values []interface{}) { for _, primaryValue := range primaryValues { for _, value := range primaryValue { values = append(values, value) } } return values }