Only call ensure_started for logged in users at startup. Fixes #247

This commit is contained in:
Tulir Asokan
2018-12-20 14:25:06 +02:00
parent 1ae4a63d4e
commit f519ea0193
2 changed files with 9 additions and 15 deletions
+6 -9
View File
@@ -168,16 +168,13 @@ class AbstractUser(ABC):
return self
async def ensure_started(self, even_if_no_session=False) -> 'AbstractUser':
if not self.puppet_whitelisted:
if not self.puppet_whitelisted or self.connected:
return self
self.log.debug("ensure_started(%s, connected=%s, even_if_no_session=%s, session_count=%s)",
self.mxid, self.connected, even_if_no_session,
self.session_container.Session.query.filter(
self.session_container.Session.session_id == self.mxid).count())
should_connect = (even_if_no_session or
self.session_container.Session.query.filter(
self.session_container.Session.session_id == self.mxid).count() > 0)
if not self.connected and should_connect:
session_count = self.session_container.Session.query.filter(
self.session_container.Session.session_id == self.mxid).count()
self.log.debug("ensure_started(%s, even_if_no_session=%s, session_count=%s)",
self.mxid, even_if_no_session, session_count)
if even_if_no_session or session_count > 0:
await self.start(delete_unless_authenticated=not even_if_no_session)
return self
+3 -6
View File
@@ -14,7 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Coroutine, Dict, List, Match, NewType, Optional, Tuple, cast, TYPE_CHECKING
from typing import Awaitable, Dict, List, Match, NewType, Optional, Tuple, TYPE_CHECKING
import logging
import asyncio
import re
@@ -207,9 +207,6 @@ class User(AbstractUser):
# endregion
# region Telegram actions that need custom methods
def ensure_started(self, even_if_no_session: bool = False) -> Coroutine[None, None, 'User']:
return cast(Coroutine[None, None, 'User'], super().ensure_started(even_if_no_session))
async def set_presence(self, online: bool = True) -> None:
if not self.is_bot:
await self.client(UpdateStatusRequest(offline=not online))
@@ -399,9 +396,9 @@ class User(AbstractUser):
# endregion
def init(context: 'Context') -> List[Coroutine]: # [None, None, AbstractUser]
def init(context: 'Context') -> List[Awaitable['AbstractUser']]:
global config
config = context.config
users = [User.from_db(user) for user in DBUser.query.all()]
return [user.ensure_started() for user in users]
return [user.ensure_started() for user in users if user.tgid]