Add locking to client connect calls
This commit is contained in:
@@ -211,7 +211,7 @@ class AbstractUser(ABC):
|
||||
return self
|
||||
|
||||
async def ensure_started(self, even_if_no_session=False) -> 'AbstractUser':
|
||||
if not self.puppet_whitelisted or self.connected:
|
||||
if self.connected:
|
||||
return self
|
||||
if even_if_no_session or self.session_container.has_session(self.mxid):
|
||||
self.log.debug("Starting client due to ensure_started"
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import (Awaitable, Dict, List, Iterable, Match, NewType, Optional, Tuple, Any,
|
||||
from typing import (Awaitable, Dict, List, Iterable, Match, NewType, Optional, Tuple, Any, cast,
|
||||
TYPE_CHECKING)
|
||||
import logging
|
||||
import asyncio
|
||||
@@ -54,6 +54,7 @@ class User(AbstractUser):
|
||||
command_status: Optional[Dict[str, Any]]
|
||||
|
||||
_db_instance: Optional[DBUser]
|
||||
_ensure_started_lock: asyncio.Lock
|
||||
|
||||
def __init__(self, mxid: UserID, tgid: Optional[TelegramID] = None,
|
||||
username: Optional[str] = None, phone: Optional[str] = None,
|
||||
@@ -73,6 +74,7 @@ class User(AbstractUser):
|
||||
self.portals = {}
|
||||
self.db_portals = db_portals or []
|
||||
self._db_instance = db_instance
|
||||
self._ensure_started_lock = asyncio.Lock()
|
||||
|
||||
self.command_status = None
|
||||
|
||||
@@ -172,8 +174,11 @@ class User(AbstractUser):
|
||||
# endregion
|
||||
# region Telegram connection management
|
||||
|
||||
def ensure_started(self, even_if_no_session=False) -> Awaitable['User']:
|
||||
return super().ensure_started(even_if_no_session)
|
||||
async def ensure_started(self, even_if_no_session=False) -> 'User':
|
||||
if not self.puppet_whitelisted or self.connected:
|
||||
return self
|
||||
async with self._ensure_started_lock:
|
||||
return cast(User, await super().ensure_started(even_if_no_session))
|
||||
|
||||
async def start(self, delete_unless_authenticated: bool = False) -> 'User':
|
||||
await super().start()
|
||||
|
||||
Reference in New Issue
Block a user