Stop handling events from custom puppets
This commit is contained in:
@@ -36,6 +36,7 @@ class Puppet:
|
||||
username_template = None
|
||||
hs_domain = None
|
||||
cache = {}
|
||||
by_custom_mxid = {}
|
||||
|
||||
def __init__(self, id=None, access_token=None, custom_mxid=None, username=None,
|
||||
displayname=None, displayname_source=None, photo_id=None, is_bot=None,
|
||||
@@ -60,22 +61,51 @@ class Puppet:
|
||||
self.refresh_intents()
|
||||
|
||||
self.cache[id] = self
|
||||
if self.custom_mxid:
|
||||
self.by_custom_mxid[self.custom_mxid] = self
|
||||
|
||||
def refresh_intents(self):
|
||||
self.is_real_user = self.custom_mxid and self.access_token
|
||||
self.intent = (self.az.intent.user(self.custom_mxid, self.access_token)
|
||||
if self.is_real_user else self.default_mxid_intent)
|
||||
|
||||
async def get_profile(self):
|
||||
try:
|
||||
return await self.intent.get_profile(self.custom_mxid)
|
||||
except MatrixError:
|
||||
return None
|
||||
|
||||
@property
|
||||
def tgid(self):
|
||||
return self.id
|
||||
|
||||
async def switch_mxid(self, access_token, mxid):
|
||||
prev_mxid = self.custom_mxid
|
||||
self.custom_mxid = mxid
|
||||
self.access_token = access_token
|
||||
self.refresh_intents()
|
||||
|
||||
err = await self.test_custom_mxid()
|
||||
if err != 0:
|
||||
return err
|
||||
|
||||
try:
|
||||
del self.by_custom_mxid[prev_mxid]
|
||||
except KeyError:
|
||||
pass
|
||||
self.mxid = self.custom_mxid or self.default_mxid
|
||||
self.by_custom_mxid[self.mxid] = self
|
||||
self.save()
|
||||
return 0
|
||||
|
||||
async def test_custom_mxid(self):
|
||||
if not self.is_real_user:
|
||||
return 0
|
||||
|
||||
mxid = await self.intent.whoami()
|
||||
if not mxid or mxid != self.custom_mxid:
|
||||
self.custom_mxid = None
|
||||
self.access_token = None
|
||||
self.refresh_intents()
|
||||
if mxid != self.custom_mxid:
|
||||
return 2
|
||||
return 1
|
||||
return 0
|
||||
|
||||
async def is_logged_in(self):
|
||||
return True
|
||||
|
||||
@@ -212,6 +242,30 @@ class Puppet:
|
||||
tgid = cls.get_id_from_mxid(mxid)
|
||||
return cls.get(tgid, create) if tgid else None
|
||||
|
||||
@classmethod
|
||||
def get_by_custom_mxid(cls, mxid):
|
||||
if not mxid:
|
||||
raise ValueError("Matrix ID can't be empty")
|
||||
|
||||
try:
|
||||
return cls.by_custom_mxid[mxid]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
puppet = DBPuppet.query.filter(DBPuppet.custom_mxid == mxid).one_or_none()
|
||||
if puppet:
|
||||
puppet = cls.from_db(puppet)
|
||||
return puppet
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_all_with_custom_mxid(cls):
|
||||
return [cls.by_custom_mxid[puppet.mxid]
|
||||
if puppet.custom_mxid in cls.by_custom_mxid
|
||||
else cls.from_db(puppet)
|
||||
for puppet in DBPuppet.query.filter(DBPuppet.custom_mxid is not None).all()]
|
||||
|
||||
@classmethod
|
||||
def get_id_from_mxid(cls, mxid):
|
||||
match = cls.mxid_regex.match(mxid)
|
||||
@@ -261,3 +315,4 @@ def init(context):
|
||||
Puppet.hs_domain = config["homeserver"]["domain"]
|
||||
localpart = Puppet.username_template.format(userid="(.+)")
|
||||
Puppet.mxid_regex = re.compile(f"@{localpart}:{Puppet.hs_domain}")
|
||||
return [puppet.test_custom_mxid() for puppet in Puppet.get_all_with_custom_mxid()]
|
||||
|
||||
Reference in New Issue
Block a user