BIG ASS COMMIT

This commit is contained in:
Karmanyaah Malhotra
2021-02-13 00:53:35 -05:00
parent 95f6487912
commit eafc18099d
47 changed files with 3412 additions and 3240 deletions

View File

@ -1,10 +1,11 @@
package upgrades
import (
"database/sql"
"errors"
"fmt"
"strings"
"gorm.io/gorm"
log "maunium.net/go/maulogger/v2"
)
@ -26,11 +27,11 @@ func (dialect Dialect) String() string {
}
}
type upgradeFunc func(*sql.Tx, context) error
type upgradeFunc func(*gorm.DB, context) error
type context struct {
dialect Dialect
db *sql.DB
db *gorm.DB
log log.Logger
}
@ -39,36 +40,45 @@ type upgrade struct {
fn upgradeFunc
}
const NumberOfUpgrades = 20
type version struct {
gorm.Model
V int
}
const NumberOfUpgrades = 1
var upgrades [NumberOfUpgrades]upgrade
var UnsupportedDatabaseVersion = fmt.Errorf("unsupported database version")
func GetVersion(db *sql.DB) (int, error) {
_, err := db.Exec("CREATE TABLE IF NOT EXISTS version (version INTEGER)")
if err != nil {
return -1, err
}
func GetVersion(db *gorm.DB) (int, error) {
var ver = version{V: 0}
result := db.FirstOrCreate(&ver, &ver)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) ||
errors.Is(result.Error, gorm.ErrInvalidField) {
db.Create(&ver)
print("create version")
version := 0
row := db.QueryRow("SELECT version FROM version LIMIT 1")
if row != nil {
_ = row.Scan(&version)
} else {
return 0, result.Error
}
}
return version, nil
return int(ver.V), nil
}
func SetVersion(tx *sql.Tx, version int) error {
_, err := tx.Exec("DELETE FROM version")
if err != nil {
return err
func SetVersion(tx *gorm.DB, newVersion int) error {
err := tx.Where("v IS NOT NULL").Delete(&version{})
if err.Error != nil {
return err.Error
}
_, err = tx.Exec("INSERT INTO version (version) VALUES ($1)", version)
return err
val := version{V: newVersion}
tx = tx.Create(&val)
return tx.Error
}
func Run(log log.Logger, dialectName string, db *sql.DB) error {
func Run(log log.Logger, dialectName string, db *gorm.DB) error {
var dialect Dialect
switch strings.ToLower(dialectName) {
case "postgres":
@ -79,7 +89,9 @@ func Run(log log.Logger, dialectName string, db *sql.DB) error {
return fmt.Errorf("unknown dialect %s", dialectName)
}
db.AutoMigrate(&version{})
version, err := GetVersion(db)
if err != nil {
return err
}
@ -91,22 +103,21 @@ func Run(log log.Logger, dialectName string, db *sql.DB) error {
log.Infofln("Database currently on v%d, latest: v%d", version, NumberOfUpgrades)
for i, upgrade := range upgrades[version:] {
log.Infofln("Upgrading database to v%d: %s", version+i+1, upgrade.message)
tx, err := db.Begin()
if err != nil {
return err
}
err = upgrade.fn(tx, context{dialect, db, log})
if err != nil {
return err
}
err = SetVersion(tx, version+i+1)
if err != nil {
return err
}
err = tx.Commit()
err = db.Transaction(func(tx *gorm.DB) error {
err = upgrade.fn(tx, context{dialect, db, log})
if err != nil {
return err
}
err = SetVersion(tx, version+i+1)
if err != nil {
return err
}
return nil
})
if err != nil {
return err
}
}
return nil
}