Merge pull request #43 from RennerDev/master

Implemented postgres
This commit is contained in:
Tulir Asokan 2019-03-14 00:37:00 +02:00 committed by GitHub
commit 67a041c06d
7 changed files with 90 additions and 56 deletions

View File

@ -19,6 +19,7 @@ package database
import (
"database/sql"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
log "maunium.net/go/maulogger/v2"
@ -34,8 +35,8 @@ type Database struct {
Message *MessageQuery
}
func New(file string) (*Database, error) {
conn, err := sql.Open("sqlite3", file)
func New(dbType string, uri string) (*Database, error) {
conn, err := sql.Open(dbType, uri)
if err != nil {
return nil, err
}
@ -63,20 +64,20 @@ func New(file string) (*Database, error) {
return db, nil
}
func (db *Database) CreateTables() error {
err := db.User.CreateTable()
func (db *Database) CreateTables(dbType string) error {
err := db.User.CreateTable(dbType)
if err != nil {
return err
}
err = db.Portal.CreateTable()
err = db.Portal.CreateTable(dbType)
if err != nil {
return err
}
err = db.Puppet.CreateTable()
err = db.Puppet.CreateTable(dbType)
if err != nil {
return err
}
err = db.Message.CreateTable()
err = db.Message.CreateTable(dbType)
if err != nil {
return err
}

View File

@ -18,6 +18,7 @@ package database
import (
"bytes"
"strings"
"database/sql"
"encoding/json"
@ -33,19 +34,34 @@ type MessageQuery struct {
log log.Logger
}
func (mq *MessageQuery) CreateTable() error {
_, err := mq.db.Exec(`CREATE TABLE IF NOT EXISTS message (
chat_jid VARCHAR(25),
chat_receiver VARCHAR(25),
jid VARCHAR(255),
mxid VARCHAR(255) NOT NULL UNIQUE,
sender VARCHAR(25) NOT NULL,
content BLOB NOT NULL,
func (mq *MessageQuery) CreateTable(dbType string) error {
if strings.ToLower(dbType) == "postgres" {
_, err := mq.db.Exec(`CREATE TABLE IF NOT EXISTS message (
chat_jid VARCHAR(255),
chat_receiver VARCHAR(255),
jid VARCHAR(255),
mxid VARCHAR(255) NOT NULL UNIQUE,
sender VARCHAR(255) NOT NULL,
content bytea NOT NULL,
PRIMARY KEY (chat_jid, chat_receiver, jid),
FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver)
)`)
PRIMARY KEY (chat_jid, chat_receiver, jid),
FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver)
)`)
return err
} else {
_, err := mq.db.Exec(`CREATE TABLE IF NOT EXISTS message (
chat_jid VARCHAR(255),
chat_receiver VARCHAR(255),
jid VARCHAR(255),
mxid VARCHAR(255) NOT NULL UNIQUE,
sender VARCHAR(255) NOT NULL,
content BLOB NOT NULL,
PRIMARY KEY (chat_jid, chat_receiver, jid),
FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver)
)`)
return err
}
}
func (mq *MessageQuery) New() *Message {
@ -56,7 +72,7 @@ func (mq *MessageQuery) New() *Message {
}
func (mq *MessageQuery) GetAll(chat PortalKey) (messages []*Message) {
rows, err := mq.db.Query("SELECT * FROM message WHERE chat_jid=? AND chat_receiver=?", chat.JID, chat.Receiver)
rows, err := mq.db.Query("SELECT * FROM message WHERE chat_jid=$1 AND chat_receiver=$2", chat.JID, chat.Receiver)
if err != nil || rows == nil {
return nil
}
@ -68,11 +84,11 @@ func (mq *MessageQuery) GetAll(chat PortalKey) (messages []*Message) {
}
func (mq *MessageQuery) GetByJID(chat PortalKey, jid types.WhatsAppMessageID) *Message {
return mq.get("SELECT * FROM message WHERE chat_jid=? AND chat_receiver=? AND jid=?", chat.JID, chat.Receiver, jid)
return mq.get("SELECT * FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3", chat.JID, chat.Receiver, jid)
}
func (mq *MessageQuery) GetByMXID(mxid types.MatrixEventID) *Message {
return mq.get("SELECT * FROM message WHERE mxid=?", mxid)
return mq.get("SELECT * FROM message WHERE mxid=$1", mxid)
}
func (mq *MessageQuery) get(query string, args ...interface{}) *Message {
@ -130,7 +146,7 @@ func (msg *Message) encodeBinaryContent() []byte {
}
func (msg *Message) Insert() {
_, err := msg.db.Exec("INSERT INTO message VALUES (?, ?, ?, ?, ?, ?)", msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, msg.Sender, msg.encodeBinaryContent())
_, err := msg.db.Exec("INSERT INTO message VALUES ($1, $2, $3, $4, $5, $6)", msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, msg.Sender, msg.encodeBinaryContent())
if err != nil {
msg.log.Warnfln("Failed to insert %s@%s: %v", msg.Chat, msg.JID, err)
}

View File

@ -59,18 +59,17 @@ type PortalQuery struct {
log log.Logger
}
func (pq *PortalQuery) CreateTable() error {
func (pq *PortalQuery) CreateTable(dbType string) error {
_, err := pq.db.Exec(`CREATE TABLE IF NOT EXISTS portal (
jid VARCHAR(25),
receiver VARCHAR(25),
jid VARCHAR(255),
receiver VARCHAR(255),
mxid VARCHAR(255) UNIQUE,
name VARCHAR(255) NOT NULL,
topic VARCHAR(255) NOT NULL,
avatar VARCHAR(255) NOT NULL,
PRIMARY KEY (jid, receiver),
FOREIGN KEY (receiver) REFERENCES user(mxid)
PRIMARY KEY (jid, receiver)
)`)
return err
}
@ -95,11 +94,11 @@ func (pq *PortalQuery) GetAll() (portals []*Portal) {
}
func (pq *PortalQuery) GetByJID(key PortalKey) *Portal {
return pq.get("SELECT * FROM portal WHERE jid=? AND receiver=?", key.JID, key.Receiver)
return pq.get("SELECT * FROM portal WHERE jid=$1 AND receiver=$2", key.JID, key.Receiver)
}
func (pq *PortalQuery) GetByMXID(mxid types.MatrixRoomID) *Portal {
return pq.get("SELECT * FROM portal WHERE mxid=?", mxid)
return pq.get("SELECT * FROM portal WHERE mxid=$1", mxid)
}
func (pq *PortalQuery) get(query string, args ...interface{}) *Portal {
@ -143,7 +142,7 @@ func (portal *Portal) mxidPtr() *string {
}
func (portal *Portal) Insert() {
_, err := portal.db.Exec("INSERT INTO portal VALUES (?, ?, ?, ?, ?, ?)",
_, err := portal.db.Exec("INSERT INTO portal VALUES ($1, $2, $3, $4, $5, $6)",
portal.Key.JID, portal.Key.Receiver, portal.mxidPtr(), portal.Name, portal.Topic, portal.Avatar)
if err != nil {
portal.log.Warnfln("Failed to insert %s: %v", portal.Key, err)
@ -155,7 +154,7 @@ func (portal *Portal) Update() {
if len(portal.MXID) > 0 {
mxid = &portal.MXID
}
_, err := portal.db.Exec("UPDATE portal SET mxid=?, name=?, topic=?, avatar=? WHERE jid=? AND receiver=?",
_, err := portal.db.Exec("UPDATE portal SET mxid=$1, name=$2, topic=$3, avatar=$4 WHERE jid=$5 AND receiver=$6",
mxid, portal.Name, portal.Topic, portal.Avatar, portal.Key.JID, portal.Key.Receiver)
if err != nil {
portal.log.Warnfln("Failed to update %s: %v", portal.Key, err)

View File

@ -29,12 +29,12 @@ type PuppetQuery struct {
log log.Logger
}
func (pq *PuppetQuery) CreateTable() error {
func (pq *PuppetQuery) CreateTable(dbType string) error {
_, err := pq.db.Exec(`CREATE TABLE IF NOT EXISTS puppet (
jid VARCHAR(25) PRIMARY KEY,
jid VARCHAR(255) PRIMARY KEY,
avatar VARCHAR(255),
displayname VARCHAR(255),
name_quality TINYINT
name_quality SMALLINT
)`)
return err
}
@ -59,7 +59,7 @@ func (pq *PuppetQuery) GetAll() (puppets []*Puppet) {
}
func (pq *PuppetQuery) Get(jid types.WhatsAppID) *Puppet {
row := pq.db.QueryRow("SELECT * FROM puppet WHERE jid=?", jid)
row := pq.db.QueryRow("SELECT * FROM puppet WHERE jid=$1", jid)
if row == nil {
return nil
}
@ -93,7 +93,7 @@ func (puppet *Puppet) Scan(row Scannable) *Puppet {
}
func (puppet *Puppet) Insert() {
_, err := puppet.db.Exec("INSERT INTO puppet VALUES (?, ?, ?, ?)",
_, err := puppet.db.Exec("INSERT INTO puppet VALUES ($1, $2, $3, $4)",
puppet.JID, puppet.Avatar, puppet.Displayname, puppet.NameQuality)
if err != nil {
puppet.log.Warnfln("Failed to insert %s: %v", puppet.JID, err)
@ -101,7 +101,7 @@ func (puppet *Puppet) Insert() {
}
func (puppet *Puppet) Update() {
_, err := puppet.db.Exec("UPDATE puppet SET displayname=?, name_quality=?, avatar=? WHERE jid=?",
_, err := puppet.db.Exec("UPDATE puppet SET displayname=$1, name_quality=$2, avatar=$3 WHERE jid=$4",
puppet.Displayname, puppet.NameQuality, puppet.Avatar, puppet.JID)
if err != nil {
puppet.log.Warnfln("Failed to update %s->%s: %v", puppet.JID, err)

View File

@ -33,20 +33,36 @@ type UserQuery struct {
log log.Logger
}
func (uq *UserQuery) CreateTable() error {
_, err := uq.db.Exec(`CREATE TABLE IF NOT EXISTS user (
mxid VARCHAR(255) PRIMARY KEY,
jid VARCHAR(25) UNIQUE,
func (uq *UserQuery) CreateTable(dbType string) error {
if strings.ToLower(dbType) == "postgres" {
_, err := uq.db.Exec(`CREATE TABLE IF NOT EXISTS "user" (
mxid VARCHAR(255) PRIMARY KEY,
jid VARCHAR(255) UNIQUE,
management_room VARCHAR(255),
management_room VARCHAR(255),
client_id VARCHAR(255),
client_token VARCHAR(255),
server_token VARCHAR(255),
enc_key BLOB,
mac_key BLOB
)`)
return err
client_id VARCHAR(255),
client_token VARCHAR(255),
server_token VARCHAR(255),
enc_key bytea,
mac_key bytea
)`)
return err
} else {
_, err := uq.db.Exec(`CREATE TABLE IF NOT EXISTS "user" (
mxid VARCHAR(255) PRIMARY KEY,
jid VARCHAR(255) UNIQUE,
management_room VARCHAR(255),
client_id VARCHAR(255),
client_token VARCHAR(255),
server_token VARCHAR(255),
enc_key BLOB,
mac_key BLOB
)`)
return err
}
}
func (uq *UserQuery) New() *User {
@ -57,7 +73,7 @@ func (uq *UserQuery) New() *User {
}
func (uq *UserQuery) GetAll() (users []*User) {
rows, err := uq.db.Query("SELECT * FROM user")
rows, err := uq.db.Query(`SELECT * FROM "user"`)
if err != nil || rows == nil {
return nil
}
@ -69,7 +85,7 @@ func (uq *UserQuery) GetAll() (users []*User) {
}
func (uq *UserQuery) GetByMXID(userID types.MatrixUserID) *User {
row := uq.db.QueryRow("SELECT * FROM user WHERE mxid=?", userID)
row := uq.db.QueryRow(`SELECT * FROM "user" WHERE mxid=$1`, userID)
if row == nil {
return nil
}
@ -77,7 +93,7 @@ func (uq *UserQuery) GetByMXID(userID types.MatrixUserID) *User {
}
func (uq *UserQuery) GetByJID(userID types.WhatsAppID) *User {
row := uq.db.QueryRow("SELECT * FROM user WHERE jid=?", stripSuffix(userID))
row := uq.db.QueryRow(`SELECT * FROM "user" WHERE jid=$1`, stripSuffix(userID))
if row == nil {
return nil
}
@ -150,7 +166,7 @@ func (user *User) sessionUnptr() (sess whatsapp.Session) {
func (user *User) Insert() {
sess := user.sessionUnptr()
_, err := user.db.Exec("INSERT INTO user VALUES (?, ?, ?, ?, ?, ?, ?, ?)", user.MXID, user.jidPtr(),
_, err := user.db.Exec(`INSERT INTO "user" VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`, user.MXID, user.jidPtr(),
user.ManagementRoom,
sess.ClientId, sess.ClientToken, sess.ServerToken, sess.EncKey, sess.MacKey)
if err != nil {
@ -160,7 +176,7 @@ func (user *User) Insert() {
func (user *User) Update() {
sess := user.sessionUnptr()
_, err := user.db.Exec("UPDATE user SET jid=?, management_room=?, client_id=?, client_token=?, server_token=?, enc_key=?, mac_key=? WHERE mxid=?",
_, err := user.db.Exec(`UPDATE "user" SET jid=$1, management_room=$2, client_id=$3, client_token=$4, server_token=$5, enc_key=$6, mac_key=$7 WHERE mxid=$8`,
user.jidPtr(), user.ManagementRoom,
sess.ClientId, sess.ClientToken, sess.ServerToken, sess.EncKey, sess.MacKey,
user.MXID)

View File

@ -20,6 +20,7 @@ appservice:
# The database type. Only "sqlite3" is supported.
type: sqlite3
# The database URI. Usually file name. https://github.com/mattn/go-sqlite3#connection-string
# postres example: postgres://synapse:changeme@db/whatsapp?sslmode=disable
uri: mautrix-whatsapp.db
# Path to the Matrix room state store.
state_store_path: ./mx-state.json

View File

@ -133,7 +133,7 @@ func (bridge *Bridge) Init() {
bridge.AS.StateStore = bridge.StateStore
bridge.Log.Debugln("Initializing database")
bridge.DB, err = database.New(bridge.Config.AppService.Database.URI)
bridge.DB, err = database.New(bridge.Config.AppService.Database.Type, bridge.Config.AppService.Database.URI)
if err != nil {
bridge.Log.Fatalln("Failed to initialize database:", err)
os.Exit(14)
@ -147,7 +147,7 @@ func (bridge *Bridge) Init() {
}
func (bridge *Bridge) Start() {
err := bridge.DB.CreateTables()
err := bridge.DB.CreateTables(bridge.Config.AppService.Database.Type)
if err != nil {
bridge.Log.Fatalln("Failed to create database tables:", err)
os.Exit(15)
@ -185,6 +185,7 @@ func (bridge *Bridge) UpdateBotProfile() {
}
func (bridge *Bridge) StartUsers() {
bridge.Log.Debugln("Starting users")
for _, user := range bridge.GetAllUsers() {
go user.Connect(false)
}