Add locking to client connect calls

This commit is contained in:
Tulir Asokan
2019-08-08 00:15:58 +03:00
parent 9cbe6b73fc
commit 8889105d5a
2 changed files with 9 additions and 4 deletions
+1 -1
View File
@@ -211,7 +211,7 @@ class AbstractUser(ABC):
return self
async def ensure_started(self, even_if_no_session=False) -> 'AbstractUser':
if not self.puppet_whitelisted or self.connected:
if self.connected:
return self
if even_if_no_session or self.session_container.has_session(self.mxid):
self.log.debug("Starting client due to ensure_started"
+8 -3
View File
@@ -13,7 +13,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import (Awaitable, Dict, List, Iterable, Match, NewType, Optional, Tuple, Any,
from typing import (Awaitable, Dict, List, Iterable, Match, NewType, Optional, Tuple, Any, cast,
TYPE_CHECKING)
import logging
import asyncio
@@ -54,6 +54,7 @@ class User(AbstractUser):
command_status: Optional[Dict[str, Any]]
_db_instance: Optional[DBUser]
_ensure_started_lock: asyncio.Lock
def __init__(self, mxid: UserID, tgid: Optional[TelegramID] = None,
username: Optional[str] = None, phone: Optional[str] = None,
@@ -73,6 +74,7 @@ class User(AbstractUser):
self.portals = {}
self.db_portals = db_portals or []
self._db_instance = db_instance
self._ensure_started_lock = asyncio.Lock()
self.command_status = None
@@ -172,8 +174,11 @@ class User(AbstractUser):
# endregion
# region Telegram connection management
def ensure_started(self, even_if_no_session=False) -> Awaitable['User']:
return super().ensure_started(even_if_no_session)
async def ensure_started(self, even_if_no_session=False) -> 'User':
if not self.puppet_whitelisted or self.connected:
return self
async with self._ensure_started_lock:
return cast(User, await super().ensure_started(even_if_no_session))
async def start(self, delete_unless_authenticated: bool = False) -> 'User':
await super().start()