Fix private chats when multiple users are using the bridge

This commit is contained in:
Tulir Asokan
2018-01-28 21:21:44 +02:00
parent f7ac86ee3b
commit 28593ea50c
6 changed files with 92 additions and 61 deletions
+4 -2
View File
@@ -218,7 +218,7 @@ class CommandHandler:
@command_handler @command_handler
def pm(self, sender, args): def pm(self, sender, args):
if len(args) == 0: if len(args) == 0:
return self.reply("**Usage:** `$cmdprefix+sp pm <user identifier>") return self.reply("**Usage:** `$cmdprefix+sp pm <user identifier>`")
elif not sender.tgid: elif not sender.tgid:
return self.reply("This command requires you to be logged in.") return self.reply("This command requires you to be logged in.")
@@ -227,7 +227,9 @@ class CommandHandler:
return self.reply("User not found.") return self.reply("User not found.")
elif not isinstance(user, User): elif not isinstance(user, User):
return self.reply("That doesn't seem to be a user.") return self.reply("That doesn't seem to be a user.")
print(user) portal = po.Portal.get_by_entity(user, sender.tgid)
portal.create_matrix_room(sender, user, [sender.mxid])
self.reply(f"Created private chat room with {pu.Puppet.get_displayname(user, False)}")
def _strip_prefix(self, value, prefixes): def _strip_prefix(self, value, prefixes):
for prefix in prefixes: for prefix in prefixes:
+5
View File
@@ -18,10 +18,12 @@ from .base import Base
class Portal(Base): class Portal(Base):
query = None
__tablename__ = "portal" __tablename__ = "portal"
# Telegram chat information # Telegram chat information
tgid = Column(Integer, primary_key=True) tgid = Column(Integer, primary_key=True)
tg_receiver = Column(Integer, primary_key=True)
peer_type = Column(String) peer_type = Column(String)
# Matrix portal information # Matrix portal information
@@ -34,6 +36,7 @@ class Portal(Base):
class Message(Base): class Message(Base):
query = None
__tablename__ = "message" __tablename__ = "message"
mxid = Column(String) mxid = Column(String)
@@ -45,6 +48,7 @@ class Message(Base):
class User(Base): class User(Base):
query = None
__tablename__ = "user" __tablename__ = "user"
mxid = Column(String, primary_key=True) mxid = Column(String, primary_key=True)
@@ -53,6 +57,7 @@ class User(Base):
class Puppet(Base): class Puppet(Base):
query = None
__tablename__ = "puppet" __tablename__ = "puppet"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
+2 -2
View File
@@ -69,7 +69,7 @@ class MatrixHandler:
return return
puppet.intent.join_room(room) puppet.intent.join_room(room)
existing_portal = Portal.get_by_tgid(puppet.tgid, "user") existing_portal = Portal.get_by_tgid(puppet.tgid, inviter.tgid, "user")
if existing_portal: if existing_portal:
try: try:
puppet.intent.invite(existing_portal.mxid, inviter.mxid) puppet.intent.invite(existing_portal.mxid, inviter.mxid)
@@ -83,7 +83,7 @@ class MatrixHandler:
except MatrixRequestError: except MatrixRequestError:
existing_portal.delete() existing_portal.delete()
portal = Portal(tgid=puppet.tgid, peer_type="user", mxid=room) portal = Portal(tgid=puppet.tgid, tg_receiver=inviter.tgid, peer_type="user", mxid=room)
portal.save() portal.save()
puppet.intent.send_notice(room, "Portal to private chat created.") puppet.intent.send_notice(room, "Portal to private chat created.")
else: else:
+55 -36
View File
@@ -36,9 +36,11 @@ class Portal:
by_mxid = {} by_mxid = {}
by_tgid = {} by_tgid = {}
def __init__(self, tgid, peer_type, mxid=None, username=None, title=None, photo_id=None): def __init__(self, tgid, peer_type, tg_receiver=None, mxid=None, username=None, title=None,
photo_id=None):
self.mxid = mxid self.mxid = mxid
self.tgid = tgid self.tgid = tgid
self.tg_receiver = tg_receiver or tgid
self.peer_type = peer_type self.peer_type = peer_type
self.username = username self.username = username
self.title = title self.title = title
@@ -46,10 +48,20 @@ class Portal:
self._main_intent = None self._main_intent = None
if tgid: if tgid:
self.by_tgid[tgid] = self self.by_tgid[self.tgid_full] = self
if mxid: if mxid:
self.by_mxid[mxid] = self self.by_mxid[mxid] = self
@property
def tgid_full(self):
return self.tgid, self.tg_receiver
@property
def tgid_log(self):
if self.tgid == self.tg_receiver:
return self.tgid
return f"{self.tg_receiver}<->{self.tgid}"
@property @property
def peer(self): def peer(self):
if self.peer_type == "user": if self.peer_type == "user":
@@ -76,35 +88,44 @@ class Portal:
for user in users: for user in users:
self.main_intent.invite(self.mxid, user) self.main_intent.invite(self.mxid, user)
def update_after_create(self, user, entity, direct, puppet=None):
if not direct:
self.update_info(user, entity)
users, participants = self.get_users(user, entity)
self.sync_telegram_users(user, users)
self.update_telegram_participants(participants)
else:
if not puppet:
puppet = p.Puppet.get(self.tgid)
puppet.update_info(user, entity)
puppet.intent.join_room(self.mxid)
def create_matrix_room(self, user, entity=None, invites=[], update_if_exists=True): def create_matrix_room(self, user, entity=None, invites=[], update_if_exists=True):
if not entity: if not entity:
entity = user.client.get_entity(self.peer) entity = user.client.get_entity(self.peer)
self.log.debug("Fetched data: %s", entity) self.log.debug("Fetched data: %s", entity)
direct = self.peer_type == "user"
if self.mxid: if self.mxid:
if update_if_exists: if update_if_exists:
self.update_info(user, entity) self.update_after_create(user, entity, direct)
users, participants = self.get_users(user, entity)
self.sync_telegram_users(user, users)
self.update_telegram_participants(participants)
self.invite_matrix(invites) self.invite_matrix(invites)
return self.mxid return self.mxid
self.log.debug("Creating room for %d", self.tgid) self.log.debug(f"Creating room for {self.tgid_log}")
try: try:
title = entity.title title = entity.title
except AttributeError: except AttributeError:
title = None title = None
direct = self.peer_type == "user"
puppet = p.Puppet.get(self.tgid) if direct else None puppet = p.Puppet.get(self.tgid) if direct else None
intent = puppet.intent if direct else self.az.intent intent = puppet.intent if direct else self.az.intent
# TODO set room alias if public channel. # TODO set room alias if public channel.
room = intent.create_room(invitees=invites, name=title, is_direct=direct) room = intent.create_room(invitees=invites, name=title, is_direct=direct)
if not room: if not room:
raise Exception(f"Failed to create room for {self.tgid}") raise Exception(f"Failed to create room for {self.tgid_log}")
self.mxid = room["room_id"] self.mxid = room["room_id"]
self.by_mxid[self.mxid] = self self.by_mxid[self.mxid] = self
@@ -119,15 +140,7 @@ class Portal:
levels["events"]["m.room.topic"] = 50 if self.peer_type == "channel" else 100 levels["events"]["m.room.topic"] = 50 if self.peer_type == "channel" else 100
levels["events"]["m.room.power_levels"] = 95 levels["events"]["m.room.power_levels"] = 95
self.main_intent.set_power_levels(self.mxid, levels) self.main_intent.set_power_levels(self.mxid, levels)
self.update_after_create(user, entity, direct, puppet)
if not direct:
self.update_info(user, entity)
users, participants = self.get_users(user, entity)
self.sync_telegram_users(user, users)
self.update_telegram_participants(participants)
else:
puppet.update_info(user, entity)
puppet.intent.join_room(self.mxid)
def sync_telegram_users(self, source, users=[]): def sync_telegram_users(self, source, users=[]):
for entity in users: for entity in users:
@@ -158,10 +171,10 @@ class Portal:
def update_info(self, user, entity=None): def update_info(self, user, entity=None):
if self.peer_type == "user": if self.peer_type == "user":
self.log.warn("Called update_info() for direct chat portal %d", self.tgid) self.log.warn(f"Called update_info() for direct chat portal {self.tgid_log}")
return return
self.log.debug("Updating info of %d", self.tgid) self.log.debug(f"Updating info of {self.tgid_log}")
if not entity: if not entity:
entity = user.client.get_entity(self.peer) entity = user.client.get_entity(self.peer)
self.log.debug("Fetched data: %s", entity) self.log.debug("Fetched data: %s", entity)
@@ -213,7 +226,7 @@ class Portal:
)) ))
return participants.users, participants.participants return participants.users, participants.participants
except ChatAdminRequiredError: except ChatAdminRequiredError:
return [] return [], []
elif self.peer_type == "user": elif self.peer_type == "user":
return [entity], [] return [entity], []
@@ -320,6 +333,7 @@ class Portal:
entity = updates.chats[0] entity = updates.chats[0]
self.tgid = entity.id self.tgid = entity.id
self.tg_receiver = self.tgid
self.update_info(source, entity) self.update_info(source, entity)
self.save() self.save()
@@ -425,7 +439,7 @@ class Portal:
}) })
def handle_telegram_text(self, source, sender, evt): def handle_telegram_text(self, source, sender, evt):
self.log.debug("Sending %s to %s by %d", evt.message, self.mxid, sender.id) self.log.debug(f"Sending {evt.message} to {self.mxid} by {sender.id}")
text, html = formatter.telegram_event_to_matrix(evt, source) text, html = formatter.telegram_event_to_matrix(evt, source)
sender.intent.set_typing(self.mxid, is_typing=False) sender.intent.set_typing(self.mxid, is_typing=False)
return sender.intent.send_text(self.mxid, text, html=html) return sender.intent.send_text(self.mxid, text, html=html)
@@ -525,17 +539,18 @@ class Portal:
# region Database conversion # region Database conversion
def to_db(self): def to_db(self):
return self.db.merge(DBPortal(tgid=self.tgid, peer_type=self.peer_type, mxid=self.mxid, return self.db.merge(
username=self.username, title=self.title, DBPortal(tgid=self.tgid, tg_receiver=self.tg_receiver, peer_type=self.peer_type,
photo_id=self.photo_id)) mxid=self.mxid, username=self.username, title=self.title,
photo_id=self.photo_id))
def migrate_and_save(self, new_id): def migrate_and_save(self, new_id):
existing = DBPortal.query.get(self.tgid) existing = DBPortal.query.get(self.tgid_full)
if existing: if existing:
self.db.object_session(existing).delete(existing) self.db.object_session(existing).delete(existing)
self.by_tgid[self.tgid] = None self.by_tgid[self.tgid_full] = None
self.tgid = new_id self.tgid = new_id
self.by_tgid[self.tgid] = self self.by_tgid[self.tgid_full] = self
self.save() self.save()
def save(self): def save(self):
@@ -547,8 +562,10 @@ class Portal:
@classmethod @classmethod
def from_db(cls, db_portal): def from_db(cls, db_portal):
return Portal(db_portal.tgid, db_portal.peer_type, db_portal.mxid, db_portal.username, return Portal(tgid=db_portal.tgid, tg_receiver=db_portal.tg_receiver,
db_portal.title, db_portal.photo_id) peer_type=db_portal.peer_type, mxid=db_portal.mxid,
username=db_portal.username, title=db_portal.title,
photo_id=db_portal.photo_id)
# endregion # endregion
# region Class instance lookup # region Class instance lookup
@@ -567,18 +584,20 @@ class Portal:
return None return None
@classmethod @classmethod
def get_by_tgid(cls, tgid, peer_type=None): def get_by_tgid(cls, tgid, tg_receiver=None, peer_type=None):
tg_receiver = tg_receiver or tgid
tgid_full = (tgid, tg_receiver)
try: try:
return cls.by_tgid[tgid] return cls.by_tgid[tgid_full]
except KeyError: except KeyError:
pass pass
portal = DBPortal.query.get(tgid) portal = DBPortal.query.get(tgid_full)
if portal: if portal:
return cls.from_db(portal) return cls.from_db(portal)
if peer_type: if peer_type:
portal = Portal(tgid, peer_type) portal = Portal(tgid, peer_type=peer_type, tg_receiver=tg_receiver)
cls.db.add(portal.to_db()) cls.db.add(portal.to_db())
portal.save() portal.save()
return portal return portal
@@ -586,7 +605,7 @@ class Portal:
return None return None
@classmethod @classmethod
def get_by_entity(cls, entity): def get_by_entity(cls, entity, receiver_id=None):
entity_type = type(entity) entity_type = type(entity)
if entity_type in {Chat, ChatFull}: if entity_type in {Chat, ChatFull}:
type_name = "chat" type_name = "chat"
@@ -608,7 +627,7 @@ class Portal:
id = entity.user_id id = entity.user_id
else: else:
raise ValueError(f"Unknown entity type {entity_type.__name__}") raise ValueError(f"Unknown entity type {entity_type.__name__}")
return cls.get_by_tgid(id, type_name) return cls.get_by_tgid(id, receiver_id if type_name == "user" else id, type_name)
# endregion # endregion
+8 -6
View File
@@ -86,18 +86,20 @@ class Puppet:
self.username = info.username self.username = info.username
changed = True changed = True
displayname = self.get_displayname(info) changed = self.update_displayname(source, info) or changed
if displayname != self.displayname:
self.intent.set_display_name(displayname)
self.displayname = displayname
changed = True
if isinstance(info.photo, UserProfilePhoto): if isinstance(info.photo, UserProfilePhoto):
changed = self.update_avatar(source, info.photo.photo_big) changed = self.update_avatar(source, info.photo.photo_big)
if changed: if changed:
self.save() self.save()
def update_displayname(self, source, info):
displayname = self.get_displayname(info)
if displayname != self.displayname:
self.intent.set_display_name(displayname)
self.displayname = displayname
return True
def update_avatar(self, source, photo): def update_avatar(self, source, photo):
photo_id = f"{photo.volume_id}-{photo.local_id}" photo_id = f"{photo.volume_id}-{photo.local_id}"
if self.photo_id != photo_id: if self.photo_id != photo_id:
+18 -15
View File
@@ -16,6 +16,7 @@
from io import BytesIO from io import BytesIO
from telethon import TelegramClient from telethon import TelegramClient
from telethon.tl.types import * from telethon.tl.types import *
from telethon.tl.types import User as TLUser
from telethon.tl.functions.messages import SendMessageRequest, SendMediaRequest from telethon.tl.functions.messages import SendMessageRequest, SendMediaRequest
from .db import User as DBUser from .db import User as DBUser
from . import portal as po, puppet as pu from . import portal as po, puppet as pu
@@ -41,7 +42,7 @@ class User:
whitelist = config.get("bridge", {}).get("whitelist", [self.mxid]) whitelist = config.get("bridge", {}).get("whitelist", [self.mxid])
self.whitelisted = self.mxid in whitelist self.whitelisted = self.mxid in whitelist
if not self.whitelisted: if not self.whitelisted:
homeserver = self.mxid[self.mxid.index(":")+1:] homeserver = self.mxid[self.mxid.index(":") + 1:]
self.whitelisted = homeserver in whitelist self.whitelisted = homeserver in whitelist
self.by_mxid[mxid] = self self.by_mxid[mxid] = self
@@ -95,7 +96,9 @@ class User:
def update_info(self, info=None): def update_info(self, info=None):
info = info or self.client.get_me() info = info or self.client.get_me()
changed = False changed = False
self.username = info.username if self.username != info.username:
self.username = info.username
changed = True
if self.tgid != info.id: if self.tgid != info.id:
self.tgid = info.id self.tgid = info.id
self.by_tgid[self.tgid] = self self.by_tgid[self.tgid] = self
@@ -177,9 +180,8 @@ class User:
dialogs = self.client.get_dialogs(limit=30) dialogs = self.client.get_dialogs(limit=30)
for dialog in dialogs: for dialog in dialogs:
entity = dialog.entity entity = dialog.entity
if (isinstance(entity, User) if (isinstance(entity, (TLUser, ChatForbidden, ChannelForbidden)) or (
or (isinstance(entity, Chat) and entity.deactivated) isinstance(entity, Chat) and entity.deactivated)):
or isinstance(entity, (ChannelForbidden, ChatForbidden))):
continue continue
portal = po.Portal.get_by_entity(entity) portal = po.Portal.get_by_entity(entity)
portal.create_matrix_room(self, entity, invites=[self.mxid]) portal.create_matrix_room(self, entity, invites=[self.mxid])
@@ -204,13 +206,13 @@ class User:
elif isinstance(update, (UpdateChatAdmins, UpdateChatParticipantAdmin)): elif isinstance(update, (UpdateChatAdmins, UpdateChatParticipantAdmin)):
self.update_admin(update) self.update_admin(update)
elif isinstance(update, UpdateChatParticipants): elif isinstance(update, UpdateChatParticipants):
portal = po.Portal.get_by_tgid(update.participants.chat_id, "chat") portal = po.Portal.get_by_tgid(update.participants.chat_id, peer_type="chat")
portal.update_telegram_participants(update.participants.participants) portal.update_telegram_participants(update.participants.participants)
else: else:
self.log.debug("Unhandled update: %s", update) self.log.debug("Unhandled update: %s", update)
def update_admin(self, update): def update_admin(self, update):
portal = po.Portal.get_by_tgid(update.chat_id, "chat") portal = po.Portal.get_by_tgid(update.chat_id, peer_type="chat")
if isinstance(update, UpdateChatAdmins): if isinstance(update, UpdateChatAdmins):
portal.set_telegram_admins_enabled(update.enabled) portal.set_telegram_admins_enabled(update.enabled)
elif isinstance(update, UpdateChatParticipantAdmin): elif isinstance(update, UpdateChatParticipantAdmin):
@@ -220,9 +222,9 @@ class User:
def update_typing(self, update): def update_typing(self, update):
if isinstance(update, UpdateUserTyping): if isinstance(update, UpdateUserTyping):
portal = po.Portal.get_by_tgid(update.user_id, "user") portal = po.Portal.get_by_tgid(update.user_id, self.tgid, "user")
else: else:
portal = po.Portal.get_by_tgid(update.chat_id, "chat") portal = po.Portal.get_by_tgid(update.chat_id, peer_type="chat")
sender = pu.Puppet.get(update.user_id) sender = pu.Puppet.get(update.user_id)
return portal.handle_telegram_typing(sender, update) return portal.handle_telegram_typing(sender, update)
@@ -236,15 +238,15 @@ class User:
def get_message_details(self, update): def get_message_details(self, update):
if isinstance(update, UpdateShortChatMessage): if isinstance(update, UpdateShortChatMessage):
portal = po.Portal.get_by_tgid(update.chat_id, "chat") portal = po.Portal.get_by_tgid(update.chat_id, peer_type="chat")
sender = pu.Puppet.get(update.from_id) sender = pu.Puppet.get(update.from_id)
elif isinstance(update, UpdateShortMessage): elif isinstance(update, UpdateShortMessage):
portal = po.Portal.get_by_tgid(update.user_id, "user") portal = po.Portal.get_by_tgid(update.user_id, self.tgid, "user")
sender = pu.Puppet.get(self.tgid if update.out else update.user_id) sender = pu.Puppet.get(self.tgid if update.out else update.user_id)
elif isinstance(update, (UpdateNewMessage, UpdateNewChannelMessage)): elif isinstance(update, (UpdateNewMessage, UpdateNewChannelMessage)):
update = update.message update = update.message
sender = pu.Puppet.get(update.from_id) sender = pu.Puppet.get(update.from_id)
portal = po.Portal.get_by_entity(update.to_id) portal = po.Portal.get_by_entity(update.to_id, receiver_id=self.tgid)
return update, sender, portal return update, sender, portal
def update_message(self, update): def update_message(self, update):
@@ -252,13 +254,14 @@ class User:
if isinstance(update, MessageService): if isinstance(update, MessageService):
if isinstance(update.action, MessageActionChannelMigrateFrom): if isinstance(update.action, MessageActionChannelMigrateFrom):
self.log.debug("Ignoring action %s to %d by %d", update.action, portal.tgid, self.log.debug(f"Ignoring action %s to %s by %d", update.action, portal.tgid_log,
sender.id) sender.id)
return return
self.log.debug("Handling action %s to %d by %d", update.action, portal.tgid, sender.id) self.log.debug("Handling action %s to %s by %d", update.action, portal.tgid_log,
sender.id)
portal.handle_telegram_action(self, sender, update.action) portal.handle_telegram_action(self, sender, update.action)
else: else:
self.log.debug("Handling message %s to %d by %d", update, portal.tgid, sender.tgid) self.log.debug("Handling message %s to %s by %d", update, portal.tgid_log, sender.tgid)
portal.handle_telegram_message(self, sender, update) portal.handle_telegram_message(self, sender, update)
# endregion # endregion