Implement syncing with custom puppets

This commit is contained in:
Tulir Asokan
2018-07-21 10:45:29 -04:00
parent ecdca21e32
commit 54287c344f
2 changed files with 103 additions and 14 deletions
+15 -6
View File
@@ -289,17 +289,26 @@ class MatrixHandler:
await portal.name_change_matrix(user, displayname, prev_displayname, event_id) await portal.name_change_matrix(user, displayname, prev_displayname, event_id)
def filter_matrix_event(self, event): def filter_matrix_event(self, event):
return (event["sender"] == self.az.bot_mxid sender = event.get("sender", None)
or Puppet.get_id_from_mxid(event["sender"]) is not None) if not sender:
return False
return (sender == self.az.bot_mxid
or Puppet.get_id_from_mxid(sender) is not None)
async def try_handle_event(self, evt):
try:
await self.handle_event(evt)
except Exception:
self.log.exception("Error handling manually received Matrix event")
async def handle_event(self, evt): async def handle_event(self, evt):
if self.filter_matrix_event(evt): if self.filter_matrix_event(evt):
return return
self.log.debug("Received event: %s", evt) self.log.debug("Received event: %s", evt)
type = evt["type"] type = evt.get("type", "m.unknown")
room_id = evt["room_id"] room_id = evt.get("room_id", None)
event_id = evt["event_id"] event_id = evt.get("event_id", None)
sender = evt["sender"] sender = evt.get("sender", None)
content = evt.get("content", {}) content = evt.get("content", {})
if type == "m.room.member": if type == "m.room.member":
state_key = evt["state_key"] state_key = evt["state_key"]
+88 -8
View File
@@ -15,15 +15,16 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from difflib import SequenceMatcher from difflib import SequenceMatcher
from typing import Optional from typing import Optional, Awaitable
import re import re
import logging import logging
import asyncio
from telethon.tl.types import UserProfilePhoto from telethon.tl.types import UserProfilePhoto
from mautrix_appservice import MatrixError, IntentAPI from mautrix_appservice import AppService, IntentAPI, MatrixRequestError
from .db import Puppet as DBPuppet from .db import Puppet as DBPuppet
from . import util from . import util, matrix
config = None config = None
@@ -31,7 +32,9 @@ config = None
class Puppet: class Puppet:
log = logging.getLogger("mau.puppet") log = logging.getLogger("mau.puppet")
db = None db = None
az = None az = None # type: AppService
mx = None # type: matrix.MatrixHandler
loop = None # type: asyncio.AbstractEventLoop
mxid_regex = None mxid_regex = None
username_template = None username_template = None
hs_domain = None hs_domain = None
@@ -79,7 +82,7 @@ class Puppet:
self.access_token = access_token self.access_token = access_token
self.refresh_intents() self.refresh_intents()
err = await self.test_custom_mxid() err = await self.init_custom_mxid()
if err != 0: if err != 0:
return err return err
@@ -92,7 +95,7 @@ class Puppet:
self.save() self.save()
return 0 return 0
async def test_custom_mxid(self): async def init_custom_mxid(self):
if not self.is_real_user: if not self.is_real_user:
return 0 return 0
@@ -104,8 +107,84 @@ class Puppet:
if mxid != self.custom_mxid: if mxid != self.custom_mxid:
return 2 return 2
return 1 return 1
asyncio.ensure_future(self.sync(), loop=self.loop)
return 0 return 0
def create_sync_filter(self) -> Awaitable[str]:
return self.intent.client.create_filter(self.custom_mxid, {
"room": {
"include_leave": False,
"state": {
"types": []
},
"timeline": {
"types": [],
},
"ephemeral": {
"types": ["m.typing", "m.receipt"]
},
"account_data": {
"types": []
}
},
"account_data": {
"types": [],
},
"presence": {
"types": ["m.presence"]
},
})
def handle_sync(self, presence, ephemeral):
presence = [self.mx.try_handle_event(event) for event in presence]
for room_id, events in ephemeral.items():
for event in events:
event["room_id"] = room_id
ephemeral = [self.mx.try_handle_event(event)
for events in ephemeral.values()
for event in events]
events = ephemeral + presence
coro = asyncio.gather(*events, loop=self.loop)
asyncio.ensure_future(coro, loop=self.loop)
async def sync(self):
try:
await self._sync()
except Exception:
self.log.exception("Fatal error syncing")
async def _sync(self):
if not self.is_real_user:
self.log.warning("Called sync() for non-custom puppet.")
return
custom_mxid = self.custom_mxid
access_token_at_start = self.access_token
errors = 0
next_batch = None
filter_id = await self.create_sync_filter()
self.log.debug(f"Starting syncer for {custom_mxid} with sync filter {filter_id}.")
while access_token_at_start == self.access_token:
try:
sync_resp = await self.intent.client.sync(filter=filter_id, since=next_batch)
errors = 0
if next_batch is not None:
presence = sync_resp.get("presence", {}).get("events", [])
ephemeral = {room: data.get("ephemeral", {}).get("events", [])
for room, data
in sync_resp.get("rooms", {}).get("join", {}).items()}
self.handle_sync(presence, ephemeral)
next_batch = sync_resp.get("next_batch", None)
except MatrixRequestError as e:
wait = min(errors, 11) ** 2
self.log.warning(f"Syncer for {custom_mxid} errored: {e}. "
f"Waiting for {wait} seconds...")
errors += 1
await asyncio.sleep(wait)
self.log.debug(f"Syncer for custom puppet {custom_mxid} stopped.")
async def is_logged_in(self): async def is_logged_in(self):
return True return True
@@ -310,9 +389,10 @@ class Puppet:
def init(context): def init(context):
global config global config
Puppet.az, Puppet.db, config, _, _ = context Puppet.az, Puppet.db, config, Puppet.loop, _ = context
Puppet.mx = context.mx
Puppet.username_template = config.get("bridge.username_template", "telegram_{userid}") Puppet.username_template = config.get("bridge.username_template", "telegram_{userid}")
Puppet.hs_domain = config["homeserver"]["domain"] Puppet.hs_domain = config["homeserver"]["domain"]
localpart = Puppet.username_template.format(userid="(.+)") localpart = Puppet.username_template.format(userid="(.+)")
Puppet.mxid_regex = re.compile(f"@{localpart}:{Puppet.hs_domain}") Puppet.mxid_regex = re.compile(f"@{localpart}:{Puppet.hs_domain}")
return [puppet.test_custom_mxid() for puppet in Puppet.get_all_with_custom_mxid()] return [puppet.init_custom_mxid() for puppet in Puppet.get_all_with_custom_mxid()]