2019-05-16 17:14:32 +00:00
|
|
|
package upgrades
|
|
|
|
|
|
|
|
import (
|
|
|
|
"database/sql"
|
|
|
|
"fmt"
|
|
|
|
"strings"
|
|
|
|
|
|
|
|
log "maunium.net/go/maulogger/v2"
|
|
|
|
)
|
|
|
|
|
|
|
|
type Dialect int
|
|
|
|
|
|
|
|
const (
|
|
|
|
Postgres Dialect = iota
|
|
|
|
SQLite
|
|
|
|
)
|
|
|
|
|
2019-08-25 14:25:19 +00:00
|
|
|
type upgradeFunc func(*sql.Tx, context) error
|
|
|
|
|
|
|
|
type context struct {
|
|
|
|
dialect Dialect
|
|
|
|
db *sql.DB
|
|
|
|
log log.Logger
|
|
|
|
}
|
2019-05-16 17:14:32 +00:00
|
|
|
|
|
|
|
type upgrade struct {
|
|
|
|
message string
|
2019-08-25 14:25:19 +00:00
|
|
|
fn upgradeFunc
|
2019-05-16 17:14:32 +00:00
|
|
|
}
|
|
|
|
|
2020-05-08 23:03:59 +00:00
|
|
|
const NumberOfUpgrades = 13
|
2019-05-23 23:33:26 +00:00
|
|
|
|
|
|
|
var upgrades [NumberOfUpgrades]upgrade
|
2019-05-16 17:14:32 +00:00
|
|
|
|
2019-05-28 18:29:43 +00:00
|
|
|
var UnsupportedDatabaseVersion = fmt.Errorf("unsupported database version")
|
|
|
|
|
2019-08-25 14:25:19 +00:00
|
|
|
func GetVersion(db *sql.DB) (int, error) {
|
2019-05-16 17:14:32 +00:00
|
|
|
_, err := db.Exec("CREATE TABLE IF NOT EXISTS version (version INTEGER)")
|
|
|
|
if err != nil {
|
|
|
|
return -1, err
|
|
|
|
}
|
|
|
|
|
|
|
|
version := 0
|
|
|
|
row := db.QueryRow("SELECT version FROM version LIMIT 1")
|
|
|
|
if row != nil {
|
|
|
|
_ = row.Scan(&version)
|
|
|
|
}
|
|
|
|
return version, nil
|
|
|
|
}
|
|
|
|
|
2019-08-25 14:25:19 +00:00
|
|
|
func SetVersion(tx *sql.Tx, version int) error {
|
2019-05-16 17:14:32 +00:00
|
|
|
_, err := tx.Exec("DELETE FROM version")
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
_, err = tx.Exec("INSERT INTO version (version) VALUES ($1)", version)
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
func Run(log log.Logger, dialectName string, db *sql.DB) error {
|
|
|
|
var dialect Dialect
|
|
|
|
switch strings.ToLower(dialectName) {
|
|
|
|
case "postgres":
|
|
|
|
dialect = Postgres
|
|
|
|
case "sqlite3":
|
|
|
|
dialect = SQLite
|
|
|
|
default:
|
|
|
|
return fmt.Errorf("unknown dialect %s", dialectName)
|
|
|
|
}
|
|
|
|
|
2019-08-25 14:25:19 +00:00
|
|
|
version, err := GetVersion(db)
|
2019-05-16 17:14:32 +00:00
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2019-05-28 18:29:43 +00:00
|
|
|
if version > NumberOfUpgrades {
|
|
|
|
return UnsupportedDatabaseVersion
|
|
|
|
}
|
|
|
|
|
2019-05-23 23:33:26 +00:00
|
|
|
log.Infofln("Database currently on v%d, latest: v%d", version, NumberOfUpgrades)
|
2019-05-16 17:14:32 +00:00
|
|
|
for i, upgrade := range upgrades[version:] {
|
2019-05-22 13:46:18 +00:00
|
|
|
log.Infofln("Upgrading database to v%d: %s", version+i+1, upgrade.message)
|
2019-05-16 17:14:32 +00:00
|
|
|
tx, err := db.Begin()
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
2019-08-25 14:25:19 +00:00
|
|
|
err = upgrade.fn(tx, context{dialect, db, log})
|
2019-05-16 17:14:32 +00:00
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
2019-08-25 14:25:19 +00:00
|
|
|
err = SetVersion(tx, version+i+1)
|
2019-05-16 17:14:32 +00:00
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
err = tx.Commit()
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|