Add option to set private chat portal rooms' name/avatar explicitly

This commit is contained in:
Tulir Asokan 2019-06-01 20:03:29 +03:00
parent 9fa0ad923d
commit e124641107
8 changed files with 158 additions and 51 deletions

View File

@ -47,6 +47,9 @@ type BridgeConfig struct {
SyncWithCustomPuppets bool `yaml:"sync_with_custom_puppets"` SyncWithCustomPuppets bool `yaml:"sync_with_custom_puppets"`
InviteOwnPuppetForBackfilling bool `yaml:"invite_own_puppet_for_backfilling"`
PrivateChatPortalMeta bool `yaml:"private_chat_portal_meta"`
CommandPrefix string `yaml:"command_prefix"` CommandPrefix string `yaml:"command_prefix"`
Permissions PermissionConfig `yaml:"permissions"` Permissions PermissionConfig `yaml:"permissions"`
@ -69,6 +72,9 @@ func (bc *BridgeConfig) setDefaults() {
bc.SyncChatMaxAge = 259200 bc.SyncChatMaxAge = 259200
bc.SyncWithCustomPuppets = true bc.SyncWithCustomPuppets = true
bc.InviteOwnPuppetForBackfilling = true
bc.PrivateChatPortalMeta = false
} }
type umBridgeConfig BridgeConfig type umBridgeConfig BridgeConfig

View File

@ -66,16 +66,8 @@ func (pq *PortalQuery) New() *Portal {
} }
} }
func (pq *PortalQuery) GetAll() (portals []*Portal) { func (pq *PortalQuery) GetAll() []*Portal {
rows, err := pq.db.Query("SELECT * FROM portal") return pq.getAll("SELECT * FROM portal")
if err != nil || rows == nil {
return nil
}
defer rows.Close()
for rows.Next() {
portals = append(portals, pq.New().Scan(rows))
}
return
} }
func (pq *PortalQuery) GetByJID(key PortalKey) *Portal { func (pq *PortalQuery) GetByJID(key PortalKey) *Portal {
@ -86,6 +78,22 @@ func (pq *PortalQuery) GetByMXID(mxid types.MatrixRoomID) *Portal {
return pq.get("SELECT * FROM portal WHERE mxid=$1", mxid) return pq.get("SELECT * FROM portal WHERE mxid=$1", mxid)
} }
func (pq *PortalQuery) GetAllByJID(jid types.WhatsAppID) []*Portal {
return pq.getAll("SELECT * FROM portal WHERE jid=$1", jid)
}
func (pq *PortalQuery) getAll(query string, args ...interface{}) (portals []*Portal) {
rows, err := pq.db.Query(query, args...)
if err != nil || rows == nil {
return nil
}
defer rows.Close()
for rows.Next() {
portals = append(portals, pq.New().Scan(rows))
}
return
}
func (pq *PortalQuery) get(query string, args ...interface{}) *Portal { func (pq *PortalQuery) get(query string, args ...interface{}) *Portal {
row := pq.db.QueryRow(query, args...) row := pq.db.QueryRow(query, args...)
if row == nil { if row == nil {
@ -104,11 +112,12 @@ type Portal struct {
Name string Name string
Topic string Topic string
Avatar string Avatar string
AvatarURL string
} }
func (portal *Portal) Scan(row Scannable) *Portal { func (portal *Portal) Scan(row Scannable) *Portal {
var mxid sql.NullString var mxid, avatarURL sql.NullString
err := row.Scan(&portal.Key.JID, &portal.Key.Receiver, &mxid, &portal.Name, &portal.Topic, &portal.Avatar) err := row.Scan(&portal.Key.JID, &portal.Key.Receiver, &mxid, &portal.Name, &portal.Topic, &portal.Avatar, &avatarURL)
if err != nil { if err != nil {
if err != sql.ErrNoRows { if err != sql.ErrNoRows {
portal.log.Errorln("Database scan failed:", err) portal.log.Errorln("Database scan failed:", err)
@ -116,6 +125,7 @@ func (portal *Portal) Scan(row Scannable) *Portal {
return nil return nil
} }
portal.MXID = mxid.String portal.MXID = mxid.String
portal.AvatarURL = avatarURL.String
return portal return portal
} }
@ -127,8 +137,8 @@ func (portal *Portal) mxidPtr() *string {
} }
func (portal *Portal) Insert() { func (portal *Portal) Insert() {
_, err := portal.db.Exec("INSERT INTO portal VALUES ($1, $2, $3, $4, $5, $6)", _, err := portal.db.Exec("INSERT INTO portal VALUES ($1, $2, $3, $4, $5, $6, $7)",
portal.Key.JID, portal.Key.Receiver, portal.mxidPtr(), portal.Name, portal.Topic, portal.Avatar) portal.Key.JID, portal.Key.Receiver, portal.mxidPtr(), portal.Name, portal.Topic, portal.Avatar, portal.AvatarURL)
if err != nil { if err != nil {
portal.log.Warnfln("Failed to insert %s: %v", portal.Key, err) portal.log.Warnfln("Failed to insert %s: %v", portal.Key, err)
} }
@ -139,8 +149,8 @@ func (portal *Portal) Update() {
if len(portal.MXID) > 0 { if len(portal.MXID) > 0 {
mxid = &portal.MXID mxid = &portal.MXID
} }
_, err := portal.db.Exec("UPDATE portal SET mxid=$1, name=$2, topic=$3, avatar=$4 WHERE jid=$5 AND receiver=$6", _, err := portal.db.Exec("UPDATE portal SET mxid=$1, name=$2, topic=$3, avatar=$4, avatar_url=$5 WHERE jid=$6 AND receiver=$7",
mxid, portal.Name, portal.Topic, portal.Avatar, portal.Key.JID, portal.Key.Receiver) mxid, portal.Name, portal.Topic, portal.Avatar, portal.AvatarURL, portal.Key.JID, portal.Key.Receiver)
if err != nil { if err != nil {
portal.log.Warnfln("Failed to update %s: %v", portal.Key, err) portal.log.Warnfln("Failed to update %s: %v", portal.Key, err)
} }

View File

@ -37,7 +37,7 @@ func (pq *PuppetQuery) New() *Puppet {
} }
func (pq *PuppetQuery) GetAll() (puppets []*Puppet) { func (pq *PuppetQuery) GetAll() (puppets []*Puppet) {
rows, err := pq.db.Query("SELECT * FROM puppet") rows, err := pq.db.Query("SELECT jid, avatar, avatar_url, displayname, name_quality, custom_mxid, access_token, next_batch FROM puppet")
if err != nil || rows == nil { if err != nil || rows == nil {
return nil return nil
} }
@ -49,7 +49,7 @@ func (pq *PuppetQuery) GetAll() (puppets []*Puppet) {
} }
func (pq *PuppetQuery) Get(jid types.WhatsAppID) *Puppet { func (pq *PuppetQuery) Get(jid types.WhatsAppID) *Puppet {
row := pq.db.QueryRow("SELECT * FROM puppet WHERE jid=$1", jid) row := pq.db.QueryRow("SELECT jid, avatar, avatar_url, displayname, name_quality, custom_mxid, access_token, next_batch FROM puppet WHERE jid=$1", jid)
if row == nil { if row == nil {
return nil return nil
} }
@ -57,7 +57,7 @@ func (pq *PuppetQuery) Get(jid types.WhatsAppID) *Puppet {
} }
func (pq *PuppetQuery) GetByCustomMXID(mxid types.MatrixUserID) *Puppet { func (pq *PuppetQuery) GetByCustomMXID(mxid types.MatrixUserID) *Puppet {
row := pq.db.QueryRow("SELECT * FROM puppet WHERE custom_mxid=$1", mxid) row := pq.db.QueryRow("SELECT jid, avatar, avatar_url, displayname, name_quality, custom_mxid, access_token, next_batch FROM puppet WHERE custom_mxid=$1", mxid)
if row == nil { if row == nil {
return nil return nil
} }
@ -65,7 +65,7 @@ func (pq *PuppetQuery) GetByCustomMXID(mxid types.MatrixUserID) *Puppet {
} }
func (pq *PuppetQuery) GetAllWithCustomMXID() (puppets []*Puppet) { func (pq *PuppetQuery) GetAllWithCustomMXID() (puppets []*Puppet) {
rows, err := pq.db.Query("SELECT * FROM puppet WHERE custom_mxid<>''") rows, err := pq.db.Query("SELECT jid, avatar, avatar_url, displayname, name_quality, custom_mxid, access_token, next_batch FROM puppet WHERE custom_mxid<>''")
if err != nil || rows == nil { if err != nil || rows == nil {
return nil return nil
} }
@ -82,6 +82,7 @@ type Puppet struct {
JID types.WhatsAppID JID types.WhatsAppID
Avatar string Avatar string
AvatarURL string
Displayname string Displayname string
NameQuality int8 NameQuality int8
@ -91,9 +92,9 @@ type Puppet struct {
} }
func (puppet *Puppet) Scan(row Scannable) *Puppet { func (puppet *Puppet) Scan(row Scannable) *Puppet {
var displayname, avatar, customMXID, accessToken, nextBatch sql.NullString var displayname, avatar, avatarURL, customMXID, accessToken, nextBatch sql.NullString
var quality sql.NullInt64 var quality sql.NullInt64
err := row.Scan(&puppet.JID, &avatar, &displayname, &quality, &customMXID, &accessToken, &nextBatch) err := row.Scan(&puppet.JID, &avatar, &avatarURL, &displayname, &quality, &customMXID, &accessToken, &nextBatch)
if err != nil { if err != nil {
if err != sql.ErrNoRows { if err != sql.ErrNoRows {
puppet.log.Errorln("Database scan failed:", err) puppet.log.Errorln("Database scan failed:", err)
@ -102,6 +103,7 @@ func (puppet *Puppet) Scan(row Scannable) *Puppet {
} }
puppet.Displayname = displayname.String puppet.Displayname = displayname.String
puppet.Avatar = avatar.String puppet.Avatar = avatar.String
puppet.AvatarURL = avatarURL.String
puppet.NameQuality = int8(quality.Int64) puppet.NameQuality = int8(quality.Int64)
puppet.CustomMXID = customMXID.String puppet.CustomMXID = customMXID.String
puppet.AccessToken = accessToken.String puppet.AccessToken = accessToken.String
@ -110,16 +112,16 @@ func (puppet *Puppet) Scan(row Scannable) *Puppet {
} }
func (puppet *Puppet) Insert() { func (puppet *Puppet) Insert() {
_, err := puppet.db.Exec("INSERT INTO puppet VALUES ($1, $2, $3, $4, $5, $6, $7)", _, err := puppet.db.Exec("INSERT INTO puppet VALUES ($1, $2, $3, $4, $5, $6, $7, $8)",
puppet.JID, puppet.Avatar, puppet.Displayname, puppet.NameQuality, puppet.CustomMXID, puppet.AccessToken, puppet.NextBatch) puppet.JID, puppet.Avatar, puppet.AvatarURL, puppet.Displayname, puppet.NameQuality, puppet.CustomMXID, puppet.AccessToken, puppet.NextBatch)
if err != nil { if err != nil {
puppet.log.Warnfln("Failed to insert %s: %v", puppet.JID, err) puppet.log.Warnfln("Failed to insert %s: %v", puppet.JID, err)
} }
} }
func (puppet *Puppet) Update() { func (puppet *Puppet) Update() {
_, err := puppet.db.Exec("UPDATE puppet SET displayname=$1, name_quality=$2, avatar=$3, custom_mxid=$4, access_token=$5, next_batch=$6 WHERE jid=$7", _, err := puppet.db.Exec("UPDATE puppet SET displayname=$1, name_quality=$2, avatar=$3, avatar_url=$4, custom_mxid=$5, access_token=$6, next_batch=$7 WHERE jid=$8",
puppet.Displayname, puppet.NameQuality, puppet.Avatar, puppet.CustomMXID, puppet.AccessToken, puppet.NextBatch, puppet.JID) puppet.Displayname, puppet.NameQuality, puppet.Avatar, puppet.AvatarURL, puppet.CustomMXID, puppet.AccessToken, puppet.NextBatch, puppet.JID)
if err != nil { if err != nil {
puppet.log.Warnfln("Failed to update %s->%s: %v", puppet.JID, err) puppet.log.Warnfln("Failed to update %s->%s: %v", puppet.JID, err)
} }

View File

@ -0,0 +1,19 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[7] = upgrade{"Add columns to store avatar MXC URIs", func(dialect Dialect, tx *sql.Tx, db *sql.DB) error {
_, err := tx.Exec(`ALTER TABLE puppet ADD COLUMN avatar_url VARCHAR(255)`)
if err != nil {
return err
}
_, err = tx.Exec(`ALTER TABLE portal ADD COLUMN avatar_url VARCHAR(255)`)
if err != nil {
return err
}
return nil
}}
}

View File

@ -22,7 +22,7 @@ type upgrade struct {
fn upgradeFunc fn upgradeFunc
} }
const NumberOfUpgrades = 7 const NumberOfUpgrades = 8
var upgrades [NumberOfUpgrades]upgrade var upgrades [NumberOfUpgrades]upgrade

View File

@ -91,6 +91,16 @@ bridge:
# are not normally sent to appservices. # are not normally sent to appservices.
sync_with_custom_puppets: true sync_with_custom_puppets: true
# Whether or not to invite own WhatsApp user's Matrix puppet into private
# chat portals when backfilling if needed.
# This always uses the default puppet instead of custom puppets due to
# rate limits and timestamp massaging.
invite_own_puppet_for_backfilling: true
# Whether or not to explicitly set the avatar and room name for private
# chat portal rooms. This can be useful if the previous field works fine,
# but causes room avatar/name bugs.
private_chat_portal_meta: false
# The prefix for commands. Only required in non-management rooms. # The prefix for commands. Only required in non-management rooms.
command_prefix: "!wa" command_prefix: "!wa"

View File

@ -67,9 +67,16 @@ func (bridge *Bridge) GetPortalByJID(key database.PortalKey) *Portal {
} }
func (bridge *Bridge) GetAllPortals() []*Portal { func (bridge *Bridge) GetAllPortals() []*Portal {
return bridge.dbPortalsToPortals(bridge.DB.Portal.GetAll())
}
func (bridge *Bridge) GetAllPortalsByJID(jid types.WhatsAppID) []*Portal {
return bridge.dbPortalsToPortals(bridge.DB.Portal.GetAllByJID(jid))
}
func (bridge *Bridge) dbPortalsToPortals(dbPortals []*database.Portal) []*Portal {
bridge.portalsLock.Lock() bridge.portalsLock.Lock()
defer bridge.portalsLock.Unlock() defer bridge.portalsLock.Unlock()
dbPortals := bridge.DB.Portal.GetAll()
output := make([]*Portal, len(dbPortals)) output := make([]*Portal, len(dbPortals))
for index, dbPortal := range dbPortals { for index, dbPortal := range dbPortals {
portal, ok := bridge.portalsByJID[dbPortal.Key] portal, ok := bridge.portalsByJID[dbPortal.Key]
@ -131,8 +138,6 @@ type Portal struct {
bridge *Bridge bridge *Bridge
log log.Logger log log.Logger
avatarURL string
roomCreateLock sync.Mutex roomCreateLock sync.Mutex
recentlyHandled [recentlyHandledLength]types.WhatsAppMessageID recentlyHandled [recentlyHandledLength]types.WhatsAppMessageID
@ -333,7 +338,7 @@ func (portal *Portal) UpdateAvatar(user *User, avatar *whatsappExt.ProfilePicInf
return false return false
} }
portal.avatarURL = resp.ContentURI portal.AvatarURL = resp.ContentURI
if len(portal.MXID) > 0 { if len(portal.MXID) > 0 {
_, err = portal.MainIntent().SetRoomAvatar(portal.MXID, resp.ContentURI) _, err = portal.MainIntent().SetRoomAvatar(portal.MXID, resp.ContentURI)
if err != nil { if err != nil {
@ -582,7 +587,7 @@ func (portal *Portal) beginBackfill() func() {
portal.backfilling = true portal.backfilling = true
var privateChatPuppetInvited bool var privateChatPuppetInvited bool
var privateChatPuppet *Puppet var privateChatPuppet *Puppet
if portal.IsPrivateChat() { if portal.IsPrivateChat() && portal.bridge.Config.Bridge.InviteOwnPuppetForBackfilling {
privateChatPuppet = portal.bridge.GetPuppetByJID(portal.Key.Receiver) privateChatPuppet = portal.bridge.GetPuppetByJID(portal.Key.Receiver)
portal.privateChatBackfillInvitePuppet = func() { portal.privateChatBackfillInvitePuppet = func() {
if privateChatPuppetInvited { if privateChatPuppetInvited {
@ -686,7 +691,14 @@ func (portal *Portal) CreateMatrixRoom(user *User) error {
var metadata *whatsappExt.GroupInfo var metadata *whatsappExt.GroupInfo
isPrivateChat := false isPrivateChat := false
if portal.IsPrivateChat() { if portal.IsPrivateChat() {
puppet := portal.bridge.GetPuppetByJID(portal.Key.JID)
if portal.bridge.Config.Bridge.PrivateChatPortalMeta {
portal.Name = puppet.Displayname
portal.AvatarURL = puppet.AvatarURL
portal.Avatar = puppet.Avatar
} else {
portal.Name = "" portal.Name = ""
}
portal.Topic = "WhatsApp private chat" portal.Topic = "WhatsApp private chat"
isPrivateChat = true isPrivateChat = true
} else if portal.IsStatusBroadcastRoom() { } else if portal.IsStatusBroadcastRoom() {
@ -708,11 +720,11 @@ func (portal *Portal) CreateMatrixRoom(user *User) error {
PowerLevels: portal.GetBasePowerLevels(), PowerLevels: portal.GetBasePowerLevels(),
}, },
}} }}
if len(portal.avatarURL) > 0 { if len(portal.AvatarURL) > 0 {
initialState = append(initialState, &mautrix.Event{ initialState = append(initialState, &mautrix.Event{
Type: mautrix.StateRoomAvatar, Type: mautrix.StateRoomAvatar,
Content: mautrix.Content{ Content: mautrix.Content{
URL: portal.avatarURL, URL: portal.AvatarURL,
}, },
}) })
} }

View File

@ -193,7 +193,9 @@ func (puppet *Puppet) UpdateAvatar(source *User, avatar *whatsappExt.ProfilePicI
if err != nil { if err != nil {
puppet.log.Warnln("Failed to remove avatar:", err) puppet.log.Warnln("Failed to remove avatar:", err)
} }
puppet.AvatarURL = ""
puppet.Avatar = avatar.Tag puppet.Avatar = avatar.Tag
go puppet.updatePortalAvatar()
return true return true
} }
@ -210,14 +212,68 @@ func (puppet *Puppet) UpdateAvatar(source *User, avatar *whatsappExt.ProfilePicI
return false return false
} }
err = puppet.DefaultIntent().SetAvatarURL(resp.ContentURI) puppet.AvatarURL = resp.ContentURI
err = puppet.DefaultIntent().SetAvatarURL(puppet.AvatarURL)
if err != nil { if err != nil {
puppet.log.Warnln("Failed to set avatar:", err) puppet.log.Warnln("Failed to set avatar:", err)
} }
puppet.Avatar = avatar.Tag puppet.Avatar = avatar.Tag
go puppet.updatePortalAvatar()
return true return true
} }
func (puppet *Puppet) UpdateName(source *User, contact whatsapp.Contact) bool {
newName, quality := puppet.bridge.Config.Bridge.FormatDisplayname(contact)
if puppet.Displayname != newName && quality >= puppet.NameQuality {
err := puppet.DefaultIntent().SetDisplayName(newName)
if err == nil {
puppet.Displayname = newName
puppet.NameQuality = quality
go puppet.updatePortalName()
puppet.Update()
} else {
puppet.log.Warnln("Failed to set display name:", err)
}
return true
}
return false
}
func (puppet *Puppet) updatePortalMeta(meta func(portal *Portal)) {
if puppet.bridge.Config.Bridge.PrivateChatPortalMeta {
for _, portal := range puppet.bridge.GetAllPortalsByJID(puppet.JID) {
meta(portal)
}
}
}
func (puppet *Puppet) updatePortalAvatar() {
puppet.updatePortalMeta(func(portal *Portal) {
if len(portal.MXID) > 0 {
_, err := portal.MainIntent().SetRoomAvatar(portal.MXID, puppet.AvatarURL)
if err != nil {
portal.log.Warnln("Failed to set avatar:", err)
}
}
portal.AvatarURL = puppet.AvatarURL
portal.Avatar = puppet.Avatar
portal.Update()
})
}
func (puppet *Puppet) updatePortalName() {
puppet.updatePortalMeta(func(portal *Portal) {
if len(portal.MXID) > 0 {
_, err := portal.MainIntent().SetRoomName(portal.MXID, puppet.Displayname)
if err != nil {
portal.log.Warnln("Failed to set name:", err)
}
}
portal.Name = puppet.Displayname
portal.Update()
})
}
func (puppet *Puppet) Sync(source *User, contact whatsapp.Contact) { func (puppet *Puppet) Sync(source *User, contact whatsapp.Contact) {
err := puppet.DefaultIntent().EnsureRegistered() err := puppet.DefaultIntent().EnsureRegistered()
if err != nil { if err != nil {
@ -227,19 +283,11 @@ func (puppet *Puppet) Sync(source *User, contact whatsapp.Contact) {
if contact.Jid == source.JID { if contact.Jid == source.JID {
contact.Notify = source.Conn.Info.Pushname contact.Notify = source.Conn.Info.Pushname
} }
newName, quality := puppet.bridge.Config.Bridge.FormatDisplayname(contact)
if puppet.Displayname != newName && quality >= puppet.NameQuality {
err := puppet.DefaultIntent().SetDisplayName(newName)
if err == nil {
puppet.Displayname = newName
puppet.NameQuality = quality
puppet.Update()
} else {
puppet.log.Warnln("Failed to set display name:", err)
}
}
if puppet.UpdateAvatar(source, nil) { update := false
update = puppet.UpdateName(source, contact) || update
update = puppet.UpdateAvatar(source, nil) || update
if update {
puppet.Update() puppet.Update()
} }
} }