Save contacts to db and allow local contact search. Fixes #35

This commit is contained in:
Tulir Asokan
2018-02-11 17:19:17 +02:00
parent 3a14f8d245
commit 04714a2975
4 changed files with 123 additions and 28 deletions
+22 -19
View File
@@ -22,7 +22,6 @@ from mautrix_appservice import MatrixRequestError
from telethon.errors import *
from telethon.tl.types import *
from telethon.tl.functions.contacts import SearchRequest
from telethon.tl.functions.messages import ImportChatInviteRequest, CheckChatInviteRequest
from telethon.tl.functions.channels import JoinChannelRequest
@@ -224,25 +223,29 @@ class CommandHandler:
return await evt.reply("**Usage:** `$cmdprefix+sp search [-r|--remote] <query>`")
elif not evt.sender.logged_in:
return await evt.reply("This command requires you to be logged in.")
# force_remote = False
if evt.args[0] in {"-r", "--remote"}:
# force_remote = True
evt.args.pop(0)
query = " ".join(evt.args)
if len(query) < 5:
return await evt.reply("Minimum length of query for remote search is 5 characters.")
found = await evt.sender.client(SearchRequest(q=query, limit=10))
# reply = ["**People:**", ""]
reply = ["**Results from Telegram server:**", ""]
for result in found.users:
puppet = pu.Puppet.get(result.id)
await puppet.update_info(evt.sender, result)
reply.append(
f"* [{puppet.displayname}](https://matrix.to/#/{puppet.mxid}): {puppet.id}")
# reply.extend(("", "**Chats:**", ""))
# for result in found.chats:
# reply.append(f"* {result.title}")
force_remote = False
if evt.args[0] in {"-r", "--remote"}:
force_remote = True
evt.args.pop(0)
query = " ".join(evt.args)
if force_remote and len(query) < 5:
return await evt.reply("Minimum length of query for remote search is 5 characters.")
results, remote = await evt.sender.search(query, force_remote)
reply = []
if remote:
reply += ["**Results from Telegram server:**", ""]
else:
reply += ["**Results in contacts:**", ""]
reply += [(f"* [{puppet.displayname}](https://matrix.to/#/{puppet.mxid}): "
+ f"{puppet.id} ({similarity}% match)")
for puppet, similarity in results]
# TODO somehow show remote channel results when joining by alias is possible?
return await evt.reply("\n".join(reply))
@command_handler
+12 -1
View File
@@ -14,7 +14,8 @@
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from sqlalchemy import Column, UniqueConstraint, Integer, String
from sqlalchemy import Column, UniqueConstraint, ForeignKey, Integer, String
from sqlalchemy.orm import relationship
from .base import Base
@@ -57,6 +58,16 @@ class User(Base):
mxid = Column(String, primary_key=True)
tgid = Column(Integer, nullable=True)
tg_username = Column(String, nullable=True)
saved_contacts = Column(Integer, default=0)
contacts = relationship("Contact", uselist=True)
class Contact(Base):
query = None
__tablename__ = "contact"
user = Column("user", Integer, ForeignKey("user.tgid"), primary_key=True)
contact = Column("contact", Integer, ForeignKey("puppet.id"), primary_key=True)
class Puppet(Base):
+12 -1
View File
@@ -14,6 +14,7 @@
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from difflib import SequenceMatcher
import re
import logging
@@ -66,6 +67,16 @@ class Puppet:
self.to_db()
self.db.commit()
def similarity(self, query):
username_similarity = (SequenceMatcher(None, self.username, query).ratio()
if self.username else 0)
displayname_similarity = (SequenceMatcher(None, self.displayname, query).ratio()
if self.displayname else 0)
#phone_number_similarity = (SequenceMatcher(None, self.phone_number, query).ratio()
# if self.phone_number else 0)
similarity = max(username_similarity, displayname_similarity)
return round(similarity * 1000) / 10
@staticmethod
def get_displayname(info, format=True):
data = {
@@ -99,7 +110,7 @@ class Puppet:
changed = await self.update_displayname(source, info) or changed
if isinstance(info.photo, UserProfilePhoto):
changed = await self.update_avatar(source, info.photo.photo_big)
changed = await self.update_avatar(source, info.photo.photo_big) or changed
if changed:
self.save()
+77 -7
View File
@@ -19,9 +19,11 @@ import asyncio
import platform
from telethon.tl.types import *
from telethon.tl.types.contacts import ContactsNotModified
from telethon.tl.types import User as TLUser
from telethon.tl.functions.contacts import GetContactsRequest, SearchRequest
from .db import User as DBUser, Message as DBMessage
from .db import User as DBUser, Message as DBMessage, Contact as DBContact
from .tgclient import MautrixTelegramClient
from . import portal as po, puppet as pu, __version__
@@ -36,10 +38,13 @@ class User:
by_mxid = {}
by_tgid = {}
def __init__(self, mxid, tgid=None, username=None):
def __init__(self, mxid, tgid=None, username=None, db_contacts=None, saved_contacts=0):
self.mxid = mxid
self.tgid = tgid
self.username = username
self.contacts = []
self.saved_contacts = saved_contacts
self.db_contacts = db_contacts
self.command_status = None
self.connected = False
@@ -65,13 +70,27 @@ class User:
def has_full_access(self):
return self.logged_in and self.whitelisted
@property
def db_contacts(self):
return [self.db.merge(DBContact(user=self.tgid, contact=puppet.id))
for puppet in self.contacts]
@db_contacts.setter
def db_contacts(self, contacts):
if contacts:
self.contacts = [pu.Puppet.get(entry.contact) for entry in contacts]
else:
self.contacts = []
def get_input_entity(self, user):
return user.client.get_input_entity(InputUser(user_id=self.tgid, access_hash=0))
# region Database conversion
def to_db(self):
return self.db.merge(DBUser(mxid=self.mxid, tgid=self.tgid, tg_username=self.username))
return self.db.merge(
DBUser(mxid=self.mxid, tgid=self.tgid, tg_username=self.username,
contacts=self.db_contacts, saved_contacts=self.saved_contacts))
def save(self):
self.to_db()
@@ -79,7 +98,8 @@ class User:
@classmethod
def from_db(cls, db_user):
return User(db_user.mxid, db_user.tgid, db_user.tg_username)
return User(db_user.mxid, db_user.tgid, db_user.tg_username, db_user.contacts,
db_user.saved_contacts)
# endregion
# region Telegram connection management
@@ -102,8 +122,9 @@ class User:
async def post_login(self, info=None):
try:
await self.sync_dialogs()
await self.update_info(info)
await self.sync_dialogs()
await self.sync_contacts()
except Exception:
self.log.exception("Failed to run post-login functions")
@@ -139,18 +160,67 @@ class User:
await self.client.log_out()
# TODO kick user from portals
def _search_local(self, query, max_results=5, min_similarity=45):
results = []
for contact in self.contacts:
similarity = contact.similarity(query)
if similarity >= min_similarity:
results.append((contact, similarity))
results.sort(key=lambda tup: tup[1], reverse=True)
return results[0:max_results]
async def _search_remote(self, query, max_results=5):
server_results = await self.client(SearchRequest(q=query, limit=max_results))
results = []
for user in server_results.users:
puppet = pu.Puppet.get(user.id)
await puppet.update_info(self, user)
results.append((puppet, puppet.similarity(query)))
results.sort(key=lambda tup: tup[1], reverse=True)
return results[0:max_results]
async def search(self, query, force_remote=False):
if force_remote:
return await self._search_remote(query), True
results = self._search_local(query)
if results:
return results, False
return await self._search_remote(query), True
async def sync_dialogs(self):
dialogs = await self.client.get_dialogs(limit=30)
creators = []
for dialog in dialogs:
entity = dialog.entity
if (isinstance(entity, (TLUser, ChatForbidden, ChannelForbidden)) or (
isinstance(entity, Chat) and (entity.deactivated or entity.left))):
invalid = (isinstance(entity, (TLUser, ChatForbidden, ChannelForbidden))
or (isinstance(entity, Chat) and (entity.deactivated or entity.left)))
if invalid:
continue
portal = po.Portal.get_by_entity(entity)
creators.append(portal.create_matrix_room(self, entity, invites=[self.mxid]))
await asyncio.gather(*creators, loop=self.loop)
def _hash_contacts(self):
acc = 0
for id in sorted([self.saved_contacts] + [contact.id for contact in self.contacts]):
acc = (acc * 20261 + id) & 0xffffffff
return acc & 0x7fffffff
async def sync_contacts(self):
response = await self.client(GetContactsRequest(hash=self._hash_contacts()))
if isinstance(response, ContactsNotModified):
return
self.log.debug("Updating contacts...")
self.contacts = []
self.saved_contacts = response.saved_count
for user in response.users:
puppet = pu.Puppet.get(user.id)
await puppet.update_info(self, user)
self.contacts.append(puppet)
self.save()
# endregion
# region Telegram update handling