import asyncio
import functools
import logging
from dataclasses import dataclass
from datetime import datetime, timedelta
from operator import attrgetter
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union, cast
from urllib.parse import parse_qs, urlparse
from uuid import uuid4
from redis.asyncio import ConnectionPool, Redis
from redis.asyncio.retry import Retry
from redis.asyncio.sentinel import Sentinel
from redis.exceptions import RedisError, WatchError
from .constants import default_queue_name, expires_extra_ms, job_key_prefix, result_key_prefix
from .jobs import Deserializer, Job, JobDef, JobResult, Serializer, deserialize_job, serialize_job
from .utils import timestamp_ms, to_ms, to_unix_ms
logger = logging.getLogger('arq.connections')
[docs]@dataclass
class RedisSettings:
"""
No-Op class used to hold redis connection redis_settings.
Used by :func:`arq.connections.create_pool` and :class:`arq.worker.Worker`.
"""
host: Union[str, List[Tuple[str, int]]] = 'localhost'
port: int = 6379
unix_socket_path: Optional[str] = None
database: int = 0
username: Optional[str] = None
password: Optional[str] = None
ssl: bool = False
ssl_keyfile: Optional[str] = None
ssl_certfile: Optional[str] = None
ssl_cert_reqs: str = 'required'
ssl_ca_certs: Optional[str] = None
ssl_ca_data: Optional[str] = None
ssl_check_hostname: bool = False
conn_timeout: int = 1
conn_retries: int = 5
conn_retry_delay: int = 1
max_connections: Optional[int] = None
sentinel: bool = False
sentinel_master: str = 'mymaster'
retry_on_timeout: bool = False
retry_on_error: Optional[List[Exception]] = None
retry: Optional[Retry] = None
@classmethod
def from_dsn(cls, dsn: str) -> 'RedisSettings':
conf = urlparse(dsn)
if conf.scheme not in {'redis', 'rediss', 'unix'}:
raise RuntimeError('invalid DSN scheme')
query_db = parse_qs(conf.query).get('db')
if query_db:
# e.g. redis://localhost:6379?db=1
database = int(query_db[0])
elif conf.scheme != 'unix':
database = int(conf.path.lstrip('/')) if conf.path else 0
else:
database = 0
return RedisSettings(
host=conf.hostname or 'localhost',
port=conf.port or 6379,
ssl=conf.scheme == 'rediss',
username=conf.username,
password=conf.password,
database=database,
unix_socket_path=conf.path if conf.scheme == 'unix' else None,
)
def __repr__(self) -> str:
return 'RedisSettings({})'.format(', '.join(f'{k}={v!r}' for k, v in self.__dict__.items()))
if TYPE_CHECKING:
BaseRedis = Redis[bytes]
else:
BaseRedis = Redis
[docs]class ArqRedis(BaseRedis):
"""
Thin subclass of ``redis.asyncio.Redis`` which adds :func:`arq.connections.enqueue_job`.
:param redis_settings: an instance of ``arq.connections.RedisSettings``.
:param job_serializer: a function that serializes Python objects to bytes, defaults to pickle.dumps
:param job_deserializer: a function that deserializes bytes into Python objects, defaults to pickle.loads
:param default_queue_name: the default queue name to use, defaults to ``arq.queue``.
:param expires_extra_ms: the default length of time from when a job is expected to start
after which the job expires, defaults to 1 day in ms.
:param kwargs: keyword arguments directly passed to ``redis.asyncio.Redis``.
"""
def __init__(
self,
pool_or_conn: Optional[ConnectionPool] = None,
job_serializer: Optional[Serializer] = None,
job_deserializer: Optional[Deserializer] = None,
default_queue_name: str = default_queue_name,
expires_extra_ms: int = expires_extra_ms,
**kwargs: Any,
) -> None:
self.job_serializer = job_serializer
self.job_deserializer = job_deserializer
self.default_queue_name = default_queue_name
if pool_or_conn:
kwargs['connection_pool'] = pool_or_conn
self.expires_extra_ms = expires_extra_ms
super().__init__(**kwargs)
[docs] async def enqueue_job(
self,
function: str,
*args: Any,
_job_id: Optional[str] = None,
_queue_name: Optional[str] = None,
_defer_until: Optional[datetime] = None,
_defer_by: Union[None, int, float, timedelta] = None,
_expires: Union[None, int, float, timedelta] = None,
_job_try: Optional[int] = None,
**kwargs: Any,
) -> Optional[Job]:
"""
Enqueue a job.
:param function: Name of the function to call
:param args: args to pass to the function
:param _job_id: ID of the job, can be used to enforce job uniqueness
:param _queue_name: queue of the job, can be used to create job in different queue
:param _defer_until: datetime at which to run the job
:param _defer_by: duration to wait before running the job
:param _expires: do not start or retry a job after this duration;
defaults to 24 hours plus deferring time, if any
:param _job_try: useful when re-enqueueing jobs within a job
:param kwargs: any keyword arguments to pass to the function
:return: :class:`arq.jobs.Job` instance or ``None`` if a job with this ID already exists
"""
if _queue_name is None:
_queue_name = self.default_queue_name
job_id = _job_id or uuid4().hex
job_key = job_key_prefix + job_id
if _defer_until and _defer_by:
raise RuntimeError("use either 'defer_until' or 'defer_by' or neither, not both")
defer_by_ms = to_ms(_defer_by)
expires_ms = to_ms(_expires)
async with self.pipeline(transaction=True) as pipe:
await pipe.watch(job_key)
if await pipe.exists(job_key, result_key_prefix + job_id):
await pipe.reset()
return None
enqueue_time_ms = timestamp_ms()
if _defer_until is not None:
score = to_unix_ms(_defer_until)
elif defer_by_ms:
score = enqueue_time_ms + defer_by_ms
else:
score = enqueue_time_ms
expires_ms = expires_ms or score - enqueue_time_ms + self.expires_extra_ms
job = serialize_job(function, args, kwargs, _job_try, enqueue_time_ms, serializer=self.job_serializer)
pipe.multi()
pipe.psetex(job_key, expires_ms, job)
pipe.zadd(_queue_name, {job_id: score})
try:
await pipe.execute()
except WatchError:
# job got enqueued since we checked 'job_exists'
return None
return Job(job_id, redis=self, _queue_name=_queue_name, _deserializer=self.job_deserializer)
async def _get_job_result(self, key: bytes) -> JobResult:
job_id = key[len(result_key_prefix) :].decode()
job = Job(job_id, self, _deserializer=self.job_deserializer)
r = await job.result_info()
if r is None:
raise KeyError(f'job "{key.decode()}" not found')
r.job_id = job_id
return r
[docs] async def all_job_results(self) -> List[JobResult]:
"""
Get results for all jobs in redis.
"""
keys = await self.keys(result_key_prefix + '*')
results = await asyncio.gather(*[self._get_job_result(k) for k in keys])
return sorted(results, key=attrgetter('enqueue_time'))
async def _get_job_def(self, job_id: bytes, score: int) -> JobDef:
key = job_key_prefix + job_id.decode()
v = await self.get(key)
if v is None:
raise RuntimeError(f'job "{key}" not found')
jd = deserialize_job(v, deserializer=self.job_deserializer)
jd.score = score
jd.job_id = job_id.decode()
return jd
[docs] async def queued_jobs(self, *, queue_name: Optional[str] = None) -> List[JobDef]:
"""
Get information about queued, mostly useful when testing.
"""
if queue_name is None:
queue_name = self.default_queue_name
jobs = await self.zrange(queue_name, withscores=True, start=0, end=-1)
return await asyncio.gather(*[self._get_job_def(job_id, int(score)) for job_id, score in jobs])
[docs]async def create_pool(
settings_: Optional[RedisSettings] = None,
*,
retry: int = 0,
job_serializer: Optional[Serializer] = None,
job_deserializer: Optional[Deserializer] = None,
default_queue_name: str = default_queue_name,
expires_extra_ms: int = expires_extra_ms,
) -> ArqRedis:
"""
Create a new redis pool, retrying up to ``conn_retries`` times if the connection fails.
Returns a :class:`arq.connections.ArqRedis` instance, thus allowing job enqueuing.
"""
settings: RedisSettings = RedisSettings() if settings_ is None else settings_
if isinstance(settings.host, str) and settings.sentinel:
raise RuntimeError("str provided for 'host' but 'sentinel' is true; list of sentinels expected")
if settings.sentinel:
def pool_factory(*args: Any, **kwargs: Any) -> ArqRedis:
client = Sentinel( # type: ignore[misc]
*args,
sentinels=settings.host,
ssl=settings.ssl,
**kwargs,
)
redis = client.master_for(settings.sentinel_master, redis_class=ArqRedis)
return cast(ArqRedis, redis)
else:
pool_factory = functools.partial(
ArqRedis,
host=settings.host,
port=settings.port,
unix_socket_path=settings.unix_socket_path,
socket_connect_timeout=settings.conn_timeout,
ssl=settings.ssl,
ssl_keyfile=settings.ssl_keyfile,
ssl_certfile=settings.ssl_certfile,
ssl_cert_reqs=settings.ssl_cert_reqs,
ssl_ca_certs=settings.ssl_ca_certs,
ssl_ca_data=settings.ssl_ca_data,
ssl_check_hostname=settings.ssl_check_hostname,
retry=settings.retry,
retry_on_timeout=settings.retry_on_timeout,
retry_on_error=settings.retry_on_error,
max_connections=settings.max_connections,
)
while True:
try:
pool = pool_factory(
db=settings.database, username=settings.username, password=settings.password, encoding='utf8'
)
pool.job_serializer = job_serializer
pool.job_deserializer = job_deserializer
pool.default_queue_name = default_queue_name
pool.expires_extra_ms = expires_extra_ms
await pool.ping()
except (ConnectionError, OSError, RedisError, asyncio.TimeoutError) as e:
if retry < settings.conn_retries:
logger.warning(
'redis connection error %s:%s %s %s, %s retries remaining...',
settings.host,
settings.port,
e.__class__.__name__,
e,
settings.conn_retries - retry,
)
await asyncio.sleep(settings.conn_retry_delay)
retry = retry + 1
else:
raise
else:
if retry > 0:
logger.info('redis connection successful')
return pool
async def log_redis_info(redis: 'Redis[bytes]', log_func: Callable[[str], Any]) -> None:
async with redis.pipeline(transaction=False) as pipe:
pipe.info(section='Server')
pipe.info(section='Memory')
pipe.info(section='Clients')
pipe.dbsize()
info_server, info_memory, info_clients, key_count = await pipe.execute()
redis_version = info_server.get('redis_version', '?')
mem_usage = info_memory.get('used_memory_human', '?')
clients_connected = info_clients.get('connected_clients', '?')
log_func(
f'redis_version={redis_version} '
f'mem_usage={mem_usage} '
f'clients_connected={clients_connected} '
f'db_keys={key_count}'
)