Add basic relaybot support. Fixes #20

This commit is contained in:
Tulir Asokan
2019-11-10 21:22:11 +02:00
parent e2d9e2fc57
commit 03d42640fe
14 changed files with 356 additions and 89 deletions

View File

@ -70,23 +70,23 @@ func (store *SQLStateStore) MarkRegistered(userID string) {
}
}
func (store *SQLStateStore) GetRoomMemberships(roomID string) map[string]mautrix.Membership {
memberships := make(map[string]mautrix.Membership)
rows, err := store.db.Query("SELECT user_id, membership FROM mx_user_profile WHERE room_id=$1", roomID)
func (store *SQLStateStore) GetRoomMembers(roomID string) map[string]mautrix.Member {
members := make(map[string]mautrix.Member)
rows, err := store.db.Query("SELECT user_id, membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1", roomID)
if err != nil {
return memberships
return members
}
var userID string
var membership mautrix.Membership
var member mautrix.Member
for rows.Next() {
err := rows.Scan(&userID, &membership)
err := rows.Scan(&userID, &member.Membership, &member.Displayname, &member.AvatarURL)
if err != nil {
store.log.Warnfln("Failed to scan membership in %s: %v", roomID, err)
store.log.Warnfln("Failed to scan member in %s: %v", roomID, err)
} else {
memberships[userID] = membership
members[userID] = member
}
}
return memberships
return members
}
func (store *SQLStateStore) GetMembership(roomID, userID string) mautrix.Membership {
@ -99,6 +99,24 @@ func (store *SQLStateStore) GetMembership(roomID, userID string) mautrix.Members
return membership
}
func (store *SQLStateStore) GetMember(roomID, userID string) mautrix.Member {
member, ok := store.TryGetMember(roomID, userID)
if !ok {
member.Membership = mautrix.MembershipLeave
}
return member
}
func (store *SQLStateStore) TryGetMember(roomID, userID string) (mautrix.Member, bool) {
row := store.db.QueryRow("SELECT membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID)
var member mautrix.Member
err := row.Scan(&member.Membership, &member.Displayname, &member.AvatarURL)
if err != nil && err != sql.ErrNoRows {
store.log.Warnfln("Failed to scan member info of %s in %s: %v", userID, roomID, err)
}
return member, err == nil
}
func (store *SQLStateStore) IsInRoom(roomID, userID string) bool {
return store.IsMembership(roomID, userID, "join")
}
@ -116,6 +134,7 @@ func (store *SQLStateStore) IsMembership(roomID, userID string, allowedMembershi
}
return false
}
func (store *SQLStateStore) SetMembership(roomID, userID string, membership mautrix.Membership) {
var err error
if store.db.dialect == "postgres" {
@ -131,6 +150,22 @@ func (store *SQLStateStore) SetMembership(roomID, userID string, membership maut
}
}
func (store *SQLStateStore) SetMember(roomID, userID string, member mautrix.Member) {
var err error
if store.db.dialect == "postgres" {
_, err = store.db.Exec(`INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (room_id, user_id) DO UPDATE SET membership=$3`, roomID, userID, member.Membership, member.Displayname, member.AvatarURL)
} else if store.db.dialect == "sqlite3" {
_, err = store.db.Exec("INSERT OR REPLACE INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) VALUES ($1, $2, $3, $4, $5)",
roomID, userID, member.Membership, member.Displayname, member.AvatarURL)
} else {
err = fmt.Errorf("unsupported dialect %s", store.db.dialect)
}
if err != nil {
store.log.Warnfln("Failed to set membership of %s in %s to %s: %v", userID, roomID, member, err)
}
}
func (store *SQLStateStore) SetPowerLevels(roomID string, levels *mautrix.PowerLevels) {
levelsBytes, err := json.Marshal(levels)
if err != nil {

View File

@ -47,7 +47,7 @@ func init() {
return executeBatch(tx, valueStrings, values...)
}
migrateMemberships := func(tx *sql.Tx, rooms map[string]map[string]mautrix.Membership) error {
migrateMemberships := func(tx *sql.Tx, rooms map[string]map[string]mautrix.Member) error {
for roomID, members := range rooms {
if len(members) == 0 {
continue
@ -125,7 +125,7 @@ func init() {
return err
} else if err = migrateRegistrations(tx, store.Registrations); err != nil {
return err
} else if err = migrateMemberships(tx, store.Memberships); err != nil {
} else if err = migrateMemberships(tx, store.Members); err != nil {
return err
} else if err = migratePowerLevels(tx, store.PowerLevels); err != nil {
return err

View File

@ -0,0 +1,16 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[10] = upgrade{"Add columns to store full member info in state store", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`ALTER TABLE mx_user_profile ADD COLUMN displayname TEXT`)
if err != nil {
return err
}
_, err = tx.Exec(`ALTER TABLE mx_user_profile ADD COLUMN avatar_url VARCHAR(255)`)
return err
}}
}

View File

@ -28,7 +28,7 @@ type upgrade struct {
fn upgradeFunc
}
const NumberOfUpgrades = 10
const NumberOfUpgrades = 11
var upgrades [NumberOfUpgrades]upgrade

View File

@ -201,6 +201,13 @@ func (user *User) SetPortalKeys(newKeys []PortalKeyWithMeta) error {
return tx.Commit()
}
func (user *User) IsInPortal(jid types.WhatsAppID) bool {
row := user.db.QueryRow(`SELECT portal_jid, portal_receiver FROM user_portal WHERE user_jid=$1 AND portal_jid=$2 AND (portal_receiver=$1 OR portal_receiver=$2)`, user.jidPtr(), &jid)
var scanJid, scanReceiver types.WhatsAppID
_ = row.Scan(&scanJid, &scanReceiver)
return scanJid == jid && (scanReceiver == jid || scanReceiver == user.JID)
}
func (user *User) GetPortalKeys() []PortalKey {
rows, err := user.db.Query(`SELECT portal_jid, portal_receiver FROM user_portal WHERE user_jid=$1`, user.jidPtr())
if err != nil {