Add basic relaybot support. Fixes #20
This commit is contained in:
@ -70,23 +70,23 @@ func (store *SQLStateStore) MarkRegistered(userID string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) GetRoomMemberships(roomID string) map[string]mautrix.Membership {
|
||||
memberships := make(map[string]mautrix.Membership)
|
||||
rows, err := store.db.Query("SELECT user_id, membership FROM mx_user_profile WHERE room_id=$1", roomID)
|
||||
func (store *SQLStateStore) GetRoomMembers(roomID string) map[string]mautrix.Member {
|
||||
members := make(map[string]mautrix.Member)
|
||||
rows, err := store.db.Query("SELECT user_id, membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1", roomID)
|
||||
if err != nil {
|
||||
return memberships
|
||||
return members
|
||||
}
|
||||
var userID string
|
||||
var membership mautrix.Membership
|
||||
var member mautrix.Member
|
||||
for rows.Next() {
|
||||
err := rows.Scan(&userID, &membership)
|
||||
err := rows.Scan(&userID, &member.Membership, &member.Displayname, &member.AvatarURL)
|
||||
if err != nil {
|
||||
store.log.Warnfln("Failed to scan membership in %s: %v", roomID, err)
|
||||
store.log.Warnfln("Failed to scan member in %s: %v", roomID, err)
|
||||
} else {
|
||||
memberships[userID] = membership
|
||||
members[userID] = member
|
||||
}
|
||||
}
|
||||
return memberships
|
||||
return members
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) GetMembership(roomID, userID string) mautrix.Membership {
|
||||
@ -99,6 +99,24 @@ func (store *SQLStateStore) GetMembership(roomID, userID string) mautrix.Members
|
||||
return membership
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) GetMember(roomID, userID string) mautrix.Member {
|
||||
member, ok := store.TryGetMember(roomID, userID)
|
||||
if !ok {
|
||||
member.Membership = mautrix.MembershipLeave
|
||||
}
|
||||
return member
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) TryGetMember(roomID, userID string) (mautrix.Member, bool) {
|
||||
row := store.db.QueryRow("SELECT membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID)
|
||||
var member mautrix.Member
|
||||
err := row.Scan(&member.Membership, &member.Displayname, &member.AvatarURL)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
store.log.Warnfln("Failed to scan member info of %s in %s: %v", userID, roomID, err)
|
||||
}
|
||||
return member, err == nil
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) IsInRoom(roomID, userID string) bool {
|
||||
return store.IsMembership(roomID, userID, "join")
|
||||
}
|
||||
@ -116,6 +134,7 @@ func (store *SQLStateStore) IsMembership(roomID, userID string, allowedMembershi
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) SetMembership(roomID, userID string, membership mautrix.Membership) {
|
||||
var err error
|
||||
if store.db.dialect == "postgres" {
|
||||
@ -131,6 +150,22 @@ func (store *SQLStateStore) SetMembership(roomID, userID string, membership maut
|
||||
}
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) SetMember(roomID, userID string, member mautrix.Member) {
|
||||
var err error
|
||||
if store.db.dialect == "postgres" {
|
||||
_, err = store.db.Exec(`INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (room_id, user_id) DO UPDATE SET membership=$3`, roomID, userID, member.Membership, member.Displayname, member.AvatarURL)
|
||||
} else if store.db.dialect == "sqlite3" {
|
||||
_, err = store.db.Exec("INSERT OR REPLACE INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) VALUES ($1, $2, $3, $4, $5)",
|
||||
roomID, userID, member.Membership, member.Displayname, member.AvatarURL)
|
||||
} else {
|
||||
err = fmt.Errorf("unsupported dialect %s", store.db.dialect)
|
||||
}
|
||||
if err != nil {
|
||||
store.log.Warnfln("Failed to set membership of %s in %s to %s: %v", userID, roomID, member, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) SetPowerLevels(roomID string, levels *mautrix.PowerLevels) {
|
||||
levelsBytes, err := json.Marshal(levels)
|
||||
if err != nil {
|
||||
|
@ -47,7 +47,7 @@ func init() {
|
||||
return executeBatch(tx, valueStrings, values...)
|
||||
}
|
||||
|
||||
migrateMemberships := func(tx *sql.Tx, rooms map[string]map[string]mautrix.Membership) error {
|
||||
migrateMemberships := func(tx *sql.Tx, rooms map[string]map[string]mautrix.Member) error {
|
||||
for roomID, members := range rooms {
|
||||
if len(members) == 0 {
|
||||
continue
|
||||
@ -125,7 +125,7 @@ func init() {
|
||||
return err
|
||||
} else if err = migrateRegistrations(tx, store.Registrations); err != nil {
|
||||
return err
|
||||
} else if err = migrateMemberships(tx, store.Memberships); err != nil {
|
||||
} else if err = migrateMemberships(tx, store.Members); err != nil {
|
||||
return err
|
||||
} else if err = migratePowerLevels(tx, store.PowerLevels); err != nil {
|
||||
return err
|
||||
|
16
database/upgrades/2019-11-10-full-member-state-store.go
Normal file
16
database/upgrades/2019-11-10-full-member-state-store.go
Normal file
@ -0,0 +1,16 @@
|
||||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func init() {
|
||||
upgrades[10] = upgrade{"Add columns to store full member info in state store", func(tx *sql.Tx, ctx context) error {
|
||||
_, err := tx.Exec(`ALTER TABLE mx_user_profile ADD COLUMN displayname TEXT`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec(`ALTER TABLE mx_user_profile ADD COLUMN avatar_url VARCHAR(255)`)
|
||||
return err
|
||||
}}
|
||||
}
|
@ -28,7 +28,7 @@ type upgrade struct {
|
||||
fn upgradeFunc
|
||||
}
|
||||
|
||||
const NumberOfUpgrades = 10
|
||||
const NumberOfUpgrades = 11
|
||||
|
||||
var upgrades [NumberOfUpgrades]upgrade
|
||||
|
||||
|
@ -201,6 +201,13 @@ func (user *User) SetPortalKeys(newKeys []PortalKeyWithMeta) error {
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (user *User) IsInPortal(jid types.WhatsAppID) bool {
|
||||
row := user.db.QueryRow(`SELECT portal_jid, portal_receiver FROM user_portal WHERE user_jid=$1 AND portal_jid=$2 AND (portal_receiver=$1 OR portal_receiver=$2)`, user.jidPtr(), &jid)
|
||||
var scanJid, scanReceiver types.WhatsAppID
|
||||
_ = row.Scan(&scanJid, &scanReceiver)
|
||||
return scanJid == jid && (scanReceiver == jid || scanReceiver == user.JID)
|
||||
}
|
||||
|
||||
func (user *User) GetPortalKeys() []PortalKey {
|
||||
rows, err := user.db.Query(`SELECT portal_jid, portal_receiver FROM user_portal WHERE user_jid=$1`, user.jidPtr())
|
||||
if err != nil {
|
||||
|
Reference in New Issue
Block a user