package upgrades

import (
	"database/sql"
	"encoding/json"
	"fmt"
	"io/ioutil"
	"os"
	"strings"

	"maunium.net/go/mautrix/event"
)

func init() {
	migrateRegistrations := func(tx *sql.Tx, registrations map[string]bool) error {
		if len(registrations) == 0 {
			return nil
		}

		executeBatch := func(tx *sql.Tx, valueStrings []string, values ...interface{}) error {
			valueString := strings.Join(valueStrings, ",")
			_, err := tx.Exec("INSERT INTO mx_registrations (user_id) VALUES "+valueString, values...)
			return err
		}

		batchSize := 100
		values := make([]interface{}, 0, batchSize)
		valueStrings := make([]string, 0, batchSize)
		i := 1
		for userID, registered := range registrations {
			if i == batchSize {
				err := executeBatch(tx, valueStrings, values...)
				if err != nil {
					return err
				}
				i = 1
				values = make([]interface{}, 0, batchSize)
				valueStrings = make([]string, 0, batchSize)
			}
			if registered {
				values = append(values, userID)
				valueStrings = append(valueStrings, fmt.Sprintf("($%d)", i))
				i++
			}
		}
		return executeBatch(tx, valueStrings, values...)
	}

	migrateMemberships := func(tx *sql.Tx, rooms map[string]map[string]event.Membership) error {
		for roomID, members := range rooms {
			if len(members) == 0 {
				continue
			}
			var values []interface{}
			var valueStrings []string
			i := 1
			for userID, membership := range members {
				values = append(values, roomID, userID, membership)
				valueStrings = append(valueStrings, fmt.Sprintf("($%d, $%d, $%d)", i, i+1, i+2))
				i += 3
			}
			valueString := strings.Join(valueStrings, ",")
			_, err := tx.Exec("INSERT INTO mx_user_profile (room_id, user_id, membership) VALUES "+valueString, values...)
			if err != nil {
				return err
			}
		}
		return nil
	}

	migratePowerLevels := func(tx *sql.Tx, rooms map[string]*event.PowerLevelsEventContent) error {
		if len(rooms) == 0 {
			return nil
		}
		var values []interface{}
		var valueStrings []string
		i := 1
		for roomID, powerLevels := range rooms {
			powerLevelBytes, err := json.Marshal(powerLevels)
			if err != nil {
				return err
			}
			values = append(values, roomID, powerLevelBytes)
			valueStrings = append(valueStrings, fmt.Sprintf("($%d, $%d)", i, i+1))
			i += 2
		}
		valueString := strings.Join(valueStrings, ",")
		_, err := tx.Exec("INSERT INTO mx_room_state (room_id, power_levels) VALUES "+valueString, values...)
		return err
	}

	userProfileTable := `CREATE TABLE mx_user_profile (
		room_id     VARCHAR(255),
		user_id     VARCHAR(255),
		membership  VARCHAR(15) NOT NULL,
		PRIMARY KEY (room_id, user_id)
	)`

	roomStateTable := `CREATE TABLE mx_room_state (
		room_id      VARCHAR(255) PRIMARY KEY,
		power_levels TEXT
	)`

	registrationsTable := `CREATE TABLE mx_registrations (
		user_id VARCHAR(255) PRIMARY KEY
	)`

	type TempStateStore struct {
		Registrations map[string]bool                           `json:"registrations"`
		Members       map[string]map[string]event.Membership    `json:"memberships"`
		PowerLevels   map[string]*event.PowerLevelsEventContent `json:"power_levels"`
	}

	upgrades[9] = upgrade{"Move state store to main DB", func(tx *sql.Tx, ctx context) error {
		if ctx.dialect == Postgres {
			roomStateTable = strings.Replace(roomStateTable, "TEXT", "JSONB", 1)
		}

		var store TempStateStore
		if _, err := tx.Exec(userProfileTable); err != nil {
			return err
		} else if _, err = tx.Exec(roomStateTable); err != nil {
			return err
		} else if _, err = tx.Exec(registrationsTable); err != nil {
			return err
		} else if data, err := ioutil.ReadFile("mx-state.json"); err != nil {
			ctx.log.Debugln("mx-state.json not found, not migrating state store")
		} else if err = json.Unmarshal(data, &store); err != nil {
			return err
		} else if err = migrateRegistrations(tx, store.Registrations); err != nil {
			return err
		} else if err = migrateMemberships(tx, store.Members); err != nil {
			return err
		} else if err = migratePowerLevels(tx, store.PowerLevels); err != nil {
			return err
		} else if err = os.Rename("mx-state.json", "mx-state.json.bak"); err != nil {
			return err
		}
		return nil
	}}
}