diff --git a/requirements/test.in b/requirements/test.in index 189b30c..279ce0b 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -2,3 +2,4 @@ pytest pytest-cov pytest-xdist coverage +pytest-asyncio diff --git a/requirements/test.txt b/requirements/test.txt index 82f8ff9..f9a6eb9 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -17,8 +17,11 @@ pluggy==1.5.0 pytest==8.2.2 # via # -r requirements/test.in + # pytest-asyncio # pytest-cov # pytest-xdist +pytest-asyncio==0.23.7 + # via -r requirements/test.in pytest-cov==5.0.0 # via -r requirements/test.in pytest-xdist==3.6.1 diff --git a/src/statsd/__init__.py b/src/statsd/__init__.py index 004bf0c..71e0570 100644 --- a/src/statsd/__init__.py +++ b/src/statsd/__init__.py @@ -4,6 +4,8 @@ Version: v\ |version|. """ +from .async_client import BaseAsyncStatsdClient, DebugAsyncStatsdClient +from .base import Sample from .client import ( BaseStatsdClient, DebugStatsdClient, @@ -14,8 +16,11 @@ __all__ = ( + "BaseAsyncStatsdClient", "BaseStatsdClient", + "DebugAsyncStatsdClient", "DebugStatsdClient", + "Sample", "StatsdClient", "UDPStatsdClient", "__version__", diff --git a/src/statsd/async_client.py b/src/statsd/async_client.py new file mode 100644 index 0000000..9b77d48 --- /dev/null +++ b/src/statsd/async_client.py @@ -0,0 +1,169 @@ +from __future__ import annotations + +import abc +import contextlib +import functools +import logging +import time +from typing import Any, AsyncIterator, Awaitable, Callable, Mapping, TypeVar +from typing_extensions import ParamSpec + +from statsd.base import AbstractStatsdClient + + +P = ParamSpec("P") +T = TypeVar("T") +U = TypeVar("U") + +logger = logging.getLogger("statsd") + + +class BaseAsyncStatsdClient(AbstractStatsdClient[Awaitable[None]]): + """ + Base async client. + + This class exposes the public interface and takes care of packet formatting + as well as sampling. It does not actually send packets anywhere, which is + left to concrete subclasses implementing :meth:`_emit`. + """ + + @abc.abstractmethod + async def _emit(self, packets: list[str]) -> None: + """ + Async send implementation. + + This method is responsible for actually sending the formatted packets + and should be implemented by all subclasses. + + It may batch or buffer packets but should not modify them in any way. It + should be agnostic to the Statsd format. + """ + raise NotImplementedError() + + def timed( + self, + name: str | None = None, + *, + tags: Mapping[str, str] | None = None, + sample_rate: float | None = None, + use_distribution: bool = False, + ) -> Callable[[Callable[P, Awaitable[U]]], Callable[P, Awaitable[U]]]: + """ + Wrap a function to record its execution time. + + This just wraps the function call with a :meth:`timer` context manager. + + If a name is not provided, the function name will be used. + + Passing ``use_distribution=True`` will report the value as a globally + aggregated :meth:`distribution` metric instead of a :meth:`timing` + metric. + + >>> client = AsyncStatsdClient() + >>> @client.timed() + ... async def do_something(): + ... pass + """ + + def decorator( + fn: Callable[P, Awaitable[U]], + ) -> Callable[P, Awaitable[U]]: + # TODO: Should the fallback include the module? Class (for methods)? + # or func.__name__ + metric_name = name or fn.__name__ + + @functools.wraps(fn) + async def wrapped(*args: P.args, **kwargs: P.kwargs) -> U: + async with self.timer( + metric_name, + tags=tags, + use_distribution=use_distribution, + sample_rate=sample_rate, + ): + return await fn(*args, **kwargs) + + return wrapped + + return decorator + + @contextlib.asynccontextmanager + async def timer( + self, + name: str, + *, + tags: Mapping[str, str] | None = None, + sample_rate: float | None = None, + use_distribution: bool = False, + ) -> AsyncIterator[None]: + """ + Context manager to measure the execution time of an async block. + + Passing ``use_distribution=True`` will report the value as a globally + aggregated :meth:`distribution` metric instead of a :meth:`timing` + metric. + + >>> client = AsyncStatsdClient() + >>> async def operation(): + ... async with client.timer("download_duration"): + ... pass + """ + start = time.perf_counter() + try: + yield + finally: + duration_ms = int(1000 * (time.perf_counter() - start)) + if use_distribution: + await self.distribution( + name, + duration_ms, + tags=tags, + sample_rate=sample_rate, + ) + else: + await self.timing( + name, + duration_ms, + tags=tags, + sample_rate=sample_rate, + ) + + +class DebugAsyncStatsdClient(BaseAsyncStatsdClient): + """ + Verbose client for development or debugging purposes. + + All Statsd packets will be logged and optionally forwarded to a wrapped + client. + """ + + def __init__( + self, + level: int = logging.INFO, + logger: logging.Logger = logger, + inner: BaseAsyncStatsdClient | None = None, + **kwargs: Any, + ) -> None: + r""" + Initialize DebugStatsdClient. + + :param level: Log level to use, defaults to ``INFO``. + + :param logger: Logger instance to use, defaults to ``statsd``. + + :param inner: Wrapped client. + + :param \**kwargs: Extra arguments forwarded to :class:`BaseAsyncStatsdClient`. + """ + super().__init__(**kwargs) + self.level = level + self.logger = logger + self.inner = inner + + async def _emit(self, packets: list[str]) -> None: + for packet in packets: + self.logger.log(self.level, "> %s", packet) + if self.inner: + await self.inner._emit(packets) + + +AsyncStatsdClient = DebugAsyncStatsdClient diff --git a/src/statsd/base.py b/src/statsd/base.py new file mode 100644 index 0000000..04c3183 --- /dev/null +++ b/src/statsd/base.py @@ -0,0 +1,371 @@ +from __future__ import annotations + +import abc +import datetime +import logging +import random +from typing import Generic, Literal, Mapping, NamedTuple, TypeVar +from typing_extensions import ParamSpec + +from statsd.exceptions import InvalidMetricType, InvalidSampleRate +from statsd.formats import DefaultSerializer, Serializer + + +P = ParamSpec("P") +T = TypeVar("T") + + +logger = logging.getLogger("statsd") + + +class Sample(NamedTuple): + """ + Container for a metric sample. + + :param metric_name: Metric name. This will be namespaced if the client + is namespace. + + :param metric_type: One of the supported Statsd metric types, counter + ("c"), gauge ("g"), timing ("ms") and set ("s"), histogram ("h") or + distribution ("d"). + + :param value: Metric value formatted according to the rules for the type. + + :param tags: A mapping of tag name to their value. This will be merged + with the client's tags if relevant, overriding any tag already set. + """ + + metric_name: str + metric_type: Literal["c", "g", "s", "ms", "h", "d"] + value: str + tags: Mapping[str, str] | None = None + + +class AbstractStatsdClient(abc.ABC, Generic[T]): + """ + Abstract Statsd client interface. + + This class exists to share implementation details between the async and sync + base classes. For most use cases you should not be subclassing this + directly; prefer ``BaseStatsdClient`` and ``BaseAsyncStatsdClient`` drop the + generic type parameter. + """ + + KNOWN_METRIC_TYPES = ("c", "g", "s", "ms", "h", "d") + + def __init__( + self, + *, + namespace: str | None = None, + tags: Mapping[str, str] | None = None, + sample_rate: float = 1, + serializer: Serializer | None = None, + ) -> None: + """ + :param namespace: Optional prefix to add all metrics. + + If this is set to ``foo``, then all metrics will be prefixed with + ``foo.``; so for instance sending out ``bar`` would actually be sent + as ``foo.bar``. + + :param tags: Default tags applied to all metrics. + + :param sample_rate: Default sampling rate applied to all metrics. + This should be between 0 and 1 inclusive, 1 meaning that all metrics + will be forwarded and 0 that none will be forwarded. + Defaults to 1. + + :param serializer: A serializer defining the wire format of the metrics. + This allows supporting diverging server implementation such as how + Telegraf and Dogstatsd handle tags. See :mod:`statsd.formats` for + more details. + """ + if not (0 <= sample_rate <= 1): + raise InvalidSampleRate(sample_rate) + + self.namespace = namespace + self.default_tags = tags or {} + self.default_sample_rate = sample_rate + self.serializer = ( + DefaultSerializer() if serializer is None else serializer + ) + + def _serialize( + self, + sample: Sample, + sample_rate: float, + ) -> str: + if sample.metric_type not in self.KNOWN_METRIC_TYPES: + raise InvalidMetricType(sample.metric_type) + + return self.serializer.serialize( + ( + # TODO: Is defaulting to ``.`` separator the right call here? + # Alternative 1: Use a prefix that simply prepended + # Alternative 2: Make the separator configurable + # Alternative 3: Make this configurable through an override of + # some sort (`serialize_name` or similar.) + f"{self.namespace}.{sample.metric_name}" + if self.namespace + else sample.metric_name + ), + sample.metric_type, + sample.value, + sample_rate=sample_rate, + tags={**self.default_tags, **(sample.tags or {})}, + ) + + # Shared interface + + def emit(self, *samples: Sample, sample_rate: float | None = None) -> T: + """ + Send samples to the underlying implementation. + + This method takes care of making any sampling decision and building the + actual packets that will be sent to the server through :meth:`_emit`. + which is purely responsible for sending the packet. + + This may modify the samples in various ways: + + - The metric name will be namespaced if the client is namespaced. + - The tags will be merged with the client's tags if relevant but the + sample tags have precedence. + + .. note:: + Calling this method with multiple samples will result in the sampling + decision being applied to all of them as a whole. + + :param samples: List of samples to send. + + :param sample_rate: Sampling rate applied to this particular call. + Should be between 0 and 1 inclusive, 1 meaning that all metrics + will be forwarded and 0 that none will be forwarded. If left + unspecified this will use the client's sample rate. + """ + sample_rate = ( + sample_rate if sample_rate is not None else self.default_sample_rate + ) + if not (0 <= sample_rate <= 1): + raise InvalidSampleRate(sample_rate) + + filtered_out = sample_rate < 1 and random.random() > sample_rate + + if not filtered_out: + return self._emit([ + self._serialize(x, sample_rate=sample_rate) for x in samples + ]) + else: + # WARN: This is weird. + return self._emit([]) + + def increment( + self, + name: str, + value: int = 1, + *, + tags: Mapping[str, str] | None = None, + sample_rate: float | None = None, + ) -> T: + """ + Increment a counter by the specified value (defaults to 1). + + :param metric_name: Metric name. + + :param value: Increment step. + + :param tags: A mapping of tag name to their value. + """ + return self.emit( + Sample(name, "c", str(value), tags=tags), + sample_rate=sample_rate, + ) + + def decrement( + self, + name: str, + value: int = 1, + *, + tags: Mapping[str, str] | None = None, + sample_rate: float | None = None, + ) -> T: + """ + Decrement a counter by the specified value (defaults to 1). + + :param metric_name: Metric name. + + :param value: Decrement step. + + :param tags: A mapping of tag name to their value. + """ + return self.emit( + Sample(name, "c", str(-1 * value), tags=tags), + sample_rate=sample_rate, + ) + + def gauge( + self, + name: str, + value: int | float, + *, + is_update: bool = False, + tags: Mapping[str, str] | None = None, + sample_rate: float | None = None, + ) -> T: + """ + Update a gauge value. + + The ``is_update`` parameter can be used to control whether this sets the + value of the gauge or sends a gauge delta packet (prepended with ``+`` + or ``-``). + + .. warning:: + Not all Statsd server implementations support Gauge deltas. Notably + Datadog protocol does not (see: + https://github.com/DataDog/dd-agent/issues/573 for more info). + + .. warning:: + Gauges can be integers or floats although floats may not be + supported by all servers. + + :param metric_name: Metric name. + + :param value: The updated gauge value. + + :param tags: A mapping of tag name to their value. + """ + if is_update: + _value = f"{'+' if value >= 0 else ''}{value}" + else: + _value = str(value) + + samples = [] + if value < 0 and not is_update: + samples.append(Sample(name, "g", "0", tags=tags)) + samples.append(Sample(name, "g", _value, tags=tags)) + + return self.emit(*samples, sample_rate=sample_rate) + + def timing( + self, + name: str, + value: int | datetime.timedelta, + *, + tags: Mapping[str, str] | None = None, + sample_rate: float | None = None, + ) -> T: + """ + Send a timing value. + + Timing usually are aggregated by the StatsD server receiving them. + + :param metric_name: Metric name. + + :param value: Timing value. Expected to be in milliseconds. + + :param tags: A mapping of tag name to their value. + """ + # TODO: Some server implementation support higher resolution timers + # using floats. We could support this with a flag. + if isinstance(value, datetime.timedelta): + value = int(1000 * value.total_seconds()) + + return self.emit( + Sample(name, "ms", str(value), tags=tags), + sample_rate=sample_rate, + ) + + def set( + self, + name: str, + value: int, + *, + tags: Mapping[str, str] | None = None, + sample_rate: float | None = None, + ) -> T: + """ + Update a set counter. + + :param metric_name: Metric name. + + :param value: The number of occurences. + + :param tags: A mapping of tag name to their value. + """ + return self.emit( + Sample(name, "s", str(value), tags=tags), + sample_rate=sample_rate, + ) + + def histogram( + self, + name: str, + value: float, + *, + tags: Mapping[str, str] | None = None, + sample_rate: float | None = None, + ) -> T: + """ + Send an histogram sample. + + Histograms, like timings are usually aggregated locally but the StatsD + server receiving them. + + .. warning:: + This is not a standard metric type and is not supported by all + StatsD backends. + + :param metric_name: Metric name. + + :param value: The recorded value. + + :param tags: A mapping of tag name to their value. + """ + return self.emit( + Sample(name, "h", str(value), tags=tags), + sample_rate=sample_rate, + ) + + def distribution( + self, + name: str, + value: float, + *, + tags: Mapping[str, str] | None = None, + sample_rate: float | None = None, + ) -> T: + """ + Send a distribution sample. + + Distributions are usually aggregated globally by a centralised service + (e.g. Veneur, Datadog) and not locally by any intermediary StatsD + server. + + .. warning:: + This is not a standard metric type and is not supported by all + StatsD backends. + + :param metric_name: Metric name. + + :param value: The recorded value. + + :param tags: A mapping of tag name to their value. + """ + return self.emit( + Sample(name, "d", str(value), tags=tags), + sample_rate=sample_rate, + ) + + # Implementation specific methods. + + @abc.abstractmethod + def _emit(self, packets: list[str]) -> T: + """ + Send implementation. + + This method is responsible for actually sending the formatted packets + and should be implemented by all subclasses. + + It may batch or buffer packets but should not modify them in any way. It + should be agnostic to the Statsd format. + """ + raise NotImplementedError() diff --git a/src/statsd/client.py b/src/statsd/client.py index 26d8a7b..8785194 100644 --- a/src/statsd/client.py +++ b/src/statsd/client.py @@ -2,19 +2,16 @@ import abc import contextlib -import datetime import errno import functools import logging -import random import socket import threading import time from typing import Any, Callable, Iterator, Mapping, TypeVar from typing_extensions import ParamSpec -from statsd.exceptions import InvalidMetricType, InvalidSampleRate -from statsd.formats import DefaultSerializer, Serializer +from statsd.base import AbstractStatsdClient P = ParamSpec("P") @@ -23,237 +20,32 @@ logger = logging.getLogger("statsd") -class BaseStatsdClient(abc.ABC): +class BaseStatsdClient(AbstractStatsdClient[None]): """ - Generic Statsd client interface. + Base client. This class exposes the public interface and takes care of packet formatting as well as sampling. It does not actually send packets anywhere, which is - left to concrete subclasses. + left to concrete subclasses implementing :meth:`_emit`. .. warning:: This class makes no assumption around the underlying implementation behaviour. Delivery guarantees, thread safety, robustness to error are all left to specific implementations. - - :param namespace: Optional prefix to add all metrics. - - If this is set to ``foo``, then all metrics will be prefixed with - ``foo.``; so for instance sending out ``bar`` would actually be sent - as ``foo.bar``. - - :param tags: Default tags applied to all metrics. - - :param sample_rate: Default sampling rate applied to all metrics. - This should be between 0 and 1 inclusive, 1 meaning that all metrics - will be forwarded and 0 that none will be forwarded. - Defaults to 1. - - :param serializer: A serializer defining the wire format of the metrics. - This allows supporting diverging server implementation such as how - Telegraf and Dogstatsd handle tags. See :mod:`statsd.formats` for - more details. """ - KNOWN_METRIC_TYPES = ("c", "g", "s", "ms", "h", "d") - - def __init__( - self, - *, - namespace: str | None = None, - tags: Mapping[str, str] | None = None, - sample_rate: float = 1, - serializer: Serializer | None = None, - ) -> None: - if not (0 <= sample_rate <= 1): - raise InvalidSampleRate(sample_rate) - - self.namespace = namespace - self.default_tags = tags or {} - self.default_sample_rate = sample_rate - self.serializer = ( - DefaultSerializer() if serializer is None else serializer - ) - - # Shared interface - - def emit( - self, - metric_name: str, - metric_type: str, - value: str, - *, - tags: Mapping[str, str] | None = None, - sample_rate: float | None = None, - ) -> None: - """ - Send a metric to the underlying implementation. - - This method takes care sampling metrics and builds the actual packet - that will be sent to the server; implementations are purely responsible - for sending the packet and have nothing to do with the Statsd format. - - :param metric_name: Metric name. This will be namespaced if the client - is namespace. - - :param metric_type: One of the supported Statsd metric types, counter - ("c"), gauge ("g"), timing ("ms") and set ("s"). - - :param value: etric value formatted according to the rules for the type. - - :param tags: A mapping of tag name to their value. This will be merged - with the client's tags if relevant, overriding any tag already set. - - :param sample_rate: Sampling rate applied to this particular call. - Should be between 0 and 1 inclusive, 1 meaning that all metrics - will be forwarded and 0 that none will be forwarded. If left - unspecified this will use the client's sample rate. - """ - sample_rate = ( - sample_rate if sample_rate is not None else self.default_sample_rate - ) - if not (0 <= sample_rate <= 1): - raise InvalidSampleRate(sample_rate) - - if sample_rate < 1 and random.random() > sample_rate: - return - - self._emit_packet( - self._serialize_metric( - metric_name, - metric_type, - value, - sample_rate=sample_rate, - tags=tags, - ), - ) - - def _serialize_metric( - self, - metric_name: str, - metric_type: str, - value: str, - sample_rate: float, - tags: Mapping[str, str] | None, - ) -> str: - if metric_type not in self.KNOWN_METRIC_TYPES: - raise InvalidMetricType(metric_type) - - return self.serializer.serialize( - ( - # TODO: Is defaulting to ``.`` separator the right call here? - # Alternative 1: Use a prefix that simply prepended - # Alternative 2: Make the separator configurable - # Alternative 3: Make this configurable through an override of - # some sort (`serialize_name` or similar.) - f"{self.namespace}.{metric_name}" - if self.namespace - else metric_name - ), - metric_type, - value, - sample_rate=sample_rate, - tags={**self.default_tags, **(tags or {})}, - ) - - def increment( - self, - name: str, - value: int = 1, - *, - tags: Mapping[str, str] | None = None, - sample_rate: float | None = None, - ) -> None: - """ - Increment a counter by the specified value (defaults to 1). - - See :meth:`emit` for details on optional parameters. - """ - self.emit(name, "c", str(value), tags=tags, sample_rate=sample_rate) - - def decrement( - self, - name: str, - value: int = 1, - *, - tags: Mapping[str, str] | None = None, - sample_rate: float | None = None, - ) -> None: - """ - Decrement a counter by the specified value (defaults to 1). - - See :meth:`emit` for details on optional parameters. - """ - self.emit( - name, - "c", - str(-1 * value), - tags=tags, - sample_rate=sample_rate, - ) - - def gauge( - self, - name: str, - value: int | float, - *, - is_update: bool = False, - tags: Mapping[str, str] | None = None, - sample_rate: float | None = None, - ) -> None: - """ - Update a gauge value. - - The ``is_update`` parameter can be used to control whether this sets the - value of the gauge or sends a gauge delta packet (prepended with ``+`` - or ``-``). - - .. warning:: - Not all Statsd server implementations support Gauge deltas. Notably - Datadog protocol does not (see: - https://github.com/DataDog/dd-agent/issues/573 for more info). - - .. warning:: - Gauges can be integers or floats although floats may not be - supported by all servers. - - See :meth:`emit` for details on other parameters. - """ - if is_update: - _value = f"{'+' if value >= 0 else ''}{value}" - else: - _value = str(value) - - with _Batcher(self, sample_rate=sample_rate) as batch: - if value < 0 and not is_update: - # WARN: This could be subject to race condition depending on the - # underlying transport and buffering settings. - batch.queue(name, "g", "0", tags=tags) - batch.queue(name, "g", _value, tags=tags) - - def timing( - self, - name: str, - value: int | datetime.timedelta, - *, - tags: Mapping[str, str] | None = None, - sample_rate: float | None = None, - ) -> None: + @abc.abstractmethod + def _emit(self, packets: list[str]) -> None: """ - Send a timing value. - - Timing usually are aggregated by the StatsD server receiving them. + Send implementation. - The ``value`` is expected to be in milliseconds. + This method is responsible for actually sending the formatted packets + and should be implemented by all subclasses. - See :meth:`emit` for details on optional parameters. + It may batch or buffer packets but should not modify them in any way. It + should be agnostic to the Statsd format. """ - # TODO: Some server implementation support higher resolution timers - # using floats. We could support this with a flag. - if isinstance(value, datetime.timedelta): - value = int(1000 * value.total_seconds()) - - self.emit(name, "ms", str(value), tags=tags, sample_rate=sample_rate) + raise NotImplementedError() def timed( self, @@ -309,7 +101,7 @@ def timer( use_distribution: bool = False, ) -> Iterator[None]: """ - Context manager to measure the execution time of a block in milliseconds. + Context manager to measure the execution time of a block. Passing ``use_distribution=True`` will report the value as a globally aggregated :meth:`distribution` metric instead of a :meth:`timing` @@ -339,122 +131,6 @@ def timer( sample_rate=sample_rate, ) - def set( - self, - name: str, - value: int, - *, - tags: Mapping[str, str] | None = None, - sample_rate: float | None = None, - ) -> None: - """Update a set counter.""" - self.emit(name, "s", str(value), tags=tags, sample_rate=sample_rate) - - def histogram( - self, - name: str, - value: float, - *, - tags: Mapping[str, str] | None = None, - sample_rate: float | None = None, - ) -> None: - """ - Send an histogram sample. - - Histograms, like timings are usually aggregated locally but the StatsD - server receiving them. - - .. warning:: - This is not a standard metric type and is not supported by all - StatsD backends. - """ - self.emit(name, "h", str(value), tags=tags, sample_rate=sample_rate) - - def distribution( - self, - name: str, - value: float, - *, - tags: Mapping[str, str] | None = None, - sample_rate: float | None = None, - ) -> None: - """ - Send a distribution sample. - - Distributions are usually aggregated globally by a centralised service - (e.g. Veneur, Datadog) and not locally by any intermediary StatsD - server. - - .. warning:: - This is not a standard metric type and is not supported by all - StatsD backends. - """ - self.emit(name, "d", str(value), tags=tags, sample_rate=sample_rate) - - # Implementation specific methods. - - @abc.abstractmethod - def _emit_packet(self, packet: str) -> None: - """ - Send implementation. - - This method is responsible for actually sending the formatted packets - and should be implemented by all subclasses. - """ - raise NotImplementedError() - - -class _Batcher: - def __init__( - self, - inner: BaseStatsdClient, - sample_rate: float | None = None, - ) -> None: - sample_rate = ( - sample_rate - if sample_rate is not None - else inner.default_sample_rate - ) - if not (0 <= sample_rate <= 1): - raise InvalidSampleRate(sample_rate) - - self.sample_rate = sample_rate - self.batch: list[str] = [] - self.inner = inner - - def flush(self) -> None: - if self.sample_rate < 1 and random.random() > self.sample_rate: - return - - for x in self.batch: - self.inner._emit_packet(x) - - self.batch[:] = [] - - def queue( - self, - metric_name: str, - metric_type: str, - value: str, - *, - tags: Mapping[str, str] | None = None, - ) -> None: - self.batch.append( - self.inner._serialize_metric( - metric_name, - metric_type, - value, - sample_rate=self.sample_rate, - tags=tags, - ), - ) - - def __enter__(self) -> _Batcher: - return self - - def __exit__(self, *args: Any) -> None: - self.flush() - class DebugStatsdClient(BaseStatsdClient): """ @@ -487,10 +163,11 @@ def __init__( self.logger = logger self.inner = inner - def _emit_packet(self, packet: str) -> None: - self.logger.log(self.level, "> %s", packet) + def _emit(self, packets: list[str]) -> None: + for packet in packets: + self.logger.log(self.level, "> %s", packet) if self.inner: - self.inner._emit_packet(packet) + self.inner._emit(packets) class UDPStatsdClient(BaseStatsdClient): @@ -545,6 +222,11 @@ def __init__( self.port = port self.sock: socket.socket | None = None + def _emit(self, packets: list[str]) -> None: + with self.lock: + for x in packets: + self._emit_packet(x) + def _socket(self) -> socket.socket: """Lazily instantiate the socket, this should only happen once.""" if self.sock is None: @@ -575,25 +257,24 @@ def _flush_buffer(self) -> None: def _emit_packet(self, packet: str) -> None: """Handle metric packets, buffering and flusing the buffer accordingly.""" - with self.lock: - msg = packet.encode("ascii") + msg = packet.encode("ascii") - # Buffering disabled, send immediately. - if not self.max_buffer_size: - return self._send(msg) + # Buffering disabled, send immediately. + if not self.max_buffer_size: + return self._send(msg) - msg_size = len(msg) + msg_size = len(msg) - would_overflow = ( - self.buffer_size + len(self.buffer) + msg_size - > self.max_buffer_size - ) + would_overflow = ( + self.buffer_size + len(self.buffer) + msg_size + > self.max_buffer_size + ) - if would_overflow: - self._flush_buffer() + if would_overflow: + self._flush_buffer() - self.buffer.append(msg) - self.buffer_size += msg_size + self.buffer.append(msg) + self.buffer_size += msg_size def _send(self, data: bytes) -> None: """Actually send data.""" diff --git a/tests/test_base_async_client.py b/tests/test_base_async_client.py new file mode 100644 index 0000000..1511535 --- /dev/null +++ b/tests/test_base_async_client.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +import logging +from typing import Any +from unittest import mock + +import pytest + +from statsd import BaseAsyncStatsdClient, DebugAsyncStatsdClient + + +class MockClient(BaseAsyncStatsdClient): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.mock = mock.Mock() + + async def _emit(self, packets: list[str]) -> None: + if packets: + self.mock(packets) + + def assert_emitted(self, expected: list[str] | str) -> None: + self.mock.assert_called_once_with( + expected if isinstance(expected, list) else [expected], + ) + + def assert_did_not_emit(self) -> None: + self.mock.assert_not_called() + + +@pytest.mark.asyncio() +async def test_timed_decorator() -> None: + client = MockClient() + + @client.timed("foo", tags={"foo": "1"}) + async def fn() -> None: + pass + + with mock.patch( + "time.perf_counter", + side_effect=[7.886838544, 20.181117592], + ): + await fn() + + client.mock.assert_called_once_with(["foo:12294|ms|#foo:1"]) + + +@pytest.mark.asyncio() +async def test_timed_decorator_use_distribution() -> None: + client = MockClient() + + @client.timed("foo", tags={"foo": "1"}, use_distribution=True) + async def fn() -> None: + pass + + with mock.patch( + "time.perf_counter", + side_effect=[7.886838544, 20.181117592], + ): + await fn() + + client.mock.assert_called_once_with(["foo:12294|d|#foo:1"]) + + +SIMPLE_TEST_CASES: list[ + tuple[str, tuple[Any, ...], dict[str, Any], list[str] | str] +] = [ + ("increment", ("foo",), {}, "foo:1|c"), + ("increment", ("foo", 10), {}, "foo:10|c"), +] + + +@pytest.mark.parametrize( + ("method", "args", "kwargs", "expected"), + SIMPLE_TEST_CASES, +) +@pytest.mark.asyncio() +async def test_debug_client_no_inner( + method: str, + args: tuple[Any, ...], + kwargs: dict[str, Any], + expected: list[str] | str, + caplog: Any, +) -> None: + client = DebugAsyncStatsdClient() + with caplog.at_level(logging.INFO, logger="statsd"): + await getattr(client, method)(*args, **kwargs) + + if isinstance(expected, list): + assert len(caplog.records) == len(expected) + for x in expected: + assert x in caplog.text + else: + assert len(caplog.records) == 1 + assert expected in caplog.text + + +@pytest.mark.parametrize( + ("method", "args", "kwargs", "expected"), + SIMPLE_TEST_CASES, +) +@pytest.mark.asyncio() +async def test_debug_client( + method: str, + args: tuple[Any, ...], + kwargs: dict[str, Any], + expected: str | list[str], +) -> None: + inner = MockClient() + client = DebugAsyncStatsdClient(inner=inner) + await getattr(client, method)(*args, **kwargs) + inner.assert_emitted(expected) + + +@pytest.mark.parametrize( + ("method", "args", "kwargs", "expected"), + SIMPLE_TEST_CASES, +) +@pytest.mark.asyncio() +async def test_debug_client_custom_logger_and_level( + method: str, + args: tuple[Any, ...], + kwargs: dict[str, Any], + expected: list[str] | str, + caplog: Any, +) -> None: + client = DebugAsyncStatsdClient( + logger=logging.getLogger("foo"), + level=logging.DEBUG, + ) + with caplog.at_level(logging.DEBUG, logger="foo"): + await getattr(client, method)(*args, **kwargs) + + if isinstance(expected, list): + assert len(caplog.records) == len(expected) + for x in expected: + assert x in caplog.text + else: + assert len(caplog.records) == 1 + assert expected in caplog.text diff --git a/tests/test_base_client.py b/tests/test_base_client.py index 02ca9f8..88a18bb 100644 --- a/tests/test_base_client.py +++ b/tests/test_base_client.py @@ -7,7 +7,7 @@ import pytest -from statsd import BaseStatsdClient, DebugStatsdClient +from statsd import BaseStatsdClient, DebugStatsdClient, Sample from statsd.exceptions import InvalidMetricType, InvalidSampleRate @@ -16,41 +16,17 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.mock = mock.Mock() - def _emit_packet(self, packet: str) -> None: - self.mock(packet) - - -def _assert_calls(fn: mock.MagicMock, *expected: tuple[Any, ...]) -> None: - assert expected - if len(expected) > 1: - assert len(fn.call_args_list) == len(expected) - assert fn.call_args_list == [mock.call(*x) for x in expected] - else: - fn.assert_called_once_with(*expected[0]) - - -def assert_emits( - client: MockClient, - method: str, - args: tuple[Any, ...], - kwargs: dict[str, Any], - expected: list[str] | str, -) -> None: - client.mock.reset_mock() - getattr(client, method)(*args, **kwargs) - expected = expected if isinstance(expected, list) else [expected] - _assert_calls(client.mock, *((x,) for x in expected)) + def _emit(self, packets: list[str]) -> None: + if packets: + self.mock(packets) + def assert_emitted(self, expected: list[str] | str) -> None: + self.mock.assert_called_once_with( + expected if isinstance(expected, list) else [expected], + ) -def assert_does_not_emit( - client: MockClient, - method: str, - args: tuple[Any, ...], - kwargs: dict[str, Any], -) -> None: - client.mock.reset_mock() - getattr(client, method)(*args, **kwargs) - client.mock.assert_not_called() + def assert_did_not_emit(self) -> None: + self.mock.assert_not_called() SIMPLE_TEST_CASES: list[ @@ -89,13 +65,9 @@ def test_basic_metrics( kwargs: dict[str, Any], expected: str | list[str], ) -> None: - assert_emits( - MockClient(), - method, - args, - kwargs, - expected, - ) + client = MockClient() + getattr(client, method)(*args, **kwargs) + client.assert_emitted(expected) @pytest.mark.parametrize("value", [-1, 100, 1.1]) @@ -109,77 +81,52 @@ def test_invalid_sample_rate(value: float) -> None: MockClient(sample_rate=value) +def test_validates_invalid_metric_type() -> None: + client = MockClient() + with pytest.raises(InvalidMetricType): + client.emit(Sample("foo", "p", "54")) # type: ignore[arg-type] + + def test_sample_rate_out() -> None: client = MockClient() with mock.patch("random.random", side_effect=lambda: 0.75): - assert_does_not_emit( - client, - "increment", - ("foo", 5), - {"sample_rate": 0.5}, - ) + client.increment("foo", 5, sample_rate=0.5) + client.assert_did_not_emit() def test_sample_rate_in() -> None: client = MockClient() with mock.patch("random.random", side_effect=lambda: 0.25): - assert_emits( - client, - "increment", - ("foo", 5), - {"sample_rate": 0.5}, - ["foo:5|c|@0.5"], - ) - - -def test_validates_invalid_metric_type() -> None: - client = MockClient() - with pytest.raises(InvalidMetricType): - client.emit("foo", "p", "54") + client.increment("foo", 5, sample_rate=0.5) + client.assert_emitted(["foo:5|c|@0.5"]) def test_default_sample_rate_out() -> None: client = MockClient(sample_rate=0.5) with mock.patch("random.random", side_effect=lambda: 0.75): - assert_does_not_emit(client, "increment", ("foo", 5), {}) + client.increment("foo", 5) + client.assert_did_not_emit() def test_default_sample_rate_in() -> None: client = MockClient(sample_rate=0.5) with mock.patch("random.random", side_effect=lambda: 0.25): - assert_emits( - client, - "increment", - ("foo", 5), - {}, - ["foo:5|c|@0.5"], - ) + client.increment("foo", 5) + client.assert_emitted(["foo:5|c|@0.5"]) def test_batched_messages_are_sampled_as_one_in() -> None: client = MockClient(sample_rate=0.5) with mock.patch("random.random", side_effect=lambda: 0.25): - assert_emits( - client, - "gauge", - ("foo", -5), - {}, - [ - "foo:0|g|@0.5", - "foo:-5|g|@0.5", - ], - ) + client.gauge("foo", -5) + client.assert_emitted(["foo:0|g|@0.5", "foo:-5|g|@0.5"]) def test_batched_messages_are_sampled_as_one_out() -> None: client = MockClient(sample_rate=0.5) with mock.patch("random.random", side_effect=lambda: 0.75): - assert_does_not_emit( - client, - "gauge", - ("foo", -5), - {}, - ) + client.gauge("foo", -5) + client.assert_did_not_emit() @pytest.mark.parametrize( @@ -201,13 +148,9 @@ def test_basic_with_tags( expected: str | list[str], ) -> None: tags = {"foo": "1", "bar": "value"} - assert_emits( - MockClient(), - method, - args, - {**kwargs, "tags": tags}, - expected, - ) + client = MockClient() + getattr(client, method)(*args, **{**kwargs, "tags": tags}) + client.assert_emitted(expected) @pytest.mark.parametrize( @@ -229,18 +172,15 @@ def test_default_tags( expected: str | list[str], ) -> None: tags = {"foo": "1", "bar": "value"} - assert_emits(MockClient(tags=tags), method, args, {**kwargs}, expected) + client = MockClient(tags=tags) + getattr(client, method)(*args, **{**kwargs, "tags": tags}) + client.assert_emitted(expected) def test_metric_tag_overrides_default_tags() -> None: client = MockClient(tags={"foo": "1", "bar": "value"}) - assert_emits( - client, - "increment", - ("foo",), - {"tags": {"foo": "2", "baz": "other_value"}}, - ["foo:1|c|#foo:2,bar:value,baz:other_value"], - ) + client.increment("foo", tags={"foo": "2", "baz": "other_value"}) + client.assert_emitted(["foo:1|c|#foo:2,bar:value,baz:other_value"]) def test_timed_decorator() -> None: @@ -256,7 +196,7 @@ def fn() -> None: ): fn() - client.mock.assert_called_once_with("foo:12294|ms|#foo:1") + client.mock.assert_called_once_with(["foo:12294|ms|#foo:1"]) def test_timed_decorator_use_distribution() -> None: @@ -272,7 +212,7 @@ def fn() -> None: ): fn() - client.mock.assert_called_once_with("foo:12294|d|#foo:1") + client.mock.assert_called_once_with(["foo:12294|d|#foo:1"]) @pytest.mark.parametrize( @@ -309,11 +249,10 @@ def test_debug_client( kwargs: dict[str, Any], expected: str | list[str], ) -> None: - mock_inner = mock.Mock() - client = DebugStatsdClient(inner=mock_inner) + inner = MockClient() + client = DebugStatsdClient(inner=inner) getattr(client, method)(*args, **kwargs) - expected = expected if isinstance(expected, list) else [expected] - _assert_calls(mock_inner._emit_packet, *((x,) for x in expected)) + inner.assert_emitted(expected) @pytest.mark.parametrize(