Stop handling events from custom puppets

This commit is contained in:
Tulir Asokan
2018-07-20 14:13:13 -04:00
parent 2b92483c50
commit ecdca21e32
7 changed files with 94 additions and 34 deletions
+4 -2
View File
@@ -110,8 +110,10 @@ with appserv.run(config["appservice.hostname"], config["appservice.port"]) as st
context.mx = MatrixHandler(context)
init_formatter(context)
init_portal(context)
init_puppet(context)
startup_actions = init_user(context) + [start, context.mx.init_as_bot()]
startup_actions = (init_puppet(context) +
init_user(context) +
[start,
context.mx.init_as_bot()])
if context.bot:
startup_actions.append(context.bot.start())
+1 -1
View File
@@ -124,7 +124,7 @@ class AbstractUser:
self.log.debug("%s connected: %s", self.mxid, self.connected)
return self
async def ensure_started(self, even_if_no_session=False):
async def ensure_started(self, even_if_no_session=False) -> "AbstractUser":
if not self.puppet_whitelisted:
return self
self.log.debug("ensure_started(%s, connected=%s, even_if_no_session=%s, session_count=%s)",
+4 -8
View File
@@ -56,15 +56,11 @@ async def ping_bot(evt: CommandEvent):
"account")
async def login_matrix(evt: CommandEvent):
puppet = pu.Puppet.get(evt.sender.tgid)
prev_info = puppet.custom_mxid, puppet.access_token
puppet.custom_mxid = evt.sender.mxid
puppet.access_token = " ".join(evt.args)
puppet.refresh_intents()
if not await puppet.get_profile():
puppet.custom_mxid, puppet.access_token = prev_info
puppet.refresh_intents()
resp = puppet.switch_mxid(" ".join(evt.args), evt.sender.mxid)
if resp == 2:
return await evt.reply("You can only log in as your own Matrix user.")
elif resp == 1:
return await evt.reply("Failed to verify access token.")
puppet.save()
return await evt.reply(
f"Replaced your Telegram account's Matrix puppet with {puppet.custom_mxid}.")
+11 -12
View File
@@ -17,14 +17,14 @@
from sqlalchemy import (Column, UniqueConstraint, ForeignKey, ForeignKeyConstraint, Integer,
BigInteger, String, Boolean, Text)
from sqlalchemy.sql import expression
from sqlalchemy.orm import relationship
from sqlalchemy.orm import relationship, Query
import json
from .base import Base
class Portal(Base):
query = None
query = None # type: Query
__tablename__ = "portal"
# Telegram chat information
@@ -42,9 +42,8 @@ class Portal(Base):
about = Column(String, nullable=True)
photo_id = Column(String, nullable=True)
class Message(Base):
query = None
query = None # type: Query
__tablename__ = "message"
mxid = Column(String)
@@ -56,7 +55,7 @@ class Message(Base):
class UserPortal(Base):
query = None
query = None # type: Query
__tablename__ = "user_portal"
user = Column(Integer, ForeignKey("user.tgid", onupdate="CASCADE", ondelete="CASCADE"),
@@ -70,7 +69,7 @@ class UserPortal(Base):
class User(Base):
query = None
query = None # type: Query
__tablename__ = "user"
mxid = Column(String, primary_key=True)
@@ -83,7 +82,7 @@ class User(Base):
class RoomState(Base):
query = None
query = None # type: Query
__tablename__ = "mx_room_state"
room_id = Column(String, primary_key=True)
@@ -107,7 +106,7 @@ class RoomState(Base):
class UserProfile(Base):
query = None
query = None # type: Query
__tablename__ = "mx_user_profile"
room_id = Column(String, primary_key=True)
@@ -125,7 +124,7 @@ class UserProfile(Base):
class Contact(Base):
query = None
query = None # type: Query
__tablename__ = "contact"
user = Column(Integer, ForeignKey("user.tgid"), primary_key=True)
@@ -133,7 +132,7 @@ class Contact(Base):
class Puppet(Base):
query = None
query = None # type: Query
__tablename__ = "puppet"
id = Column(Integer, primary_key=True)
@@ -149,14 +148,14 @@ class Puppet(Base):
# Fucking Telegram not telling bots what chats they are in 3:<
class BotChat(Base):
query = None
query = None # type: Query
__tablename__ = "bot_chat"
id = Column(Integer, primary_key=True)
type = Column(String, nullable=False)
class TelegramFile(Base):
query = None
query = None # type: Query
__tablename__ = "telegram_file"
id = Column(String, primary_key=True)
+6 -1
View File
@@ -824,7 +824,12 @@ class Portal:
mxid=event_id))
self.db.commit()
async def handle_matrix_message(self, sender, message, event_id):
async def handle_matrix_message(self, sender: u.User, message: dict, event_id: str):
puppet = p.Puppet.get_by_custom_mxid(sender.mxid)
if puppet and message.get("net.maunium.telegram.puppet", False):
self.log.debug("Ignoring puppet-sent message by confirmed puppet user %s", sender.mxid)
return
logged_in = not await sender.needs_relaybot(self)
client = sender.client if logged_in else self.bot.client
sender_id = sender.tgid if logged_in else self.bot.tgid
+61 -6
View File
@@ -36,6 +36,7 @@ class Puppet:
username_template = None
hs_domain = None
cache = {}
by_custom_mxid = {}
def __init__(self, id=None, access_token=None, custom_mxid=None, username=None,
displayname=None, displayname_source=None, photo_id=None, is_bot=None,
@@ -60,22 +61,51 @@ class Puppet:
self.refresh_intents()
self.cache[id] = self
if self.custom_mxid:
self.by_custom_mxid[self.custom_mxid] = self
def refresh_intents(self):
self.is_real_user = self.custom_mxid and self.access_token
self.intent = (self.az.intent.user(self.custom_mxid, self.access_token)
if self.is_real_user else self.default_mxid_intent)
async def get_profile(self):
try:
return await self.intent.get_profile(self.custom_mxid)
except MatrixError:
return None
@property
def tgid(self):
return self.id
async def switch_mxid(self, access_token, mxid):
prev_mxid = self.custom_mxid
self.custom_mxid = mxid
self.access_token = access_token
self.refresh_intents()
err = await self.test_custom_mxid()
if err != 0:
return err
try:
del self.by_custom_mxid[prev_mxid]
except KeyError:
pass
self.mxid = self.custom_mxid or self.default_mxid
self.by_custom_mxid[self.mxid] = self
self.save()
return 0
async def test_custom_mxid(self):
if not self.is_real_user:
return 0
mxid = await self.intent.whoami()
if not mxid or mxid != self.custom_mxid:
self.custom_mxid = None
self.access_token = None
self.refresh_intents()
if mxid != self.custom_mxid:
return 2
return 1
return 0
async def is_logged_in(self):
return True
@@ -212,6 +242,30 @@ class Puppet:
tgid = cls.get_id_from_mxid(mxid)
return cls.get(tgid, create) if tgid else None
@classmethod
def get_by_custom_mxid(cls, mxid):
if not mxid:
raise ValueError("Matrix ID can't be empty")
try:
return cls.by_custom_mxid[mxid]
except KeyError:
pass
puppet = DBPuppet.query.filter(DBPuppet.custom_mxid == mxid).one_or_none()
if puppet:
puppet = cls.from_db(puppet)
return puppet
return None
@classmethod
def get_all_with_custom_mxid(cls):
return [cls.by_custom_mxid[puppet.mxid]
if puppet.custom_mxid in cls.by_custom_mxid
else cls.from_db(puppet)
for puppet in DBPuppet.query.filter(DBPuppet.custom_mxid is not None).all()]
@classmethod
def get_id_from_mxid(cls, mxid):
match = cls.mxid_regex.match(mxid)
@@ -261,3 +315,4 @@ def init(context):
Puppet.hs_domain = config["homeserver"]["domain"]
localpart = Puppet.username_template.format(userid="(.+)")
Puppet.mxid_regex = re.compile(f"@{localpart}:{Puppet.hs_domain}")
return [puppet.test_custom_mxid() for puppet in Puppet.get_all_with_custom_mxid()]
+7 -4
View File
@@ -14,7 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict
from typing import Dict, Awaitable, Optional
import logging
import asyncio
import re
@@ -185,6 +185,9 @@ class User(AbstractUser):
# endregion
# region Telegram actions that need custom methods
def ensure_started(self, even_if_no_session=False) -> "Awaitable[User]":
return super().ensure_started(even_if_no_session)
async def update_info(self, info: User = None):
info = info or await self.client.get_me()
changed = False
@@ -309,7 +312,7 @@ class User(AbstractUser):
# region Class instance lookup
@classmethod
def get_by_mxid(cls, mxid, create=True):
def get_by_mxid(cls, mxid, create=True) -> "Optional[User]":
if not mxid:
raise ValueError("Matrix ID can't be empty")
@@ -332,7 +335,7 @@ class User(AbstractUser):
return None
@classmethod
def get_by_tgid(cls, tgid):
def get_by_tgid(cls, tgid) -> "Optional[User]":
try:
return cls.by_tgid[tgid]
except KeyError:
@@ -346,7 +349,7 @@ class User(AbstractUser):
return None
@classmethod
def find_by_username(cls, username):
def find_by_username(cls, username) -> "Optional[User]":
if not username:
return None