package callbacks

import (
	"reflect"
	"strings"

	"gorm.io/gorm"
	"gorm.io/gorm/clause"
	"gorm.io/gorm/schema"
)

func SaveBeforeAssociations(db *gorm.DB) {
	if db.Error == nil && db.Statement.Schema != nil {
		selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false)

		// Save Belongs To associations
		for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
			if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
				continue
			}

			setupReferences := func(obj reflect.Value, elem reflect.Value) {
				for _, ref := range rel.References {
					if !ref.OwnPrimaryKey {
						pv, _ := ref.PrimaryKey.ValueOf(elem)
						db.AddError(ref.ForeignKey.Set(obj, pv))

						if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
							dest[ref.ForeignKey.DBName] = pv
							if _, ok := dest[rel.Name]; ok {
								dest[rel.Name] = elem.Interface()
							}
						}
					}
				}
			}

			switch db.Statement.ReflectValue.Kind() {
			case reflect.Slice, reflect.Array:
				var (
					objs      []reflect.Value
					fieldType = rel.Field.FieldType
					isPtr     = fieldType.Kind() == reflect.Ptr
				)

				if !isPtr {
					fieldType = reflect.PtrTo(fieldType)
				}

				elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
				for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
					obj := db.Statement.ReflectValue.Index(i)

					if reflect.Indirect(obj).Kind() == reflect.Struct {
						if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value
							rv := rel.Field.ReflectValueOf(obj) // relation reflect value
							objs = append(objs, obj)
							if isPtr {
								elems = reflect.Append(elems, rv)
							} else {
								elems = reflect.Append(elems, rv.Addr())
							}
						}
					} else {
						break
					}
				}

				if elems.Len() > 0 {
					if saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) == nil {
						for i := 0; i < elems.Len(); i++ {
							setupReferences(objs[i], elems.Index(i))
						}
					}
				}
			case reflect.Struct:
				if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
					rv := rel.Field.ReflectValueOf(db.Statement.ReflectValue) // relation reflect value
					if rv.Kind() != reflect.Ptr {
						rv = rv.Addr()
					}

					if saveAssociations(db, rel, rv.Interface(), selectColumns, restricted, nil) == nil {
						setupReferences(db.Statement.ReflectValue, rv)
					}
				}
			}
		}
	}
}

func SaveAfterAssociations(db *gorm.DB) {
	if db.Error == nil && db.Statement.Schema != nil {
		selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false)

		// Save Has One associations
		for _, rel := range db.Statement.Schema.Relationships.HasOne {
			if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
				continue
			}

			switch db.Statement.ReflectValue.Kind() {
			case reflect.Slice, reflect.Array:
				var (
					fieldType = rel.Field.FieldType
					isPtr     = fieldType.Kind() == reflect.Ptr
				)

				if !isPtr {
					fieldType = reflect.PtrTo(fieldType)
				}

				elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)

				for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
					obj := db.Statement.ReflectValue.Index(i)

					if reflect.Indirect(obj).Kind() == reflect.Struct {
						if _, zero := rel.Field.ValueOf(obj); !zero {
							rv := rel.Field.ReflectValueOf(obj)
							if rv.Kind() != reflect.Ptr {
								rv = rv.Addr()
							}

							for _, ref := range rel.References {
								if ref.OwnPrimaryKey {
									fv, _ := ref.PrimaryKey.ValueOf(obj)
									db.AddError(ref.ForeignKey.Set(rv, fv))
								} else if ref.PrimaryValue != "" {
									db.AddError(ref.ForeignKey.Set(rv, ref.PrimaryValue))
								}
							}

							elems = reflect.Append(elems, rv)
						}
					}
				}

				if elems.Len() > 0 {
					assignmentColumns := []string{}
					for _, ref := range rel.References {
						assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
					}

					saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns)
				}
			case reflect.Struct:
				if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
					f := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
					if f.Kind() != reflect.Ptr {
						f = f.Addr()
					}

					assignmentColumns := []string{}
					for _, ref := range rel.References {
						if ref.OwnPrimaryKey {
							fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue)
							ref.ForeignKey.Set(f, fv)
						} else if ref.PrimaryValue != "" {
							ref.ForeignKey.Set(f, ref.PrimaryValue)
						}
						assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
					}

					saveAssociations(db, rel, f.Interface(), selectColumns, restricted, assignmentColumns)
				}
			}
		}

		// Save Has Many associations
		for _, rel := range db.Statement.Schema.Relationships.HasMany {
			if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
				continue
			}

			fieldType := rel.Field.IndirectFieldType.Elem()
			isPtr := fieldType.Kind() == reflect.Ptr
			if !isPtr {
				fieldType = reflect.PtrTo(fieldType)
			}
			elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
			appendToElems := func(v reflect.Value) {
				if _, zero := rel.Field.ValueOf(v); !zero {
					f := reflect.Indirect(rel.Field.ReflectValueOf(v))

					for i := 0; i < f.Len(); i++ {
						elem := f.Index(i)
						for _, ref := range rel.References {
							if ref.OwnPrimaryKey {
								pv, _ := ref.PrimaryKey.ValueOf(v)
								ref.ForeignKey.Set(elem, pv)
							} else if ref.PrimaryValue != "" {
								ref.ForeignKey.Set(elem, ref.PrimaryValue)
							}
						}

						if isPtr {
							elems = reflect.Append(elems, elem)
						} else {
							elems = reflect.Append(elems, elem.Addr())
						}
					}
				}
			}

			switch db.Statement.ReflectValue.Kind() {
			case reflect.Slice, reflect.Array:
				for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
					obj := db.Statement.ReflectValue.Index(i)
					if reflect.Indirect(obj).Kind() == reflect.Struct {
						appendToElems(obj)
					}
				}
			case reflect.Struct:
				appendToElems(db.Statement.ReflectValue)
			}

			if elems.Len() > 0 {
				assignmentColumns := []string{}
				for _, ref := range rel.References {
					assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
				}

				saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns)
			}
		}

		// Save Many2Many associations
		for _, rel := range db.Statement.Schema.Relationships.Many2Many {
			if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
				continue
			}

			fieldType := rel.Field.IndirectFieldType.Elem()
			isPtr := fieldType.Kind() == reflect.Ptr
			if !isPtr {
				fieldType = reflect.PtrTo(fieldType)
			}
			elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
			joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10)
			objs := []reflect.Value{}

			appendToJoins := func(obj reflect.Value, elem reflect.Value) {
				joinValue := reflect.New(rel.JoinTable.ModelType)
				for _, ref := range rel.References {
					if ref.OwnPrimaryKey {
						fv, _ := ref.PrimaryKey.ValueOf(obj)
						ref.ForeignKey.Set(joinValue, fv)
					} else if ref.PrimaryValue != "" {
						ref.ForeignKey.Set(joinValue, ref.PrimaryValue)
					} else {
						fv, _ := ref.PrimaryKey.ValueOf(elem)
						ref.ForeignKey.Set(joinValue, fv)
					}
				}
				joins = reflect.Append(joins, joinValue)
			}

			appendToElems := func(v reflect.Value) {
				if _, zero := rel.Field.ValueOf(v); !zero {
					f := reflect.Indirect(rel.Field.ReflectValueOf(v))

					for i := 0; i < f.Len(); i++ {
						elem := f.Index(i)

						objs = append(objs, v)
						if isPtr {
							elems = reflect.Append(elems, elem)
						} else {
							elems = reflect.Append(elems, elem.Addr())
						}
					}
				}
			}

			switch db.Statement.ReflectValue.Kind() {
			case reflect.Slice, reflect.Array:
				for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
					obj := db.Statement.ReflectValue.Index(i)
					if reflect.Indirect(obj).Kind() == reflect.Struct {
						appendToElems(obj)
					}
				}
			case reflect.Struct:
				appendToElems(db.Statement.ReflectValue)
			}

			if elems.Len() > 0 {
				if v, ok := selectColumns[rel.Name+".*"]; !ok || v {
					saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil)
				}

				for i := 0; i < elems.Len(); i++ {
					appendToJoins(objs[i], elems.Index(i))
				}
			}

			if joins.Len() > 0 {
				db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(clause.OnConflict{DoNothing: true}).Session(&gorm.Session{
					SkipHooks:                db.Statement.SkipHooks,
					DisableNestedTransaction: true,
				}).Create(joins.Interface()).Error)
			}
		}
	}
}

func onConflictOption(stmt *gorm.Statement, s *schema.Schema, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) clause.OnConflict {
	if stmt.DB.FullSaveAssociations {
		defaultUpdatingColumns = make([]string, 0, len(s.DBNames))
		for _, dbName := range s.DBNames {
			if v, ok := selectColumns[dbName]; (ok && !v) || (!ok && restricted) {
				continue
			}

			if !s.LookUpField(dbName).PrimaryKey {
				defaultUpdatingColumns = append(defaultUpdatingColumns, dbName)
			}
		}
	}

	if len(defaultUpdatingColumns) > 0 {
		var columns []clause.Column
		for _, dbName := range s.PrimaryFieldDBNames {
			columns = append(columns, clause.Column{Name: dbName})
		}

		return clause.OnConflict{
			Columns:   columns,
			DoUpdates: clause.AssignmentColumns(defaultUpdatingColumns),
		}
	}

	return clause.OnConflict{DoNothing: true}
}

func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error {
	var (
		selects, omits []string
		onConflict     = onConflictOption(db.Statement, rel.FieldSchema, selectColumns, restricted, defaultUpdatingColumns)
		refName        = rel.Name + "."
	)

	for name, ok := range selectColumns {
		columnName := ""
		if strings.HasPrefix(name, refName) {
			columnName = strings.TrimPrefix(name, refName)
		} else if strings.HasPrefix(name, clause.Associations) {
			columnName = name
		}

		if columnName != "" {
			if ok {
				selects = append(selects, columnName)
			} else {
				omits = append(omits, columnName)
			}
		}
	}

	tx := db.Session(&gorm.Session{NewDB: true}).Clauses(onConflict).Session(&gorm.Session{
		SkipHooks:                db.Statement.SkipHooks,
		DisableNestedTransaction: true,
	})

	db.Statement.Settings.Range(func(k, v interface{}) bool {
		tx.Statement.Settings.Store(k, v)
		return true
	})

	if len(selects) > 0 {
		tx = tx.Select(selects)
	}

	if len(omits) > 0 {
		tx = tx.Omit(omits...)
	}

	return db.AddError(tx.Create(values).Error)
}