from __future__ import annotations
import uuid
import asyncio
import logging
import functools
from collections import deque
from dataclasses import dataclass
from collections.abc import Callable
from concurrent.futures import Future as ConcurrentFuture
from typing import Any
from .manager import (
LeasedExecutorManager,
_coerce_positive_int,
_coerce_finite_duration
)
logger = logging.getLogger(__name__)
@dataclass(slots=True)
class _WorkItem:
work_id: str
fn: Callable[..., Any]
args: tuple[Any, ...]
kwargs: dict[str, Any]
result_future: asyncio.Future[Any]
submitted_at: float
owner: str | None
[docs]
class WorkGrinder:
"""Async work batcher backed by leased executors.
Multiple async callers submit sync work. The grinder starts processing a batch
when either:
- the oldest pending work has waited at least max_wait_seconds, or
- pending work count reaches batch_size_threshold.
Once a batch is ready, it leases one executor and submits the whole batch.
"""
[docs]
def __init__(
self,
*,
executor_manager: LeasedExecutorManager,
max_wait_seconds: float = 10.0,
batch_size_threshold: int = 20,
lease_seconds: float = 60.0,
owner_prefix: str = "work-grinder",
logger: logging.Logger | None = None,
):
"""Initialize a WorkGrinder instance.
Args:
executor_manager (LeasedExecutorManager): The executor manager to lease executors from.
max_wait_seconds (float, optional): The maximum time to wait before processing a batch. Defaults to 10.0.
batch_size_threshold (int, optional): The number of pending work items to trigger batch processing. Defaults to 20.
lease_seconds (float, optional): The duration to lease an executor for each batch. Defaults to 60.0.
owner_prefix (str, optional): The prefix for the owner identifier of each batch. Defaults to "work-grinder".
Raises:
ValueError: If max_wait_seconds is not greater than 0.
ValueError: If batch_size_threshold is not greater than 0.
ValueError: If lease_seconds is not greater than 0.
"""
max_wait_seconds = _coerce_finite_duration(
"max_wait_seconds", max_wait_seconds
)
batch_size_threshold = _coerce_positive_int(
"batch_size_threshold", batch_size_threshold
)
lease_seconds = _coerce_finite_duration(
"lease_seconds",
lease_seconds
)
self._executor_manager = executor_manager
self._logger = logger or logging.getLogger(__name__)
self._max_wait_seconds = max_wait_seconds
self._batch_size_threshold = batch_size_threshold
self._lease_seconds = lease_seconds
self._owner_prefix = owner_prefix
self._pending: deque[_WorkItem] = deque()
self._condition = asyncio.Condition()
self._loop: asyncio.AbstractEventLoop | None = None
self._task: asyncio.Task[None] | None = None
self._started = False
self._stopping = False
self._batch_seq = 0
[docs]
async def start(self) -> None:
"""Start the WorkGrinder.
This method initializes the event loop and starts the grinder loop task.
"""
running_loop = asyncio.get_running_loop()
if self._started:
if self._loop is not running_loop:
raise RuntimeError(
"WorkGrinder is already started on a different event loop"
)
return
self._loop = running_loop
self._stopping = False
self._task = asyncio.create_task(
self._grinder_loop(),
name="leasepool-work-grinder",
)
self._started = True
self._logger.info(
"WorkGrinder started max_wait_seconds=%s batch_size_threshold=%s "
"lease_seconds=%s",
self._max_wait_seconds,
self._batch_size_threshold,
self._lease_seconds,
)
[docs]
async def stop(self, *, cancel_pending: bool = False) -> None:
"""Stop the WorkGrinder.
Args:
cancel_pending (bool, optional): Whether to cancel pending work items. Defaults to False.
"""
if not self._started:
return
_ = self._require_owner_loop()
self._stopping = True
async with self._condition:
if cancel_pending:
while self._pending:
item = self._pending.popleft()
if not item.result_future.done():
item.result_future.cancel()
self._condition.notify_all()
task = self._task
if task is not None:
if cancel_pending:
task.cancel()
try:
await task
except asyncio.CancelledError:
if not cancel_pending:
raise
finally:
self._task = None
self._started = False
self._stopping = False
self._loop = None
self._logger.info("WorkGrinder stopped")
[docs]
async def submit(
self,
fn: Callable[..., Any],
/,
*args: Any,
owner: str | None = None,
**kwargs: Any,
) -> Any:
"""Submit a work item to the WorkGrinder.
Args:
fn (Callable[..., Any]): The function to execute.
owner (str | None, optional): The owner of the work item. Defaults to None.
Returns:
Any: The result of the work item.
"""
future = await self.enqueue(fn, *args, owner=owner, **kwargs)
return await future
[docs]
async def enqueue(
self,
fn: Callable[..., Any],
/,
*args: Any,
owner: str | None = None,
**kwargs: Any,
) -> asyncio.Future[Any]:
"""Enqueue a work item to the WorkGrinder.
Args:
fn (Callable[..., Any]): The function to execute.
owner (str | None, optional): The owner of the work item. Defaults to None.
Raises:
RuntimeError: If the WorkGrinder is not started.
RuntimeError: If the WorkGrinder is stopping.
Returns:
asyncio.Future[Any]: A future representing the result of the work item.
"""
loop = self._require_owner_loop()
if self._stopping:
raise RuntimeError("WorkGrinder is stopping")
result_future: asyncio.Future[Any] = loop.create_future()
item = _WorkItem(
work_id=uuid.uuid4().hex,
fn=fn,
args=args,
kwargs=kwargs,
result_future=result_future,
submitted_at=loop.time(),
owner=owner,
)
result_future.add_done_callback(self._on_result_future_done)
async with self._condition:
self._pending.append(item)
pending_count = len(self._pending)
self._logger.debug(
"Queued work work_id=%s owner=%s pending=%s",
item.work_id,
owner,
pending_count,
)
if pending_count >= self._batch_size_threshold:
self._condition.notify_all()
else:
self._condition.notify()
return result_future
[docs]
def submit_from_thread(
self,
fn: Callable[..., Any],
/,
*args: Any,
owner: str | None = None,
**kwargs: Any,
) -> ConcurrentFuture[Any]:
"""Submit a work item to the WorkGrinder from a different thread.
Args:
fn (Callable[..., Any]): The function to execute.
owner (str | None, optional): The owner of the work item. Defaults to None.
Raises:
RuntimeError: If the WorkGrinder is not started.
Returns:
ConcurrentFuture[Any]: A future representing the result of the work item.
"""
loop = self._loop
if not self._started or loop is None or loop.is_closed():
raise RuntimeError("WorkGrinder is not started")
self._reject_owner_loop_thread_sync_call("submit_from_thread")
if self._stopping:
raise RuntimeError("WorkGrinder is stopping")
return asyncio.run_coroutine_threadsafe(
self.submit(fn, *args, owner=owner, **kwargs),
loop,
)
[docs]
def stats(self) -> dict[str, Any]:
"""Get the current statistics of the WorkGrinder.
This method must be called from the WorkGrinder event-loop thread while
the grinder is running. Use stats_from_thread() from other threads.
It is also safe before start or after stop.
"""
oldest_wait_seconds = 0.0
if self._started:
loop = self._require_owner_loop()
if self._pending:
oldest_wait_seconds = max(
0.0,
loop.time() - self._pending[0].submitted_at,
)
return {
"started": self._started,
"stopping": self._stopping,
"pending": len(self._pending),
"batch_size_threshold": self._batch_size_threshold,
"max_wait_seconds": self._max_wait_seconds,
"lease_seconds": self._lease_seconds,
"oldest_wait_seconds": oldest_wait_seconds,
}
[docs]
async def astats(self) -> dict[str, Any]:
"""Get the current statistics of the WorkGrinder asynchronously.
Returns:
dict[str, Any]: A dictionary containing the current statistics.
"""
loop = self._require_owner_loop()
async with self._condition:
oldest_wait_seconds = 0.0
if self._pending:
oldest_wait_seconds = max(
0.0,
loop.time() - self._pending[0].submitted_at,
)
return {
"started": self._started,
"stopping": self._stopping,
"pending": len(self._pending),
"batch_size_threshold": self._batch_size_threshold,
"max_wait_seconds": self._max_wait_seconds,
"lease_seconds": self._lease_seconds,
"oldest_wait_seconds": oldest_wait_seconds,
}
[docs]
def stats_from_thread(self, timeout: float | None = None) -> dict[str, Any]:
"""Get the current statistics of the WorkGrinder from a different thread.
Args:
timeout (float | None, optional): The maximum time to wait for the statistics.
Defaults to None.
Raises:
RuntimeError: If the WorkGrinder is not started.
Returns:
dict[str, Any]: A dictionary containing the current statistics.
"""
loop = self._loop
if not self._started or loop is None or loop.is_closed():
raise RuntimeError("WorkGrinder is not started")
self._reject_owner_loop_thread_sync_call("stats_from_thread")
future = asyncio.run_coroutine_threadsafe(self.astats(), loop)
return future.result(timeout=timeout)
def _require_owner_loop(self) -> asyncio.AbstractEventLoop:
"""Return the owning loop or raise if called from the wrong loop."""
owner_loop = self._loop
if not self._started or owner_loop is None:
raise RuntimeError("WorkGrinder is not started")
try:
running_loop = asyncio.get_running_loop()
except RuntimeError as exc:
raise RuntimeError(
"WorkGrinder async methods must be called from its owning "
"event loop; use submit_from_thread() or stats_from_thread() "
"from other threads."
) from exc
if running_loop is not owner_loop:
raise RuntimeError(
"WorkGrinder async methods must be called from its owning "
"event loop; use submit_from_thread() or stats_from_thread() "
"from other threads."
)
return owner_loop
def _reject_owner_loop_thread_sync_call(self, method_name: str) -> None:
"""Reject sync thread APIs when called from the owning event-loop thread."""
owner_loop = self._loop
if owner_loop is None:
return
try:
running_loop = asyncio.get_running_loop()
except RuntimeError:
return
if running_loop is owner_loop:
raise RuntimeError(
f"WorkGrinder.{method_name}() cannot be called from the "
"owning event-loop thread; use the async API instead."
)
def _on_result_future_done(self, future: asyncio.Future[Any]) -> None:
"""Schedule pending-queue cleanup when a queued result future is cancelled."""
if not future.cancelled():
return
loop = self._loop
if loop is None or loop.is_closed():
return
loop.create_task(self._remove_cancelled_pending_item(future))
async def _remove_cancelled_pending_item(
self,
future: asyncio.Future[Any],
) -> None:
"""Remove a cancelled future from the pending queue if it has not run yet."""
async with self._condition:
removed: _WorkItem | None = None
for item in self._pending:
if item.result_future is future:
removed = item
break
if removed is None:
return
self._pending.remove(removed)
self._logger.debug(
"Removed cancelled pending work work_id=%s owner=%s pending=%s",
removed.work_id,
removed.owner,
len(self._pending),
)
self._condition.notify_all()
async def _grinder_loop(self) -> None:
"""The main loop of the WorkGrinder.
This loop continuously waits for the next batch of work items and processes them.
It exits when the WorkGrinder is stopping and there are no more pending work items.
"""
try:
while True:
batch = await self._wait_for_next_batch()
if not batch:
if self._stopping:
break
continue
await self._process_batch(batch)
if self._stopping:
async with self._condition:
if not self._pending:
break
except asyncio.CancelledError:
raise
except Exception:
self._logger.exception("WorkGrinder loop crashed")
async with self._condition:
while self._pending:
item = self._pending.popleft()
if not item.result_future.done():
item.result_future.set_exception(
RuntimeError("WorkGrinder loop crashed")
)
async def _wait_for_next_batch(self) -> list[_WorkItem]:
"""Wait for the next batch of work items.
Returns:
list[_WorkItem]: A list of work items for the next batch.
"""
assert self._loop is not None
async with self._condition:
while True:
if self._pending:
now = self._loop.time()
oldest_wait_seconds = now - self._pending[0].submitted_at
threshold_reached = len(self._pending) >= self._batch_size_threshold
timeout_reached = oldest_wait_seconds >= self._max_wait_seconds
if threshold_reached or timeout_reached or self._stopping:
return self._drain_pending_locked()
remaining = self._max_wait_seconds - oldest_wait_seconds
try:
await asyncio.wait_for(
self._condition.wait(),
timeout=remaining,
)
except asyncio.TimeoutError:
pass
else:
if self._stopping:
return []
await self._condition.wait()
def _drain_pending_locked(self) -> list[_WorkItem]:
"""Drain all pending work items.
Returns:
list[_WorkItem]: A list of all pending work items.
"""
batch = list(self._pending)
self._pending.clear()
return batch
async def _process_batch(self, batch: list[_WorkItem]) -> None:
"""Process a batch of work items.
Args:
batch (list[_WorkItem]): The batch of work items to process.
"""
live_batch = [item for item in batch if not item.result_future.cancelled()]
if not live_batch:
return
self._batch_seq += 1
batch_id = self._batch_seq
lease_owner = f"{self._owner_prefix}-batch-{batch_id}"
self._logger.info(
"Processing batch batch_id=%s size=%s lease_seconds=%s",
batch_id,
len(live_batch),
self._lease_seconds,
)
lease = None
submitted_items: list[_WorkItem] = []
executor_futures: list[asyncio.Future[Any]] = []
try:
lease = await self._executor_manager.acquire(
lease_seconds=self._lease_seconds,
owner=lease_owner,
wait=True,
)
loop = asyncio.get_running_loop()
for item in live_batch:
call = functools.partial(item.fn, *item.args, **item.kwargs)
try:
executor_future = loop.run_in_executor(
lease.executor,
call,
)
except Exception as exc:
self._logger.exception(
"Failed to submit work batch_id=%s work_id=%s owner=%s",
batch_id,
item.work_id,
item.owner,
)
for unsubmitted_item in live_batch[len(submitted_items):]:
if not unsubmitted_item.result_future.done():
unsubmitted_item.result_future.set_exception(exc)
break
submitted_items.append(item)
executor_futures.append(executor_future)
if executor_futures:
results = await asyncio.gather(
*executor_futures,
return_exceptions=True,
)
for item, result in zip(submitted_items, results, strict=True):
if item.result_future.done():
continue
if isinstance(result, BaseException):
item.result_future.set_exception(result)
else:
item.result_future.set_result(result)
self._logger.info(
"Finished batch batch_id=%s size=%s",
batch_id,
len(live_batch),
)
except asyncio.CancelledError:
self._logger.info(
"Cancelled batch batch_id=%s size=%s",
batch_id,
len(live_batch),
)
for executor_future in executor_futures:
executor_future.cancel()
for item in live_batch:
if not item.result_future.done():
item.result_future.cancel()
raise
except Exception as exc:
self._logger.exception(
"Batch failed batch_id=%s size=%s",
batch_id,
len(live_batch),
)
for item in live_batch:
if not item.result_future.done():
item.result_future.set_exception(exc)
finally:
if lease is not None:
await lease.release()