package gormx

import (
	"fmt"
	"strings"

	"github.com/glebarez/sqlite"
	"github.com/spf13/cobra"
	"gitlab.xaotos.cn/qtt/acmin/pkg/util/env"
	"gorm.io/driver/mysql"
	"gorm.io/gorm"
	"gorm.io/gorm/logger"
	"gorm.io/gorm/schema"
)

const (
	FlagDriver      = "gorm-driver"
	FlagDsn         = "gorm-dsn"
	FlagTablePrefix = "gorm-table-prefix"
	FlagPrepareStmt = "gorm-prepare-stmt"
	FlagAutoMigrate = "gorm-auto-migrate"
	FlagDebug       = "gorm-debug"
)

type Options struct {
	FlagPrefix string
}

func getFlagVal[T any](cmd *cobra.Command, opt Options, name string, getVal func(name string) (T, error)) func() T {
	return func() T {
		var (
			res T
			err error
		)
		if opt.FlagPrefix != "" && cmd.Flags().Changed(opt.FlagPrefix+name) {
			res, err = getVal(opt.FlagPrefix + name)
		} else {
			res, err = getVal(name)
		}
		if err != nil {
			panic(err)
		}
		return res
	}
}

func (o Options) GetEnvPrefix() string {
	return strings.ReplaceAll(strings.ToUpper(o.FlagPrefix), "-", "_")
}

func AddGormDBFlagsToCmd(cmd *cobra.Command, opt Options) func() (*gorm.DB, error) {
	// flag vars
	var (
		driver       string
		dsn          string
		tablerPrefix string
		debug        bool
	)

	cmd.Flags().StringVar(&driver, opt.FlagPrefix+FlagDriver, env.StringFromEnv(opt.GetEnvPrefix()+"GORM_DRIVER", ""), "")
	driverFn := getFlagVal(cmd, opt, FlagDriver, cmd.Flags().GetString)

	cmd.Flags().StringVar(&dsn, opt.FlagPrefix+FlagDsn, env.StringFromEnv(opt.GetEnvPrefix()+"GORM_DSN", ""), "")
	dsnFn := getFlagVal(cmd, opt, FlagDsn, cmd.Flags().GetString)

	cmd.Flags().StringVar(&tablerPrefix, opt.FlagPrefix+FlagTablePrefix, env.StringFromEnv(opt.GetEnvPrefix()+"GORM_TABLE_PREFIX", ""), "")
	tablerPrefixFn := getFlagVal(cmd, opt, FlagTablePrefix, cmd.Flags().GetString)

	cmd.Flags().BoolVar(&debug, opt.FlagPrefix+FlagDebug, env.ParseBoolFromEnv(opt.GetEnvPrefix()+"GORM_DEBUG", false), "")
	debugFn := getFlagVal(cmd, opt, FlagDebug, cmd.Flags().GetBool)

	cmd.Flags().BoolVar(&debug, opt.FlagPrefix+FlagPrepareStmt, env.ParseBoolFromEnv(opt.GetEnvPrefix()+"GORM_PREPARE_STMT", false), "")
	prepareStmtFn := getFlagVal(cmd, opt, FlagPrepareStmt, cmd.Flags().GetBool)

	cmd.Flags().BoolVar(&debug, opt.FlagPrefix+FlagAutoMigrate, env.ParseBoolFromEnv(opt.GetEnvPrefix()+"GORM_AUTO_MIGRATE", false), "")
	autoMigrateFn := getFlagVal(cmd, opt, FlagAutoMigrate, cmd.Flags().GetBool)

	return func() (*gorm.DB, error) {
		return NewGorm(driverFn(), dsnFn(), tablerPrefixFn(), prepareStmtFn(), autoMigrateFn(), debugFn())
	}
}

func NewGorm(driver, dsn, tablePrefix string, prepareStmt, autoMigrate, debug bool) (*gorm.DB, error) {
	var dialector gorm.Dialector
	switch driver {
	case "mysql":
		dialector = mysql.Open(dsn)
	case "sqlite3":
		dialector = sqlite.Open(dsn)
	default:
		return nil, fmt.Errorf("unsupported driver type: %s", driver)
	}

	cfg := &gorm.Config{
		NamingStrategy: schema.NamingStrategy{
			TablePrefix:   tablePrefix,
			SingularTable: true,
		},
		PrepareStmt: prepareStmt,
		Logger:      logger.Discard,
	}

	if debug {
		cfg.Logger = logger.Default
	}

	return gorm.Open(dialector, cfg)
}
