Source code for advanced_alchemy.filters

"""Collection filter datastructures."""

from __future__ import annotations

from abc import ABC, abstractmethod
from collections import abc  # noqa: TCH003
from dataclasses import dataclass
from datetime import datetime  # noqa: TCH003
from operator import attrgetter
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast

from sqlalchemy import BinaryExpression, and_, any_, or_, text

if TYPE_CHECKING:
    from typing import Callable

    from sqlalchemy import ColumnElement, Select, StatementLambdaElement
    from sqlalchemy.orm import InstrumentedAttribute
    from typing_extensions import TypeAlias

    from advanced_alchemy import base


__all__ = (
    "BeforeAfter",
    "CollectionFilter",
    "FilterTypes",
    "LimitOffset",
    "OrderBy",
    "SearchFilter",
    "NotInCollectionFilter",
    "OnBeforeAfter",
    "NotInSearchFilter",
    "PaginationFilter",
    "InAnyFilter",
    "StatementFilter",
    "StatementFilterT",
)

T = TypeVar("T")
ModelT = TypeVar("ModelT", bound="base.ModelProtocol")
StatementFilterT = TypeVar("StatementFilterT", bound="StatementFilter")
FilterTypes: TypeAlias = "BeforeAfter | OnBeforeAfter | CollectionFilter[Any] | LimitOffset | OrderBy | SearchFilter | NotInCollectionFilter[Any] | NotInSearchFilter"
"""Aggregate type alias of the types supported for collection filtering."""


[docs]class StatementFilter(ABC):
[docs] @abstractmethod def append_to_statement(self, statement: Select[tuple[ModelT]], model: type[ModelT]) -> Select[tuple[ModelT]]: return statement
[docs] @abstractmethod def append_to_lambda_statement( self, statement: StatementLambdaElement, *args: Any, **kwargs: Any, ) -> StatementLambdaElement: return statement
@staticmethod def _get_instrumented_attr(model: Any, key: str | InstrumentedAttribute[Any]) -> InstrumentedAttribute[Any]: # copy this here to avoid a circular import of `get_instrumented_attribute`. Maybe we move that function somewhere else? if isinstance(key, str): return cast("InstrumentedAttribute[Any]", getattr(model, key)) return key
[docs]@dataclass class BeforeAfter(StatementFilter): """Data required to filter a query on a ``datetime`` column.""" field_name: str """Name of the model attribute to filter on.""" before: datetime | None """Filter results where field earlier than this.""" after: datetime | None """Filter results where field later than this."""
[docs] def append_to_statement(self, statement: Select[tuple[ModelT]], model: type[ModelT]) -> Select[tuple[ModelT]]: field = self._get_instrumented_attr(model, self.field_name) if self.before is not None: statement = statement.where(field < self.before) # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType] if self.after is not None: statement = statement.where(field > self.after) # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType] return statement
[docs] def append_to_lambda_statement( self, statement: StatementLambdaElement, model: type[ModelT], ) -> StatementLambdaElement: field = self._get_instrumented_attr(model, self.field_name) if self.before is not None: before = self.before statement += lambda s: s.where(field < before) # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType] if self.after is not None: after = self.after statement += lambda s: s.where(field > after) # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType] return statement
[docs]@dataclass class OnBeforeAfter(StatementFilter): """Data required to filter a query on a ``datetime`` column.""" field_name: str """Name of the model attribute to filter on.""" on_or_before: datetime | None """Filter results where field is on or earlier than this.""" on_or_after: datetime | None """Filter results where field on or later than this."""
[docs] def append_to_statement(self, statement: Select[tuple[ModelT]], model: type[ModelT]) -> Select[tuple[ModelT]]: field = self._get_instrumented_attr(model, self.field_name) if self.on_or_before is not None: statement = statement.where(field <= self.on_or_before) # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType] if self.on_or_after is not None: statement = statement.where(field >= self.on_or_after) # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType] return statement
[docs] def append_to_lambda_statement( self, statement: StatementLambdaElement, model: type[ModelT], ) -> StatementLambdaElement: field = self._get_instrumented_attr(model, self.field_name) if self.on_or_before is not None: on_or_before = self.on_or_before statement += lambda s: s.where(field <= on_or_before) # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType] if self.on_or_after is not None: on_or_after = self.on_or_after statement += lambda s: s.where(field >= on_or_after) # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType] return statement
[docs]class InAnyFilter(StatementFilter, ABC): """Subclass for methods that have a `prefer_any` attribute."""
[docs]@dataclass class CollectionFilter(InAnyFilter, Generic[T]): """Data required to construct a ``WHERE ... IN (...)`` clause.""" field_name: str """Name of the model attribute to filter on.""" values: abc.Collection[T] | None """Values for ``IN`` clause. An empty list will return an empty result set, however, if ``None``, the filter is not applied to the query, and all rows are returned. """
[docs] def append_to_statement( self, statement: Select[tuple[ModelT]], model: type[ModelT], prefer_any: bool = False, ) -> Select[tuple[ModelT]]: field = self._get_instrumented_attr(model, self.field_name) if self.values is None: return statement if not self.values: return statement.where(text("1=-1")) if prefer_any: return statement.where(any_(self.values) == field) # type: ignore[arg-type] return statement.where(field.in_(self.values))
[docs] def append_to_lambda_statement( self, statement: StatementLambdaElement, model: type[ModelT], prefer_any: bool = False, ) -> StatementLambdaElement: field = self._get_instrumented_attr(model, self.field_name) if self.values is None: return statement if not self.values: statement += lambda s: s.where(text("1=-1")) # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType] return statement if prefer_any: values = self.values statement += lambda s: s.where(any_(values) == field) # type: ignore[arg-type] return statement values = self.values statement += lambda s: s.where(field.in_(values)) # pyright: ignore[reportUnknownLambdaType,reportArgumentType,reportUnknownMemberType] return statement
[docs]@dataclass class NotInCollectionFilter(InAnyFilter, Generic[T]): """Data required to construct a ``WHERE ... NOT IN (...)`` clause.""" field_name: str """Name of the model attribute to filter on.""" values: abc.Collection[T] | None """Values for ``NOT IN`` clause. An empty list or ``None`` will return all rows."""
[docs] def append_to_statement( self, statement: Select[tuple[ModelT]], model: type[ModelT], prefer_any: bool = False, ) -> Select[tuple[ModelT]]: field = self._get_instrumented_attr(model, self.field_name) if not self.values: return statement if prefer_any: return statement.where(any_(self.values) == field) # type: ignore[arg-type] return statement.where(field.in_(self.values))
[docs] def append_to_lambda_statement( self, statement: StatementLambdaElement, model: type[ModelT], prefer_any: bool = False, ) -> StatementLambdaElement: field = self._get_instrumented_attr(model, self.field_name) if not self.values: return statement if prefer_any: values = self.values statement += lambda s: s.where(any_(values) != field) # type: ignore[arg-type] return statement values = self.values statement += lambda s: s.where(field.notin_(values)) # pyright: ignore[reportUnknownLambdaType,reportArgumentType,reportUnknownMemberType] return statement
[docs]class PaginationFilter(StatementFilter, ABC): """Subclass for methods that function as a pagination type."""
[docs]@dataclass class LimitOffset(PaginationFilter): """Data required to add limit/offset filtering to a query.""" limit: int """Value for ``LIMIT`` clause of query.""" offset: int """Value for ``OFFSET`` clause of query."""
[docs] def append_to_statement(self, statement: Select[tuple[ModelT]], model: type[ModelT]) -> Select[tuple[ModelT]]: return statement.limit(self.limit).offset(self.offset)
[docs] def append_to_lambda_statement( self, statement: StatementLambdaElement, model: type[ModelT], ) -> StatementLambdaElement: limit = self.limit offset = self.offset statement += lambda s: s.limit(limit).offset(offset) # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType] return statement
[docs]@dataclass class OrderBy(StatementFilter): """Data required to construct a ``ORDER BY ...`` clause.""" field_name: str """Name of the model attribute to sort on.""" sort_order: Literal["asc", "desc"] = "asc" """Sort ascending or descending"""
[docs] def append_to_statement(self, statement: Select[tuple[ModelT]], model: type[ModelT]) -> Select[tuple[ModelT]]: field = self._get_instrumented_attr(model, self.field_name) if self.sort_order == "desc": return statement.order_by(field.desc()) return statement.order_by(field.asc())
[docs] def append_to_lambda_statement( self, statement: StatementLambdaElement, model: type[ModelT], ) -> StatementLambdaElement: field = self._get_instrumented_attr(model, self.field_name) fragment = field.desc() if self.sort_order == "desc" else field.asc() statement += lambda s: s.order_by(fragment) # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType] return statement
[docs]@dataclass class SearchFilter(StatementFilter): """Data required to construct a ``WHERE field_name LIKE '%' || :value || '%'`` clause.""" field_name: str | set[str] """Name of the model attribute to search on.""" value: str """Search value.""" ignore_case: bool | None = False """Should the search be case insensitive.""" @property def _operator(self) -> Callable[..., ColumnElement[bool]]: return or_ @property def _func(self) -> attrgetter[Callable[[str], BinaryExpression[bool]]]: return attrgetter("ilike" if self.ignore_case else "like") @property def normalized_field_names(self) -> set[str]: return {self.field_name} if isinstance(self.field_name, str) else self.field_name
[docs] def get_search_clauses(self, model: type[ModelT]) -> list[BinaryExpression[bool]]: search_clause: list[BinaryExpression[bool]] = [] for field_name in self.normalized_field_names: field = self._get_instrumented_attr(model, field_name) search_text = f"%{self.value}%" search_clause.append(self._func(field)(search_text)) return search_clause
[docs] def append_to_statement( self, statement: Select[tuple[ModelT]], model: type[ModelT], ) -> Select[tuple[ModelT]]: where_clause = self._operator(*self.get_search_clauses(model)) return statement.where(where_clause)
[docs] def append_to_lambda_statement( self, statement: StatementLambdaElement, model: type[ModelT], ) -> StatementLambdaElement: where_clause = self._operator(*self.get_search_clauses(model)) statement += lambda s: s.where(where_clause) # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType] return statement
[docs]@dataclass class NotInSearchFilter(SearchFilter): """Data required to construct a ``WHERE field_name NOT LIKE '%' || :value || '%'`` clause.""" @property def _operator(self) -> Callable[..., ColumnElement[bool]]: return and_ @property def _func(self) -> attrgetter[Callable[[str], BinaryExpression[bool]]]: return attrgetter("not_ilike" if self.ignore_case else "not_like")