Add user-portal mapping to database
This commit is contained in:
@ -42,7 +42,7 @@ func NewPortalKey(jid, receiver types.WhatsAppID) PortalKey {
|
||||
receiver = jid
|
||||
}
|
||||
return PortalKey{
|
||||
JID: jid,
|
||||
JID: jid,
|
||||
Receiver: receiver,
|
||||
}
|
||||
}
|
||||
@ -152,3 +152,26 @@ func (portal *Portal) Delete() {
|
||||
portal.log.Warnfln("Failed to delete %s: %v", portal.Key, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (portal *Portal) GetUserIDs() []types.MatrixUserID {
|
||||
rows, err := portal.db.Query(`SELECT "user".mxid FROM "user", user_portal
|
||||
WHERE "user".jid=user_portal.user_jid
|
||||
AND user_portal.portal_jid=$1
|
||||
AND user_portal.portal_receiver=$2`,
|
||||
portal.Key.JID, portal.Key.Receiver)
|
||||
if err != nil {
|
||||
portal.log.Debugln("Failed to get portal user ids:", err)
|
||||
return nil
|
||||
}
|
||||
var userIDs []types.MatrixUserID
|
||||
for rows.Next() {
|
||||
var userID types.MatrixUserID
|
||||
err = rows.Scan(&userID)
|
||||
if err != nil {
|
||||
portal.log.Warnln("Failed to scan row:", err)
|
||||
continue
|
||||
}
|
||||
userIDs = append(userIDs, userID)
|
||||
}
|
||||
return userIDs
|
||||
}
|
||||
|
19
database/upgrades/2019-05-28-user-portal-table.go
Normal file
19
database/upgrades/2019-05-28-user-portal-table.go
Normal file
@ -0,0 +1,19 @@
|
||||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func init() {
|
||||
upgrades[6] = upgrade{"Add user-portal mapping table", func(dialect Dialect, tx *sql.Tx, db *sql.DB) error {
|
||||
_, err := tx.Exec(`CREATE TABLE user_portal (
|
||||
user_jid VARCHAR(255),
|
||||
portal_jid VARCHAR(255),
|
||||
portal_receiver VARCHAR(255),
|
||||
PRIMARY KEY (user_jid, portal_jid, portal_receiver),
|
||||
FOREIGN KEY (user_jid) REFERENCES "user"(jid) ON DELETE CASCADE,
|
||||
FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE
|
||||
)`)
|
||||
return err
|
||||
}}
|
||||
}
|
@ -22,7 +22,7 @@ type upgrade struct {
|
||||
fn upgradeFunc
|
||||
}
|
||||
|
||||
const NumberOfUpgrades = 6
|
||||
const NumberOfUpgrades = 7
|
||||
|
||||
var upgrades [NumberOfUpgrades]upgrade
|
||||
|
||||
|
@ -18,6 +18,7 @@ package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@ -165,3 +166,50 @@ func (user *User) Update() {
|
||||
user.log.Warnfln("Failed to update %s: %v", user.MXID, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (user *User) SetPortalKeys(newKeys []PortalKey) error {
|
||||
tx, err := user.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec("DELETE FROM user_portal WHERE user_jid=$1", user.jidPtr())
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
return err
|
||||
}
|
||||
valueStrings := make([]string, len(newKeys))
|
||||
values := make([]interface{}, len(newKeys)*3)
|
||||
for i, key := range newKeys {
|
||||
valueStrings[i] = fmt.Sprintf("($%d, $%d, $%d)", i*3+1, i*3+2, i*3+3)
|
||||
values[i*3] = user.jidPtr()
|
||||
values[i*3+1] = key.JID
|
||||
values[i*3+2] = key.Receiver
|
||||
}
|
||||
query := fmt.Sprintf("INSERT INTO user_portal (user_jid, portal_jid, portal_receiver) VALUES %s",
|
||||
strings.Join(valueStrings, ", "))
|
||||
_, err = tx.Exec(query, values...)
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (user *User) GetPortalKeys() []PortalKey {
|
||||
rows, err := user.db.Query(`SELECT portal_jid, portal_receiver FROM user_portal WHERE user_jid=$1`, user.jidPtr())
|
||||
if err != nil {
|
||||
user.log.Warnln("Failed to get user portal keys:", err)
|
||||
return nil
|
||||
}
|
||||
var keys []PortalKey
|
||||
for rows.Next() {
|
||||
var key PortalKey
|
||||
err = rows.Scan(&key.JID, &key.Receiver)
|
||||
if err != nil {
|
||||
user.log.Warnln("Failed to scan row:", err)
|
||||
continue
|
||||
}
|
||||
keys = append(keys, key)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
Reference in New Issue
Block a user