1+ from __future__ import annotations
12import asyncio
23import logging
34import random
45import ssl
56import time
67from collections import defaultdict
78from ssl import SSLContext , PROTOCOL_TLS
8- from typing import Any
9+ from typing import Any , TYPE_CHECKING
910from concurrent .futures import CancelledError
1011from .buildin import Buildin
1112from .protocol import Proto , Protocol , ProtocolWS
1213from ..exceptions import NodeError , AuthError
1314from ..util import strip_code
15+ if TYPE_CHECKING :
16+ from ..room .roombase import RoomBase
17+ from ..client .package import Package
1418
1519
1620class Client (Buildin ):
@@ -54,7 +58,7 @@ def __init__(
5458 self ._scope = '@t' # default to thingsdb scope
5559 self ._pool_idx = 0
5660 self ._reconnecting = False
57- self ._rooms = dict ()
61+ self ._rooms : dict [ int , RoomBase ] = dict ()
5862 self ._rooms_lock = asyncio .Lock ()
5963
6064 if ssl is True :
@@ -64,7 +68,7 @@ def __init__(
6468 else :
6569 self ._ssl = ssl
6670
67- def get_rooms (self ):
71+ def get_rooms (self ) -> tuple [ RoomBase , ...] :
6872 """Can be used to get the rooms which are joined.
6973
7074 Returns:
@@ -185,12 +189,18 @@ def connect_pool(
185189 assert self ._reconnecting is False
186190 assert len (pool ), 'pool must contain at least one node'
187191 if len (auth ) == 1 :
188- auth = auth [0 ] # type: ignore
192+ _auth = auth [0 ] # token or tuple[str, str]
193+ elif len (auth ) == 2 and \
194+ isinstance (auth [0 ], str ) and \
195+ isinstance (auth [1 ], str ):
196+ _auth = (auth [0 ], auth [1 ]) # username/password
197+ else :
198+ raise TypeError ('wrong or missing authentication arguments' )
189199
190200 self ._pool = tuple ((
191201 (address , 9200 ) if isinstance (address , str ) else address
192202 for address in pool ))
193- self ._auth = self ._auth_check (auth )
203+ self ._auth = self ._auth_check (_auth )
194204 self ._pool_idx = random .randint (0 , len (pool ) - 1 )
195205 fut = self .reconnect ()
196206 if fut is None :
@@ -291,8 +301,15 @@ async def authenticate(
291301 wait forever on a response. Defaults to 5.
292302 """
293303 if len (auth ) == 1 :
294- auth = auth [0 ] # type: ignore
295- self ._auth = self ._auth_check (auth )
304+ _auth = auth [0 ] # token or tuple[str, str]
305+ elif len (auth ) == 2 and \
306+ isinstance (auth [0 ], str ) and \
307+ isinstance (auth [1 ], str ):
308+ _auth = (auth [0 ], auth [1 ]) # username/password
309+ else :
310+ raise TypeError ('wrong or missing authentication arguments' )
311+
312+ self ._auth = self ._auth_check (_auth )
296313 await self ._authenticate (timeout )
297314
298315 def query (
@@ -568,7 +585,7 @@ def _leave(self, *ids: int | str,
568585 return self ._write_pkg (Proto .REQ_LEAVE , [scope , * ids ]) # type: ignore
569586
570587 @staticmethod
571- def _auth_check (auth ) :
588+ def _auth_check (auth : str | tuple [ str , str ]) -> str | tuple [ str , str ] :
572589 assert ((
573590 isinstance (auth , (list , tuple )) and
574591 len (auth ) == 2 and
@@ -583,7 +600,7 @@ def _auth_check(auth):
583600 return auth
584601
585602 @staticmethod
586- def _is_websocket_host (host ) :
603+ def _is_websocket_host (host : str ) -> bool :
587604 return host .startswith ('ws://' ) or host .startswith ('wss://' )
588605
589606 async def _connect (self , timeout : int | None = 5 ):
@@ -615,7 +632,7 @@ async def _connect(self, timeout: int | None = 5):
615632 self ._pool_idx += 1
616633 self ._pool_idx %= len (self ._pool )
617634
618- async def _on_room (self , room_id , pkg ):
635+ async def _on_room (self , room_id : int , pkg : Package ):
619636 async with self ._rooms_lock :
620637 try :
621638 room = self ._rooms [room_id ]
@@ -628,8 +645,9 @@ async def _on_room(self, room_id, pkg):
628645 if isinstance (task , asyncio .Task ):
629646 await task
630647
631- def _on_event (self , pkg ):
648+ def _on_event (self , pkg : Package ):
632649 if pkg .tp == Proto .ON_NODE_STATUS :
650+ assert pkg .data is not None
633651 status , node_id = pkg .data ['status' ], pkg .data ['id' ]
634652
635653 if self ._reconnect and status == 'SHUTTING_DOWN' :
@@ -654,7 +672,7 @@ def _on_event(self, pkg):
654672 asyncio .ensure_future (self ._on_room (room_id , pkg ),
655673 loop = self .get_event_loop ())
656674
657- def _on_connection_lost (self , protocol , exc ):
675+ def _on_connection_lost (self , protocol : asyncio . Protocol , exc : Exception ):
658676 if self ._protocol is not protocol :
659677 return
660678 self ._protocol = None
@@ -695,20 +713,23 @@ async def _reconnect_loop(self):
695713 finally :
696714 self ._reconnecting = False
697715
698- def _ping (self , timeout ):
699- return self ._write (Proto .REQ_PING , timeout = timeout )
716+ async def _ping (self , timeout : int | None ):
717+ return await self ._write (Proto .REQ_PING , timeout = timeout )
700718
701- def _authenticate (self , timeout ):
702- return self ._write (Proto .REQ_AUTH , data = self ._auth , timeout = timeout )
719+ async def _authenticate (self , timeout : int | None ) -> asyncio .Future [Any ]:
720+ return await self ._write (
721+ Proto .REQ_AUTH ,
722+ data = self ._auth ,
723+ timeout = timeout )
703724
704725 async def _rejoin (self ):
705726 if not self ._rooms :
706727 return # do nothig if no rooms are used
707728
708729 # re-arrange the rooms per scope to combine joins in a less requests
709- scopes = defaultdict (list )
730+ scopes : dict [ str , list [ int ]] = defaultdict (list )
710731 for room in self ._rooms .values ():
711- if room .id :
732+ if room .id and room . scope :
712733 scopes [room .scope ].append (room .id )
713734
714735 # join request per scope, each for one or more rooms
0 commit comments