from collections.abc import Iterable, Sequence
from typing import Any, Literal, Optional, Protocol, Union, cast, overload
from sqlalchemy import (
Delete,
Dialect,
Select,
UnaryExpression,
Update,
)
from sqlalchemy.orm import (
InstrumentedAttribute,
MapperProperty,
RelationshipProperty,
joinedload,
lazyload,
selectinload,
)
from sqlalchemy.orm.strategy_options import (
_AbstractLoad, # pyright: ignore[reportPrivateUsage] # pyright: ignore[reportPrivateUsage]
)
from sqlalchemy.sql import ColumnElement, ColumnExpressionArgument
from sqlalchemy.sql.base import ExecutableOption
from sqlalchemy.sql.dml import ReturningDelete, ReturningUpdate
from sqlalchemy.sql.elements import Label
from typing_extensions import TypeAlias
from advanced_alchemy.base import ModelProtocol
from advanced_alchemy.exceptions import ErrorMessages
from advanced_alchemy.exceptions import wrap_sqlalchemy_exception as _wrap_sqlalchemy_exception
from advanced_alchemy.filters import (
InAnyFilter,
PaginationFilter,
StatementFilter,
StatementTypeT,
)
from advanced_alchemy.repository._typing import arrays_equal, is_numpy_array
from advanced_alchemy.repository.typing import ModelT, OrderingPair
WhereClauseT = ColumnExpressionArgument[bool]
SingleLoad: TypeAlias = Union[
_AbstractLoad,
Literal["*"],
InstrumentedAttribute[Any],
RelationshipProperty[Any],
MapperProperty[Any],
]
LoadCollection: TypeAlias = Sequence[Union[SingleLoad, Sequence[SingleLoad]]]
ExecutableOptions: TypeAlias = Sequence[ExecutableOption]
LoadSpec: TypeAlias = Union[LoadCollection, SingleLoad, ExecutableOption, ExecutableOptions]
OrderByT: TypeAlias = Union[
str,
InstrumentedAttribute[Any],
RelationshipProperty[Any],
]
# NOTE: For backward compatibility with Litestar - this is imported from here within the litestar codebase.
wrap_sqlalchemy_exception = _wrap_sqlalchemy_exception
DEFAULT_ERROR_MESSAGE_TEMPLATES: ErrorMessages = {
"integrity": "There was a data validation error during processing",
"foreign_key": "A foreign key is missing or invalid",
"multiple_rows": "Multiple matching rows found",
"duplicate_key": "A record matching the supplied data already exists.",
"other": "There was an error during data processing",
"check_constraint": "The data failed a check constraint during processing",
"not_found": "The requested resource was not found",
}
"""Default error messages for repository errors."""
[docs]
def get_instrumented_attr(
model: type[ModelProtocol],
key: Union[str, InstrumentedAttribute[Any]],
) -> InstrumentedAttribute[Any]:
"""Get an instrumented attribute from a model.
Args:
model: SQLAlchemy model class.
key: Either a string attribute name or an :class:`sqlalchemy.orm.InstrumentedAttribute`.
Returns:
:class:`sqlalchemy.orm.InstrumentedAttribute`: The instrumented attribute from the model.
"""
if isinstance(key, str):
return cast("InstrumentedAttribute[Any]", getattr(model, key))
return key
[docs]
def model_from_dict(model: type[ModelT], **kwargs: Any) -> ModelT:
"""Create an ORM model instance from a dictionary of attributes.
Args:
model: The SQLAlchemy model class to instantiate.
**kwargs: Keyword arguments containing model attribute values.
Returns:
ModelT: A new instance of the model populated with the provided values.
"""
data = {
attr_name: kwargs[attr_name]
for attr_name in model.__mapper__.attrs.keys() # noqa: SIM118 # pyright: ignore[reportUnknownMemberType]
if attr_name in kwargs
}
return model(**data)
def get_abstract_loader_options(
loader_options: Union[LoadSpec, None],
default_loader_options: Union[list[_AbstractLoad], None] = None,
default_options_have_wildcards: bool = False,
merge_with_default: bool = True,
inherit_lazy_relationships: bool = True,
cycle_count: int = 0,
) -> tuple[list[_AbstractLoad], bool]:
"""Generate SQLAlchemy loader options for eager loading relationships.
Args:
loader_options :class:`~advanced_alchemy.repository.typing.LoadSpec`|:class:`None` Specification for how to load relationships. Can be:
- None: Use defaults
- :class:`sqlalchemy.orm.strategy_options._AbstractLoad`: Direct SQLAlchemy loader option
- :class:`sqlalchemy.orm.InstrumentedAttribute`: Model relationship attribute
- :class:`sqlalchemy.orm.RelationshipProperty`: SQLAlchemy relationship
- str: "*" for wildcard loading
- :class:`typing.Sequence` of the above
default_loader_options: :class:`typing.Sequence` of :class:`sqlalchemy.orm.strategy_options._AbstractLoad` loader options to start with.
default_options_have_wildcards: Whether the default options contain wildcards.
merge_with_default: Whether to merge the default options with the loader options.
inherit_lazy_relationships: Whether to inherit the ``lazy`` configuration from the model's relationships.
cycle_count: Number of times this function has been called recursively.
Returns:
tuple[:class:`list`[:class:`sqlalchemy.orm.strategy_options._AbstractLoad`], bool]: A tuple containing:
- :class:`list` of :class:`sqlalchemy.orm.strategy_options._AbstractLoad` SQLAlchemy loader option objects
- Boolean indicating if any wildcard loaders are present
"""
loads: list[_AbstractLoad] = []
if cycle_count == 0 and not inherit_lazy_relationships:
loads.append(lazyload("*"))
if cycle_count == 0 and merge_with_default and default_loader_options is not None:
loads.extend(default_loader_options)
options_have_wildcards = default_options_have_wildcards
if loader_options is None:
return (loads, options_have_wildcards)
if isinstance(loader_options, _AbstractLoad):
return ([loader_options], options_have_wildcards)
if isinstance(loader_options, InstrumentedAttribute):
loader_options = [loader_options.property]
if isinstance(loader_options, RelationshipProperty):
class_ = loader_options.class_attribute
return (
[selectinload(class_)]
if loader_options.uselist
else [joinedload(class_, innerjoin=loader_options.innerjoin)],
options_have_wildcards if loader_options.uselist else True,
)
if isinstance(loader_options, str) and loader_options == "*":
options_have_wildcards = True
return ([joinedload("*")], options_have_wildcards)
if isinstance(loader_options, (list, tuple)):
for attribute in loader_options: # pyright: ignore[reportUnknownVariableType]
if isinstance(attribute, (list, tuple)):
load_chain, options_have_wildcards = get_abstract_loader_options(
loader_options=attribute, # pyright: ignore[reportUnknownArgumentType]
default_options_have_wildcards=options_have_wildcards,
inherit_lazy_relationships=inherit_lazy_relationships,
merge_with_default=merge_with_default,
cycle_count=cycle_count + 1,
)
loader = load_chain[-1]
for sub_load in load_chain[-2::-1]:
loader = sub_load.options(loader)
loads.append(loader)
else:
load_chain, options_have_wildcards = get_abstract_loader_options(
loader_options=attribute, # pyright: ignore[reportUnknownArgumentType]
default_options_have_wildcards=options_have_wildcards,
inherit_lazy_relationships=inherit_lazy_relationships,
merge_with_default=merge_with_default,
cycle_count=cycle_count + 1,
)
loads.extend(load_chain)
return (loads, options_have_wildcards)
[docs]
class FilterableRepositoryProtocol(Protocol[ModelT]):
"""Protocol defining the interface for filterable repositories.
This protocol defines the required attributes and methods that any
filterable repository implementation must provide.
"""
model_type: type[ModelT]
"""The SQLAlchemy model class this repository manages."""
[docs]
class FilterableRepository(FilterableRepositoryProtocol[ModelT]):
"""Default implementation of a filterable repository.
Provides core filtering, ordering and pagination functionality for
SQLAlchemy models.
"""
model_type: type[ModelT]
"""The SQLAlchemy model class this repository manages."""
prefer_any_dialects: Optional[tuple[str]] = ("postgresql",)
"""List of dialects that prefer to use ``field.id = ANY(:1)`` instead of ``field.id IN (...)``."""
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None
"""List or single :class:`~advanced_alchemy.repository.typing.OrderingPair` to use for sorting."""
_prefer_any: bool = False
"""Whether to prefer ANY() over IN() in queries."""
_dialect: Dialect
"""The SQLAlchemy :class:`sqlalchemy.dialects.Dialect` being used."""
@overload
def _apply_filters(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
apply_pagination: bool = True,
statement: Select[tuple[ModelT]],
) -> Select[tuple[ModelT]]: ...
@overload
def _apply_filters(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
apply_pagination: bool = True,
statement: Delete,
) -> Delete: ...
@overload
def _apply_filters(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
apply_pagination: bool = True,
statement: Union[ReturningDelete[tuple[ModelT]], ReturningUpdate[tuple[ModelT]]],
) -> Union[ReturningDelete[tuple[ModelT]], ReturningUpdate[tuple[ModelT]]]: ...
@overload
def _apply_filters(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
apply_pagination: bool = True,
statement: Update,
) -> Update: ...
def _apply_filters(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
apply_pagination: bool = True,
statement: StatementTypeT,
) -> StatementTypeT:
"""Apply filters to a SQL statement.
Args:
*filters: Filter conditions to apply.
apply_pagination: Whether to apply pagination filters.
statement: The base SQL statement to filter.
Returns:
StatementTypeT: The filtered SQL statement.
"""
for filter_ in filters:
if isinstance(filter_, (PaginationFilter,)):
if apply_pagination:
statement = filter_.append_to_statement(statement, self.model_type)
elif isinstance(filter_, (InAnyFilter,)):
statement = filter_.append_to_statement(statement, self.model_type)
elif isinstance(filter_, ColumnElement):
statement = cast("StatementTypeT", statement.where(filter_))
else:
statement = filter_.append_to_statement(statement, self.model_type)
return statement
def _filter_select_by_kwargs(
self,
statement: StatementTypeT,
kwargs: Union[dict[Any, Any], Iterable[tuple[Any, Any]]],
) -> StatementTypeT:
"""Filter a statement using keyword arguments.
Args:
statement: :class:`sqlalchemy.sql.Select` The SQL statement to filter.
kwargs: Dictionary or iterable of tuples containing filter criteria.
Keys should be model attribute names, values are what to filter for.
Returns:
StatementTypeT: The filtered SQL statement.
"""
for key, val in dict(kwargs).items():
field = get_instrumented_attr(self.model_type, key)
statement = cast("StatementTypeT", statement.where(field == val))
return statement
def _apply_order_by(
self,
statement: StatementTypeT,
order_by: Union[
OrderingPair,
list[OrderingPair],
],
) -> StatementTypeT:
"""Apply ordering to a SQL statement.
Args:
statement: The SQL statement to order.
order_by: Ordering specification. Either a single tuple or list of tuples where:
- First element is the field name or :class:`sqlalchemy.orm.InstrumentedAttribute` to order by
- Second element is a boolean indicating descending (True) or ascending (False)
Returns:
StatementTypeT: The ordered SQL statement.
"""
if not isinstance(order_by, list):
order_by = [order_by]
for order_field in order_by:
if isinstance(order_field, UnaryExpression):
statement = statement.order_by(order_field) # type: ignore
else:
field = get_instrumented_attr(self.model_type, order_field[0])
statement = self._order_by_attribute(statement, field, order_field[1])
return statement
@staticmethod
def _order_by_attribute(
statement: StatementTypeT,
field: InstrumentedAttribute[Any],
is_desc: bool,
) -> StatementTypeT:
"""Apply ordering by a single attribute to a SQL statement.
Args:
statement: The SQL statement to order.
field: The model attribute to order by.
is_desc: Whether to order in descending (True) or ascending (False) order.
Returns:
StatementTypeT: The ordered SQL statement.
"""
if not isinstance(statement, Select):
return statement
return statement.order_by(field.desc() if is_desc else field.asc())
def column_has_defaults(column: Any) -> bool:
"""Check if a column has any type of default value or update handler.
This includes:
- Python-side defaults (column.default)
- Server-side defaults (column.server_default)
- Python-side onupdate handlers (column.onupdate)
- Server-side onupdate handlers (column.server_onupdate)
Args:
column: SQLAlchemy column object to check
Returns:
bool: True if the column has any type of default or update handler
"""
# Label objects (from column_property) don't have default/onupdate attributes
# Return False for these as they represent computed values, not defaulted columns
if isinstance(column, Label):
return False
# Use defensive attribute checking for safety with other column-like objects
return (
getattr(column, "default", None) is not None
or getattr(column, "server_default", None) is not None
or getattr(column, "onupdate", None) is not None
or getattr(column, "server_onupdate", None) is not None
)
def compare_values(existing_value: Any, new_value: Any) -> bool:
"""Safely compare two values, handling numpy arrays and other special types.
This function handles the comparison of values that may include numpy arrays
(such as pgvector's Vector type) which cannot be directly compared using
standard equality operators due to their element-wise comparison behavior.
Args:
existing_value: The current value to compare.
new_value: The new value to compare against.
Returns:
bool: True if values are equal, False otherwise.
"""
# Handle None comparisons
if existing_value is None and new_value is None:
return True
if existing_value is None or new_value is None:
return False
# Handle numpy arrays or array-like objects
if is_numpy_array(existing_value) or is_numpy_array(new_value):
# Both values must be arrays for them to be considered equal
if not (is_numpy_array(existing_value) and is_numpy_array(new_value)):
return False
return arrays_equal(existing_value, new_value)
# Standard equality comparison for all other types
try:
return bool(existing_value == new_value)
except (ValueError, TypeError):
# If comparison fails for any reason, consider them different
# This is a safe fallback that will trigger updates when unsure
return False