Source code for advanced_alchemy.service._util

"""Service object implementation for SQLAlchemy.

RepositoryService object is generic on the domain model type which
should be a SQLAlchemy model.
"""

from __future__ import annotations

from functools import partial
from pathlib import Path, PurePath
from typing import TYPE_CHECKING, Any, Callable, List, Sequence, cast, overload
from uuid import UUID

from advanced_alchemy.exceptions import AdvancedAlchemyError
from advanced_alchemy.filters import LimitOffset
from advanced_alchemy.repository.typing import ModelOrRowMappingT
from advanced_alchemy.service.pagination import OffsetPagination
from advanced_alchemy.service.typing import (
    MSGSPEC_INSTALLED,
    PYDANTIC_INSTALLED,
    BaseModel,
    ModelDTOT,
    Struct,
    convert,
    get_type_adapter,
)

if TYPE_CHECKING:
    from sqlalchemy import ColumnElement, RowMapping

    from advanced_alchemy.base import ModelProtocol
    from advanced_alchemy.filters import StatementFilter
    from advanced_alchemy.service.typing import FilterTypeT

__all__ = ("ResultConverter", "find_filter")


def _default_msgspec_deserializer(
    target_type: Any,
    value: Any,
    type_decoders: Sequence[tuple[Callable[[Any], bool], Callable[[Any, Any], Any]]] | None = None,
) -> Any:  # pragma: no cover
    """Transform values non-natively supported by ``msgspec``

    Args:
        target_type: Encountered type
        value: Value to coerce
        type_decoders: Optional sequence of type decoders

    Returns:
        A ``msgspec``-supported type
    """

    if isinstance(value, target_type):
        return value

    if type_decoders:
        for predicate, decoder in type_decoders:
            if predicate(target_type):
                return decoder(target_type, value)

    if issubclass(target_type, (Path, PurePath, UUID)):
        return target_type(value)

    msg = f"Unsupported type: {type(value)!r}"
    raise TypeError(msg)


[docs] def find_filter( filter_type: type[FilterTypeT], filters: Sequence[StatementFilter | ColumnElement[bool]] | Sequence[StatementFilter], ) -> FilterTypeT | None: """Get the filter specified by filter type from the filters. Args: filter_type: The type of filter to find. filters: filter types to apply to the query Returns: The match filter instance or None """ return next( (cast("FilterTypeT | None", filter_) for filter_ in filters if isinstance(filter_, filter_type)), None, )
[docs] class ResultConverter: """Simple mixin to help convert to a paginated response model. Single objects are transformed to the supplied schema type, and lists of objects are automatically transformed into an `OffsetPagination` response of the supplied schema type. Args: data: A database model instance or row mapping. Type: :class:`~advanced_alchemy.repository.typing.ModelOrRowMappingT` Returns: The converted schema object. """ @overload def to_schema( self, data: ModelOrRowMappingT, total: int | None = None, filters: Sequence[StatementFilter | ColumnElement[bool]] | Sequence[StatementFilter] | None = None, *, schema_type: None = None, ) -> ModelOrRowMappingT: ... @overload def to_schema( self, data: Sequence[ModelOrRowMappingT], total: int | None = None, filters: Sequence[StatementFilter | ColumnElement[bool]] | Sequence[StatementFilter] | None = None, *, schema_type: None = None, ) -> OffsetPagination[ModelOrRowMappingT]: ... @overload def to_schema( self, data: ModelProtocol | RowMapping, total: int | None = None, filters: Sequence[StatementFilter | ColumnElement[bool]] | Sequence[StatementFilter] | None = None, *, schema_type: type[ModelDTOT], ) -> ModelDTOT: ... @overload def to_schema( self, data: Sequence[ModelProtocol] | Sequence[RowMapping], total: int | None = None, filters: Sequence[StatementFilter | ColumnElement[bool]] | Sequence[StatementFilter] | None = None, *, schema_type: type[ModelDTOT], ) -> OffsetPagination[ModelDTOT]: ...
[docs] def to_schema( self, data: ModelOrRowMappingT | Sequence[ModelOrRowMappingT] | ModelProtocol | Sequence[ModelProtocol] | RowMapping | Sequence[RowMapping], total: int | None = None, filters: Sequence[StatementFilter | ColumnElement[bool]] | Sequence[StatementFilter] | None = None, *, schema_type: type[ModelDTOT] | None = None, ) -> ModelOrRowMappingT | OffsetPagination[ModelOrRowMappingT] | ModelDTOT | OffsetPagination[ModelDTOT]: """Convert the object to a response schema. When `schema_type` is None, the model is returned with no conversion. Args: data: The return from one of the service calls. Type: :class:`~advanced_alchemy.repository.typing.ModelOrRowMappingT` total: The total number of rows in the data. filters: :class:`~advanced_alchemy.filters.StatementFilter`| :class:`sqlalchemy.sql.expression.ColumnElement` Collection of route filters. schema_type: :class:`~advanced_alchemy.service.typing.ModelDTOT` Optional schema type to convert the data to Returns: - :class:`~advanced_alchemy.base.ModelProtocol` | :class:`sqlalchemy.orm.RowMapping` | :class:`~advanced_alchemy.service.pagination.OffsetPagination` | :class:`msgspec.Struct` | :class:`pydantic.BaseModel` """ if filters is None: filters = [] if schema_type is None: if not isinstance(data, Sequence): return cast("ModelOrRowMappingT", data) limit_offset = find_filter(LimitOffset, filters=filters) total = total or len(data) limit_offset = limit_offset if limit_offset is not None else LimitOffset(limit=len(data), offset=0) return OffsetPagination[ModelOrRowMappingT]( items=cast("Sequence[ModelOrRowMappingT]", data), limit=limit_offset.limit, offset=limit_offset.offset, total=total, ) if MSGSPEC_INSTALLED and issubclass(schema_type, Struct): if not isinstance(data, Sequence): return cast( "ModelDTOT", convert( obj=data, type=schema_type, from_attributes=True, dec_hook=partial( _default_msgspec_deserializer, type_decoders=[ (lambda x: x is UUID, lambda t, v: t(v.hex)), ], ), ), ) limit_offset = find_filter(LimitOffset, filters=filters) total = total or len(data) limit_offset = limit_offset if limit_offset is not None else LimitOffset(limit=len(data), offset=0) return OffsetPagination[ModelDTOT]( items=convert( obj=data, type=List[schema_type], # type: ignore[valid-type] from_attributes=True, dec_hook=partial( _default_msgspec_deserializer, type_decoders=[ (lambda x: x is UUID, lambda t, v: t(v.hex)), ], ), ), limit=limit_offset.limit, offset=limit_offset.offset, total=total, ) if PYDANTIC_INSTALLED and issubclass(schema_type, BaseModel): if not isinstance(data, Sequence): return cast( "ModelDTOT", get_type_adapter(schema_type).validate_python(data, from_attributes=True), ) limit_offset = find_filter(LimitOffset, filters=filters) total = total if total else len(data) limit_offset = limit_offset if limit_offset is not None else LimitOffset(limit=len(data), offset=0) return OffsetPagination[ModelDTOT]( items=get_type_adapter(List[schema_type]).validate_python(data, from_attributes=True), # type: ignore[valid-type] # pyright: ignore[reportUnknownArgumentType] limit=limit_offset.limit, offset=limit_offset.offset, total=total, ) msg = "`schema_type` should be a valid Pydantic or Msgspec schema" raise AdvancedAlchemyError(msg)