From 8889105d5a60546571d32746dd2d25e751d3f09f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 8 Aug 2019 00:15:58 +0300 Subject: [PATCH] Add locking to client connect calls --- mautrix_telegram/abstract_user.py | 2 +- mautrix_telegram/user.py | 11 ++++++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/mautrix_telegram/abstract_user.py b/mautrix_telegram/abstract_user.py index d572b7c3..c8e995df 100644 --- a/mautrix_telegram/abstract_user.py +++ b/mautrix_telegram/abstract_user.py @@ -211,7 +211,7 @@ class AbstractUser(ABC): return self async def ensure_started(self, even_if_no_session=False) -> 'AbstractUser': - if not self.puppet_whitelisted or self.connected: + if self.connected: return self if even_if_no_session or self.session_container.has_session(self.mxid): self.log.debug("Starting client due to ensure_started" diff --git a/mautrix_telegram/user.py b/mautrix_telegram/user.py index 324dc3e1..91e2381c 100644 --- a/mautrix_telegram/user.py +++ b/mautrix_telegram/user.py @@ -13,7 +13,7 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import (Awaitable, Dict, List, Iterable, Match, NewType, Optional, Tuple, Any, +from typing import (Awaitable, Dict, List, Iterable, Match, NewType, Optional, Tuple, Any, cast, TYPE_CHECKING) import logging import asyncio @@ -54,6 +54,7 @@ class User(AbstractUser): command_status: Optional[Dict[str, Any]] _db_instance: Optional[DBUser] + _ensure_started_lock: asyncio.Lock def __init__(self, mxid: UserID, tgid: Optional[TelegramID] = None, username: Optional[str] = None, phone: Optional[str] = None, @@ -73,6 +74,7 @@ class User(AbstractUser): self.portals = {} self.db_portals = db_portals or [] self._db_instance = db_instance + self._ensure_started_lock = asyncio.Lock() self.command_status = None @@ -172,8 +174,11 @@ class User(AbstractUser): # endregion # region Telegram connection management - def ensure_started(self, even_if_no_session=False) -> Awaitable['User']: - return super().ensure_started(even_if_no_session) + async def ensure_started(self, even_if_no_session=False) -> 'User': + if not self.puppet_whitelisted or self.connected: + return self + async with self._ensure_started_lock: + return cast(User, await super().ensure_started(even_if_no_session)) async def start(self, delete_unless_authenticated: bool = False) -> 'User': await super().start()