Source code for arq.drain

"""
:mod:`drain`
============

Drain class used by :class:`arq.worker.BaseWorker` and reusable elsewhere.
"""
import asyncio
import logging
from typing import Optional, Set  # noqa

from aioredis import Redis
from async_timeout import timeout

from arq.utils import gen_random

from .jobs import ArqError

__all__ = ['Drain']

# these loggers could do with more sensible names
work_logger = logging.getLogger('arq.work')
jobs_logger = logging.getLogger('arq.jobs')


class TaskError(ArqError, RuntimeError):
    pass


[docs]class Drain: """ Drains popping jobs from redis lists and managing a set of tasks with a limited size to execute those jobs. """ def __init__(self, *, redis: Redis, max_concurrent_tasks: int=50, shutdown_delay: float=6, timeout_seconds: int=60, burst_mode: bool=True, raise_task_exception: bool=False, semaphore_timeout: float=60) -> None: """ :param redis: redis pool to get connection from to pop items from list, also used to optionally re-enqueue pending jobs on termination :param max_concurrent_tasks: maximum number of jobs which can be execute at the same time by the event loop :param shutdown_delay: number of seconds to wait for tasks to finish :param timeout_seconds: maximum duration of a job, after that the job will be cancelled by the event loop :param burst_mode: break the iter loop as soon as no more jobs are available by adding an sentinel quit queue :param raise_task_exception: whether or not to raise an exception which occurs in a processed task """ self.redis = redis self.loop = redis._pool_or_conn._loop self.max_concurrent_tasks = max_concurrent_tasks self.task_semaphore = asyncio.Semaphore(value=max_concurrent_tasks, loop=self.loop) self.shutdown_delay = max(shutdown_delay, 0.1) self.timeout_seconds = timeout_seconds self.burst_mode = burst_mode self.raise_task_exception = raise_task_exception self.pending_tasks: Set[asyncio.futures.Future] = set() self.task_exception: Optional[Exception] = None self.semaphore_timeout = semaphore_timeout self.jobs_complete, self.jobs_failed, self.jobs_timed_out = 0, 0, 0 self.running = False self._finish_lock = asyncio.Lock(loop=self.loop) async def __aenter__(self): self.running = True return self async def __aexit__(self, exc_type, exc_val, exc_tb): cancelled_tasks = await self.finish() if cancelled_tasks: raise TaskError(f'finishing the drain required {cancelled_tasks} tasks to be cancelled') elif self.raise_task_exception and self.task_exception: e = self.task_exception raise TaskError(f'A processed task failed: {e.__class__.__name__}, {e}') from e @property def jobs_in_progress(self): return self.max_concurrent_tasks - self.task_semaphore._value
[docs] async def iter(self, *raw_queues: bytes, pop_timeout=1): """ blpop jobs from redis queues and yield them. Waits for the number of tasks to drop below max_concurrent_tasks before popping. :param raw_queues: tuple of bytes defining queue(s) to pop from. :param pop_timeout: how long to wait on each blpop before yielding None. :yields: tuple ``(raw_queue_name, raw_data)`` or ``(None, None)`` if all jobs are empty """ work_logger.debug('starting main blpop loop') quit_queue = None assert self.running, 'drain iter will only work when the drain is running' if self.burst_mode: quit_queue = b'arq:quit-' + gen_random() work_logger.debug('populating quit queue to prompt exit: %s', quit_queue.decode()) await self.redis.rpush(quit_queue, b'1') raw_queues = tuple(raw_queues) + (quit_queue,) while True: work_logger.debug('task semaphore locked: %r', self.task_semaphore.locked()) try: with timeout(self.semaphore_timeout): await self.task_semaphore.acquire() except asyncio.TimeoutError: work_logger.warning('task semaphore acquisition timed after %0.1fs', self.semaphore_timeout) continue if not self.running: break with await self.redis as r: msg = await r.blpop(*raw_queues, timeout=pop_timeout) if msg is None: yield None, None self.task_semaphore.release() continue raw_queue, raw_data = msg if self.burst_mode and raw_queue == quit_queue: work_logger.debug('got job from the quit queue, stopping') break work_logger.debug('yielding job, jobs in progress %d', self.jobs_in_progress) yield raw_queue, raw_data
[docs] def add(self, coro, job, re_enqueue=False): """ Start job and add it to the pending tasks set. :param coro: coroutine to execute the job :param job: job object, instance of :class:`arq.jobs.Job` or similar :param re_enqueue: whether or not to re-enqueue the job on finish if the job won't finish in time. """ task = self.loop.create_task(coro(job)) if re_enqueue: task.job = job task.re_enqueue = re_enqueue task.add_done_callback(self._job_callback) self.loop.call_later(self.timeout_seconds, self._cancel_job, task, job) self.pending_tasks.add(task)
[docs] async def finish(self, timeout=None): """ Cancel all pending tasks and optionally re-enqueue jobs which haven't finished after the timeout. :param timeout: how long to wait for tasks to finish, defaults to ``shutdown_delay`` """ timeout = timeout or self.shutdown_delay self.running = False cancelled_tasks = 0 if self.pending_tasks: with await self._finish_lock: work_logger.info('drain waiting %0.1fs for %d tasks to finish', timeout, len(self.pending_tasks)) _, pending = await asyncio.wait(self.pending_tasks, timeout=timeout, loop=self.loop) if pending: pipe = self.redis.pipeline() for task in pending: if task.re_enqueue: pipe.rpush(task.job.raw_queue, task.job.raw_data) task.cancel() cancelled_tasks += 1 if pipe._results: await pipe.execute() self.pending_tasks = set() return cancelled_tasks
def _job_callback(self, task): self.task_semaphore.release() self.jobs_complete += 1 task_exception = task.exception() if task_exception: self.running = False self.task_exception = task_exception elif task.result(): self.jobs_failed += 1 self._remove_task(task) jobs_logger.debug('task complete, %d jobs done, %d failed', self.jobs_complete, self.jobs_failed) def _cancel_job(self, task, job): if not task.cancel(): return self.jobs_timed_out += 1 jobs_logger.error('task timed out %r', job) self._remove_task(task) def _remove_task(self, task): try: self.pending_tasks.remove(task) except KeyError: pass