Source code for advanced_alchemy.repository._sync

# Do not edit this file directly. It has been autogenerated from
# advanced_alchemy/repository/_async.py
import contextlib
import datetime
import decimal
import random
import string
from collections.abc import Iterable, Sequence
from typing import (
    TYPE_CHECKING,
    Any,
    Final,
    Literal,
    Optional,
    Protocol,
    Union,
    cast,
    runtime_checkable,
)

from sqlalchemy import (
    Delete,
    Result,
    Row,
    Select,
    TextClause,
    Update,
    any_,
    delete,
    inspect,
    over,
    select,
    text,
    update,
)
from sqlalchemy import func as sql_func
from sqlalchemy.exc import MissingGreenlet, NoInspectionAvailable
from sqlalchemy.orm import InstrumentedAttribute, Session
from sqlalchemy.orm.scoping import scoped_session
from sqlalchemy.orm.strategy_options import _AbstractLoad  # pyright: ignore[reportPrivateUsage]
from sqlalchemy.sql import ColumnElement
from sqlalchemy.sql.dml import ReturningDelete, ReturningUpdate
from sqlalchemy.sql.selectable import ForUpdateParameter

from advanced_alchemy.exceptions import ErrorMessages, NotFoundError, RepositoryError, wrap_sqlalchemy_exception
from advanced_alchemy.filters import StatementFilter, StatementTypeT
from advanced_alchemy.repository._util import (
    DEFAULT_ERROR_MESSAGE_TEMPLATES,
    FilterableRepository,
    FilterableRepositoryProtocol,
    LoadSpec,
    column_has_defaults,
    compare_values,
    get_abstract_loader_options,
    get_instrumented_attr,
    was_attribute_set,
)
from advanced_alchemy.repository.typing import MISSING, ModelT, OrderingPair, T
from advanced_alchemy.service.typing import schema_dump
from advanced_alchemy.utils.dataclass import Empty, EmptyType
from advanced_alchemy.utils.text import slugify

if TYPE_CHECKING:
    from sqlalchemy.engine.interfaces import _CoreSingleExecuteParams  # pyright: ignore[reportPrivateUsage]

DEFAULT_INSERTMANYVALUES_MAX_PARAMETERS: Final = 950
POSTGRES_VERSION_SUPPORTING_MERGE: Final = 15
DEFAULT_SAFE_TYPES: Final[set[type[Any]]] = {
    int,
    float,
    str,
    bool,
    bytes,
    decimal.Decimal,
    datetime.date,
    datetime.datetime,
    datetime.time,
    datetime.timedelta,
}


[docs] @runtime_checkable class SQLAlchemySyncRepositoryProtocol(FilterableRepositoryProtocol[ModelT], Protocol[ModelT]): """Base Protocol""" id_attribute: str match_fields: Optional[Union[list[str], str]] = None statement: Select[tuple[ModelT]] session: Union[Session, scoped_session[Session]] auto_expunge: bool auto_refresh: bool auto_commit: bool order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None error_messages: Optional[ErrorMessages] = None wrap_exceptions: bool = True
[docs] def __init__(
self, *, statement: Optional[Select[tuple[ModelT]]] = None, session: Union[Session, scoped_session[Session]], auto_expunge: bool = False, auto_refresh: bool = True, auto_commit: bool = False, load: Optional[LoadSpec] = None, execution_options: Optional[dict[str, Any]] = None, order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None, error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty, wrap_exceptions: bool = True, **kwargs: Any, ) -> None: ...
[docs] @classmethod def get_id_attribute_value(
cls, item: Union[ModelT, type[ModelT]], id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None, ) -> Any: ...
[docs] @classmethod def set_id_attribute_value(
cls, item_id: Any, item: ModelT, id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None, ) -> ModelT: ...
[docs] @staticmethod def check_not_found(item_or_none: Optional[ModelT]) -> ModelT: ...
[docs] def add(
self, data: ModelT, *, auto_commit: Optional[bool] = None, auto_expunge: Optional[bool] = None, auto_refresh: Optional[bool] = None, error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty, ) -> ModelT: ...
[docs] def add_many(
self, data: list[ModelT], *, auto_commit: Optional[bool] = None, auto_expunge: Optional[bool] = None, error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty, ) -> Sequence[ModelT]: ...
[docs] def delete(
self, item_id: Any, *, auto_commit: Optional[bool] = None, auto_expunge: Optional[bool] = None, id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None, error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty, load: Optional[LoadSpec] = None, execution_options: Optional[dict[str, Any]] = None, ) -> ModelT: ...
[docs] def delete_many(
self, item_ids: list[Any], *, auto_commit: Optional[bool] = None, auto_expunge: Optional[bool] = None, id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None, chunk_size: Optional[int] = None, error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty, load: Optional[LoadSpec] = None, execution_options: Optional[dict[str, Any]] = None, ) -> Sequence[ModelT]: ...
[docs] def delete_where(
self, *filters: Union[StatementFilter, ColumnElement[bool]], auto_commit: Optional[bool] = None, auto_expunge: Optional[bool] = None, load: Optional[LoadSpec] = None, error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty, execution_options: Optional[dict[str, Any]] = None, sanity_check: bool = True, **kwargs: Any, ) -> Sequence[ModelT]: ...
[docs] def exists(
self, *filters: Union[StatementFilter, ColumnElement[bool]], load: Optional[LoadSpec] = None, error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty, execution_options: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> bool: ...
[docs] def get(
self, item_id: Any, *, auto_expunge: Optional[bool] = None, statement: Optional[Select[tuple[ModelT]]] = None, id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None, error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty, load: Optional[LoadSpec] = None, execution_options: Optional[dict[str, Any]] = None, ) -> ModelT: ...
[docs] def get_one(
self, *filters: Union[StatementFilter, ColumnElement[bool]], auto_expunge: Optional[bool] = None, statement: Optional[Select[tuple[ModelT]]] = None, error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty, load: Optional[LoadSpec] = None, execution_options: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> ModelT: ...
[docs] def get_one_or_none(
self, *filters: Union[StatementFilter, ColumnElement[bool]], auto_expunge: Optional[bool] = None, statement: Optional[Select[tuple[ModelT]]] = None, error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty, load: Optional[LoadSpec] = None, execution_options: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> Optional[ModelT]: ...
[docs] def get_or_upsert(
self, *filters: Union[StatementFilter, ColumnElement[bool]], match_fields: Optional[Union[list[str], str]] = None, upsert: bool = True, attribute_names: Optional[Iterable[str]] = None, with_for_update: ForUpdateParameter = None, auto_commit: Optional[bool] = None, auto_expunge: Optional[bool] = None, auto_refresh: Optional[bool] = None, error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty, load: Optional[LoadSpec] = None, execution_options: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> tuple[ModelT, bool]: ...
[docs] def get_and_update(
self, *filters: Union[StatementFilter, ColumnElement[bool]], match_fields: Optional[Union[list[str], str]] = None, attribute_names: Optional[Iterable[str]] = None, with_for_update: ForUpdateParameter = None, auto_commit: Optional[bool] = None, auto_expunge: Optional[bool] = None, auto_refresh: Optional[bool] = None, error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty, load: Optional[LoadSpec] = None, execution_options: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> tuple[ModelT, bool]: ...
[docs] def count(
self, *filters: Union[StatementFilter, ColumnElement[bool]], statement: Optional[Select[tuple[ModelT]]] = None, load: Optional[LoadSpec] = None, error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty, execution_options: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> int: ...
[docs] def update(
self, data: ModelT, *, attribute_names: Optional[Iterable[str]] = None, with_for_update: ForUpdateParameter = None, auto_commit: Optional[bool] = None, auto_expunge: Optional[bool] = None, auto_refresh: Optional[bool] = None, id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None, error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty, load: Optional[LoadSpec] = None, execution_options: Optional[dict[str, Any]] = None, ) -> ModelT: ...
[docs] def update_many(
self, data: list[ModelT], *, auto_commit: Optional[bool] = None, auto_expunge: Optional[bool] = None, error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty, load: Optional[LoadSpec] = None, execution_options: Optional[dict[str, Any]] = None, ) -> list[ModelT]: ... def _get_update_many_statement( self, model_type: type[ModelT], supports_returning: bool, loader_options: Optional[list[_AbstractLoad]], execution_options: Optional[dict[str, Any]], ) -> Union[Update, ReturningUpdate[tuple[ModelT]]]: ...
[docs] def upsert(
self, data: ModelT, *, attribute_names: Optional[Iterable[str]] = None, with_for_update: ForUpdateParameter = None, auto_expunge: Optional[bool] = None, auto_commit: Optional[bool] = None, auto_refresh: Optional[bool] = None, match_fields: Optional[Union[list[str], str]] = None, error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty, load: Optional[LoadSpec] = None, execution_options: Optional[dict[str, Any]] = None, ) -> ModelT: ...
[docs] def upsert_many(
self, data: list[ModelT], *, auto_expunge: Optional[bool] = None, auto_commit: Optional[bool] = None, no_merge: bool = False, match_fields: Optional[Union[list[str], str]] = None, error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty, load: Optional[LoadSpec] = None, execution_options: Optional[dict[str, Any]] = None, ) -> list[ModelT]: ...
[docs] def list_and_count(
self, *filters: Union[StatementFilter, ColumnElement[bool]], auto_expunge: Optional[bool] = None, statement: Optional[Select[tuple[ModelT]]] = None, count_with_window_function: Optional[bool] = None, error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty, load: Optional[LoadSpec] = None, execution_options: Optional[dict[str, Any]] = None, order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None, **kwargs: Any, ) -> tuple[list[ModelT], int]: ...
[docs] def list(
self, *filters: Union[StatementFilter, ColumnElement[bool]], auto_expunge: Optional[bool] = None, statement: Optional[Select[tuple[ModelT]]] = None, error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty, load: Optional[LoadSpec] = None, execution_options: Optional[dict[str, Any]] = None, order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None, **kwargs: Any, ) -> list[ModelT]: ...
[docs] @classmethod def check_health(cls, session: Union[Session, scoped_session[Session]]) -> bool: ...
[docs] @runtime_checkable class SQLAlchemySyncSlugRepositoryProtocol(SQLAlchemySyncRepositoryProtocol[ModelT], Protocol[ModelT]): """Protocol for SQLAlchemy repositories that support slug-based operations. Extends the base repository protocol to add slug-related functionality. Type Parameters: ModelT: The SQLAlchemy model type this repository handles. """
[docs] def get_by_slug( self, slug: str, *, error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty, load: Optional[LoadSpec] = None, execution_options: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> Optional[ModelT]: """Get a model instance by its slug. Args: slug: The slug value to search for. error_messages: Optional custom error message templates. load: Specification for eager loading of relationships. execution_options: Options for statement execution. **kwargs: Additional filtering criteria. Returns: ModelT | None: The found model instance or None if not found. """ ...
[docs] def get_available_slug( self, value_to_slugify: str, **kwargs: Any, ) -> str: """Generate a unique slug for a given value. Args: value_to_slugify: The string to convert to a slug. **kwargs: Additional parameters for slug generation. Returns: str: A unique slug derived from the input value. """ ...
[docs] class SQLAlchemySyncRepository(SQLAlchemySyncRepositoryProtocol[ModelT], FilterableRepository[ModelT]): """Async SQLAlchemy repository implementation. Provides a complete implementation of async database operations using SQLAlchemy, including CRUD operations, filtering, and relationship loading. Type Parameters: ModelT: The SQLAlchemy model type this repository handles. .. seealso:: :class:`~advanced_alchemy.repository._util.FilterableRepository` """ id_attribute: str = "id" """Name of the unique identifier for the model.""" loader_options: Optional[LoadSpec] = None """Default loader options for the repository.""" error_messages: Optional[ErrorMessages] = None """Default error messages for the repository.""" wrap_exceptions: bool = True """Wrap SQLAlchemy exceptions in a ``RepositoryError``. When set to ``False``, the original exception will be raised.""" inherit_lazy_relationships: bool = True """Optionally ignore the default ``lazy`` configuration for model relationships. This is useful for when you want to replace instead of merge the model's loaded relationships with the ones specified in the ``load`` or ``default_loader_options`` configuration.""" merge_loader_options: bool = True """Merges the default loader options with the loader options specified in the ``load`` argument. This is useful for when you want to totally replace instead of merge the model's loaded relationships with the ones specified in the ``load`` or ``default_loader_options`` configuration.""" execution_options: Optional[dict[str, Any]] = None """Default execution options for the repository.""" match_fields: Optional[Union[list[str], str]] = None """List of dialects that prefer to use ``field.id = ANY(:1)`` instead of ``field.id IN (...)``.""" uniquify: bool = False """Optionally apply the ``unique()`` method to results before returning. This is useful for certain SQLAlchemy uses cases such as applying ``contains_eager`` to a query containing a one-to-many relationship """ count_with_window_function: bool = True """Use an analytical window function to count results. This allows the count to be performed in a single query. """
[docs] def __init__( self, *, statement: Optional[Select[tuple[ModelT]]] = None, session: Union[Session, scoped_session[Session]], auto_expunge: bool = False, auto_refresh: bool = True, auto_commit: bool = False, order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None, error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty, load: Optional[LoadSpec] = None, execution_options: Optional[dict[str, Any]] = None, wrap_exceptions: bool = True, uniquify: Optional[bool] = None, count_with_window_function: Optional[bool] = None, **kwargs: Any, ) -> None: """Repository for SQLAlchemy models. Args: statement: To facilitate customization of the underlying select query. session: Session managing the unit-of-work for the operation. auto_expunge: Remove object from session before returning. auto_refresh: Refresh object from session before returning. auto_commit: Commit objects before returning. order_by: Set default order options for queries. load: Set default relationships to be loaded execution_options: Set default execution options error_messages: A set of custom error messages to use for operations wrap_exceptions: Wrap SQLAlchemy exceptions in a ``RepositoryError``. When set to ``False``, the original exception will be raised. uniquify: Optionally apply the ``unique()`` method to results before returning. count_with_window_function: When false, list and count will use two queries instead of an analytical window function. **kwargs: Additional arguments. """ self.auto_expunge = auto_expunge self.auto_refresh = auto_refresh self.auto_commit = auto_commit self.order_by = order_by self.session = session self.error_messages = self._get_error_messages( error_messages=error_messages, default_messages=self.error_messages ) self.wrap_exceptions = wrap_exceptions self._uniquify = uniquify if uniquify is not None else self.uniquify self.count_with_window_function = ( count_with_window_function if count_with_window_function is not None else self.count_with_window_function ) self._default_loader_options, self._loader_options_have_wildcards = get_abstract_loader_options( loader_options=load if load is not None else self.loader_options, inherit_lazy_relationships=self.inherit_lazy_relationships, merge_with_default=self.merge_loader_options, ) execution_options = execution_options if execution_options is not None else self.execution_options self._default_execution_options = execution_options or {} self.statement = select(self.model_type) if statement is None else statement self._dialect = self.session.bind.dialect if self.session.bind is not None else self.session.get_bind().dialect self._prefer_any = any(self._dialect.name == engine_type for engine_type in self.prefer_any_dialects or ())
def _get_uniquify(self, uniquify: Optional[bool] = None) -> bool: """Get the uniquify value, preferring the method parameter over instance setting. Args: uniquify: Optional override for the uniquify setting. Returns: bool: The uniquify value to use. """ return bool(uniquify) if uniquify is not None else self._uniquify def _type_must_use_in_instead_of_any(self, matched_values: "list[Any]", field_type: "Any" = None) -> bool: """Determine if field.in_() should be used instead of any_() for compatibility. Uses SQLAlchemy's type introspection to detect types that may have DBAPI serialization issues with the ANY() operator. Checks if actual values match the column's expected python_type - mismatches indicate complex types that need the safer IN() operator. Falls back to Python type checking when SQLAlchemy type information is unavailable. Args: matched_values: Values to be used in the filter field_type: Optional SQLAlchemy TypeEngine from the column Returns: bool: True if field.in_() should be used instead of any_() """ if not matched_values: return False if field_type is not None: try: expected_python_type = getattr(field_type, "python_type", None) if expected_python_type is not None: for value in matched_values: if value is not None and not isinstance(value, expected_python_type): return True except (AttributeError, NotImplementedError): return True return any(value is not None and type(value) not in DEFAULT_SAFE_TYPES for value in matched_values) def _get_unique_values(self, values: "list[Any]") -> "list[Any]": """Get unique values from a list, handling unhashable types safely. Args: values: List of values to deduplicate Returns: list[Any]: List of unique values preserving order """ if not values: return [] try: # Fast path for hashable types seen: set[Any] = set() unique_values: list[Any] = [] for value in values: if value not in seen: unique_values.append(value) seen.add(value) except TypeError: # Fallback for unhashable types (e.g., dicts from JSONB) unique_values = [] for value in values: if value not in unique_values: unique_values.append(value) return unique_values @staticmethod def _get_error_messages( error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty, default_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty, ) -> Optional[ErrorMessages]: if error_messages == Empty: error_messages = None if default_messages == Empty: default_messages = None messages = DEFAULT_ERROR_MESSAGE_TEMPLATES if default_messages and isinstance(default_messages, dict): messages.update(default_messages) if error_messages: messages.update(cast("ErrorMessages", error_messages)) return messages
[docs] @classmethod def get_id_attribute_value( cls, item: Union[ModelT, type[ModelT]], id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None, ) -> Any: """Get value of attribute named as :attr:`id_attribute` on ``item``. Args: item: Anything that should have an attribute named as :attr:`id_attribute` value. id_attribute: Allows customization of the unique identifier to use for model fetching. Defaults to `None`, but can reference any surrogate or candidate key for the table. Returns: The value of attribute on ``item`` named as :attr:`id_attribute`. """ if isinstance(id_attribute, InstrumentedAttribute): id_attribute = id_attribute.key return getattr(item, id_attribute if id_attribute is not None else cls.id_attribute)
[docs] @classmethod def set_id_attribute_value( cls, item_id: Any, item: ModelT, id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None, ) -> ModelT: """Return the ``item`` after the ID is set to the appropriate attribute. Args: item_id: Value of ID to be set on instance item: Anything that should have an attribute named as :attr:`id_attribute` value. id_attribute: Allows customization of the unique identifier to use for model fetching. Defaults to `None`, but can reference any surrogate or candidate key for the table. Returns: Item with ``item_id`` set to :attr:`id_attribute` """ if isinstance(id_attribute, InstrumentedAttribute): id_attribute = id_attribute.key setattr(item, id_attribute if id_attribute is not None else cls.id_attribute, item_id) return item
[docs] @staticmethod def check_not_found(item_or_none: Optional[ModelT]) -> ModelT: """Raise :exc:`advanced_alchemy.exceptions.NotFoundError` if ``item_or_none`` is ``None``. Args: item_or_none: Item (:class:`T <T>`) to be tested for existence. Raises: NotFoundError: If ``item_or_none`` is ``None`` Returns: The item, if it exists. """ if item_or_none is None: msg = "No item found when one was expected" raise NotFoundError(msg) return item_or_none
def _get_execution_options( self, execution_options: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: if execution_options is None: return self._default_execution_options return execution_options def _get_loader_options( self, loader_options: Optional[LoadSpec], ) -> Union[tuple[list[_AbstractLoad], bool], tuple[None, bool]]: if loader_options is None: # use the defaults set at initialization return self._default_loader_options, self._loader_options_have_wildcards or self._uniquify return get_abstract_loader_options( loader_options=loader_options, default_loader_options=self._default_loader_options, default_options_have_wildcards=self._loader_options_have_wildcards or self._uniquify, inherit_lazy_relationships=self.inherit_lazy_relationships, merge_with_default=self.merge_loader_options, )
[docs] def add( self, data: ModelT, *, auto_commit: Optional[bool] = None, auto_expunge: Optional[bool] = None, auto_refresh: Optional[bool] = None, error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty, ) -> ModelT: """Add ``data`` to the collection. Args: data: Instance to be added to the collection. auto_expunge: Remove object from session before returning. auto_refresh: Refresh object from session before returning. auto_commit: Commit objects before returning. error_messages: An optional dictionary of templates to use for friendlier error messages to clients Returns: The added instance. """ error_messages = self._get_error_messages( error_messages=error_messages, default_messages=self.error_messages, ) with wrap_sqlalchemy_exception( error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions ): instance = self._attach_to_session(data) self._flush_or_commit(auto_commit=auto_commit) self._refresh(instance, auto_refresh=auto_refresh) self._expunge(instance, auto_expunge=auto_expunge) return instance
[docs] def add_many( self, data: list[ModelT], *, auto_commit: Optional[bool] = None, auto_expunge: Optional[bool] = None, error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty, ) -> Sequence[ModelT]: """Add many `data` to the collection. Args: data: list of Instances to be added to the collection. auto_expunge: Remove object from session before returning. auto_commit: Commit objects before returning. error_messages: An optional dictionary of templates to use for friendlier error messages to clients Returns: The added instances. """ error_messages = self._get_error_messages( error_messages=error_messages, default_messages=self.error_messages, ) with wrap_sqlalchemy_exception( error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions ): self.session.add_all(data) self._flush_or_commit(auto_commit=auto_commit) for datum in data: self._expunge(datum, auto_expunge=auto_expunge) return data
[docs] def delete( self, item_id: Any, *, auto_commit: Optional[bool] = None, auto_expunge: Optional[bool] = None, id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None, error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty, load: Optional[LoadSpec] = None, execution_options: Optional[dict[str, Any]] = None, uniquify: Optional[bool] = None, ) -> ModelT: """Delete instance identified by ``item_id``. Args: item_id: Identifier of instance to be deleted. auto_expunge: Remove object from session before returning. auto_commit: Commit objects before returning. id_attribute: Allows customization of the unique identifier to use for model fetching. Defaults to `id`, but can reference any surrogate or candidate key for the table. error_messages: An optional dictionary of templates to use for friendlier error messages to clients load: Set default relationships to be loaded execution_options: Set default execution options uniquify: Optionally apply the ``unique()`` method to results before returning. Returns: The deleted instance. """ self._uniquify = self._get_uniquify(uniquify) error_messages = self._get_error_messages( error_messages=error_messages, default_messages=self.error_messages, ) with wrap_sqlalchemy_exception( error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions ): execution_options = self._get_execution_options(execution_options) instance = self.get( item_id, id_attribute=id_attribute, load=load, execution_options=execution_options, ) self.session.delete(instance) self._flush_or_commit(auto_commit=auto_commit) self._expunge(instance, auto_expunge=auto_expunge) return instance
[docs] def delete_many( self, item_ids: list[Any], *, auto_commit: Optional[bool] = None, auto_expunge: Optional[bool] = None, id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None, chunk_size: Optional[int] = None, error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty, load: Optional[LoadSpec] = None, execution_options: Optional[dict[str, Any]] = None, uniquify: Optional[bool] = None, ) -> Sequence[ModelT]: """Delete instance identified by `item_id`. Args: item_ids: Identifier of instance to be deleted. auto_expunge: Remove object from session before returning. auto_commit: Commit objects before returning. id_attribute: Allows customization of the unique identifier to use for model fetching. Defaults to `id`, but can reference any surrogate or candidate key for the table. chunk_size: Allows customization of the ``insertmanyvalues_max_parameters`` setting for the driver. Defaults to `950` if left unset. error_messages: An optional dictionary of templates to use for friendlier error messages to clients load: Set default relationships to be loaded execution_options: Set default execution options uniquify: Optionally apply the ``unique()`` method to results before returning. Returns: The deleted instances. """ self._uniquify = self._get_uniquify(uniquify) error_messages = self._get_error_messages( error_messages=error_messages, default_messages=self.error_messages, ) with wrap_sqlalchemy_exception( error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions ): execution_options = self._get_execution_options(execution_options) loader_options, _loader_options_have_wildcard = self._get_loader_options(load) id_attribute = get_instrumented_attr( self.model_type, id_attribute if id_attribute is not None else self.id_attribute, ) instances: list[ModelT] = [] if self._prefer_any: chunk_size = len(item_ids) + 1 chunk_size = self._get_insertmanyvalues_max_parameters(chunk_size) for idx in range(0, len(item_ids), chunk_size): chunk = item_ids[idx : min(idx + chunk_size, len(item_ids))] if self._dialect.delete_executemany_returning: instances.extend( self.session.scalars( self._get_delete_many_statement( statement_type="delete", model_type=self.model_type, id_attribute=id_attribute, id_chunk=chunk, supports_returning=self._dialect.delete_executemany_returning, loader_options=loader_options, execution_options=execution_options, ), ), ) else: instances.extend( self.session.scalars( self._get_delete_many_statement( statement_type="select", model_type=self.model_type, id_attribute=id_attribute, id_chunk=chunk, supports_returning=self._dialect.delete_executemany_returning, loader_options=loader_options, execution_options=execution_options, ), ), ) self.session.execute( self._get_delete_many_statement( statement_type="delete", model_type=self.model_type, id_attribute=id_attribute, id_chunk=chunk, supports_returning=self._dialect.delete_executemany_returning, loader_options=loader_options, execution_options=execution_options, ), ) self._flush_or_commit(auto_commit=auto_commit) for instance in instances: self._expunge(instance, auto_expunge=auto_expunge) return instances
@staticmethod def _get_insertmanyvalues_max_parameters(chunk_size: Optional[int] = None) -> int: return chunk_size if chunk_size is not None else DEFAULT_INSERTMANYVALUES_MAX_PARAMETERS
[docs] def delete_where( self, *filters: Union[StatementFilter, ColumnElement[bool]], auto_commit: Optional[bool] = None, auto_expunge: Optional[bool] = None, error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty, sanity_check: bool = True, load: Optional[LoadSpec] = None, execution_options: Optional[dict[str, Any]] = None, uniquify: Optional[bool] = None, **kwargs: Any, ) -> Sequence[ModelT]: """Delete instances specified by referenced kwargs and filters. Args: *filters: Types for specific filtering operations. auto_expunge: Remove object from session before returning. auto_commit: Commit objects before returning. error_messages: An optional dictionary of templates to use for friendlier error messages to clients sanity_check: When true, the length of selected instances is compared to the deleted row count load: Set default relationships to be loaded execution_options: Set default execution options uniquify: Optionally apply the ``unique()`` method to results before returning. **kwargs: Arguments to apply to a delete Raises: RepositoryError: If the number of deleted rows does not match the number of selected instances Returns: The deleted instances. """ self._uniquify = self._get_uniquify(uniquify) error_messages = self._get_error_messages( error_messages=error_messages, default_messages=self.error_messages, ) with wrap_sqlalchemy_exception( error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions ): execution_options = self._get_execution_options(execution_options) loader_options, _loader_options_have_wildcard = self._get_loader_options(load) model_type = self.model_type statement = self._get_base_stmt( statement=delete(model_type), loader_options=loader_options, execution_options=execution_options, ) statement = self._filter_select_by_kwargs(statement=statement, kwargs=kwargs) statement = self._apply_filters(*filters, statement=statement, apply_pagination=False) instances: list[ModelT] = [] if self._dialect.delete_executemany_returning: instances.extend(self.session.scalars(statement.returning(model_type))) else: instances.extend( self.list( *filters, load=load, execution_options=execution_options, auto_expunge=auto_expunge, **kwargs, ), ) result = self.session.execute(statement) row_count = getattr(result, "rowcount", -2) if sanity_check and row_count >= 0 and len(instances) != row_count: # pyright: ignore # backends will return a -1 if they can't determine impacted rowcount # only compare length of selected instances to results if it's >= 0 self.session.rollback() raise RepositoryError(detail="Deleted count does not match fetched count. Rollback issued.") self._flush_or_commit(auto_commit=auto_commit) for instance in instances: self._expunge(instance, auto_expunge=auto_expunge) return instances
[docs] def exists( self, *filters: Union[StatementFilter, ColumnElement[bool]], error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty, load: Optional[LoadSpec] = None, execution_options: Optional[dict[str, Any]] = None, uniquify: Optional[bool] = None, **kwargs: Any, ) -> bool: """Return true if the object specified by ``kwargs`` exists. Args: *filters: Types for specific filtering operations. error_messages: An optional dictionary of templates to use for friendlier error messages to clients load: Set default relationships to be loaded execution_options: Set default execution options uniquify: Optionally apply the ``unique()`` method to results before returning. **kwargs: Identifier of the instance to be retrieved. Returns: True if the instance was found. False if not found.. """ error_messages = self._get_error_messages( error_messages=error_messages, default_messages=self.error_messages, ) existing = self.count( *filters, load=load, execution_options=execution_options, error_messages=error_messages, **kwargs, ) return existing > 0
@staticmethod def _get_base_stmt( *, statement: StatementTypeT, loader_options: Optional[list[_AbstractLoad]], execution_options: Optional[dict[str, Any]], ) -> StatementTypeT: """Get base statement with options applied. Args: statement: The select statement to modify loader_options: Options for loading relationships execution_options: Options for statement execution Returns: Modified select statement """ if loader_options: statement = cast("StatementTypeT", statement.options(*loader_options)) if execution_options: statement = cast("StatementTypeT", statement.execution_options(**execution_options)) return statement def _get_delete_many_statement( self, *, model_type: type[ModelT], id_attribute: InstrumentedAttribute[Any], id_chunk: list[Any], supports_returning: bool, statement_type: Literal["delete", "select"] = "delete", loader_options: Optional[list[_AbstractLoad]], execution_options: Optional[dict[str, Any]], ) -> Union[Select[tuple[ModelT]], Delete, ReturningDelete[tuple[ModelT]]]: # Base statement is static statement = self._get_base_stmt( statement=delete(model_type) if statement_type == "delete" else select(model_type), loader_options=loader_options, execution_options=execution_options, ) if execution_options: statement = statement.execution_options(**execution_options) if supports_returning and statement_type != "select": statement = cast("ReturningDelete[tuple[ModelT]]", statement.returning(model_type)) # type: ignore[union-attr,assignment] # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType,reportAttributeAccessIssue,reportUnknownVariableType] # Use field.in_() if types are incompatible with ANY() or if dialect doesn't prefer ANY() use_in = not self._prefer_any or self._type_must_use_in_instead_of_any(id_chunk, id_attribute.type) if use_in: return statement.where(id_attribute.in_(id_chunk)) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] return statement.where(any_(id_chunk) == id_attribute) # type: ignore[arg-type]
[docs] def get( self, item_id: Any, *, auto_expunge: Optional[bool] = None, statement: Optional[Select[tuple[ModelT]]] = None, id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None, error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty, load: Optional[LoadSpec] = None, execution_options: Optional[dict[str, Any]] = None, uniquify: Optional[bool] = None, ) -> ModelT: """Get instance identified by `item_id`. Args: item_id: Identifier of the instance to be retrieved. auto_expunge: Remove object from session before returning. statement: To facilitate customization of the underlying select query. id_attribute: Allows customization of the unique identifier to use for model fetching. Defaults to `id`, but can reference any surrogate or candidate key for the table. error_messages: An optional dictionary of templates to use for friendlier error messages to clients load: Set relationships to be loaded execution_options: Set default execution options uniquify: Optionally apply the ``unique()`` method to results before returning. Returns: The retrieved instance. """ self._uniquify = self._get_uniquify(uniquify) error_messages = self._get_error_messages( error_messages=error_messages, default_messages=self.error_messages, ) with wrap_sqlalchemy_exception( error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions ): execution_options = self._get_execution_options(execution_options) statement = self.statement if statement is None else statement loader_options, loader_options_have_wildcard = self._get_loader_options(load) id_attribute = id_attribute if id_attribute is not None else self.id_attribute statement = self._get_base_stmt( statement=statement, loader_options=loader_options, execution_options=execution_options, ) statement = self._filter_select_by_kwargs(statement, [(id_attribute, item_id)]) instance = (self._execute(statement, uniquify=loader_options_have_wildcard)).scalar_one_or_none() instance = self.check_not_found(instance) self._expunge(instance, auto_expunge=auto_expunge) return instance
[docs] def get_one( self, *filters: Union[StatementFilter, ColumnElement[bool]], auto_expunge: Optional[bool] = None, statement: Optional[Select[tuple[ModelT]]] = None, error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty, load: Optional[LoadSpec] = None, execution_options: Optional[dict[str, Any]] = None, uniquify: Optional[bool] = None, **kwargs: Any, ) -> ModelT: """Get instance identified by ``kwargs``. Args: *filters: Types for specific filtering operations. auto_expunge: Remove object from session before returning. statement: To facilitate customization of the underlying select query. error_messages: An optional dictionary of templates to use for friendlier error messages to clients load: Set relationships to be loaded execution_options: Set default execution options uniquify: Optionally apply the ``unique()`` method to results before returning. **kwargs: Identifier of the instance to be retrieved. Returns: The retrieved instance. """ self._uniquify = self._get_uniquify(uniquify) error_messages = self._get_error_messages( error_messages=error_messages, default_messages=self.error_messages, ) with wrap_sqlalchemy_exception( error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions ): execution_options = self._get_execution_options(execution_options) statement = self.statement if statement is None else statement loader_options, loader_options_have_wildcard = self._get_loader_options(load) statement = self._get_base_stmt( statement=statement, loader_options=loader_options, execution_options=execution_options, ) statement = self._apply_filters(*filters, apply_pagination=False, statement=statement) statement = self._filter_select_by_kwargs(statement, kwargs) instance = (self._execute(statement, uniquify=loader_options_have_wildcard)).scalar_one_or_none() instance = self.check_not_found(instance) self._expunge(instance, auto_expunge=auto_expunge) return instance
[docs] def get_one_or_none( self, *filters: Union[StatementFilter, ColumnElement[bool]], auto_expunge: Optional[bool] = None, statement: Optional[Select[tuple[ModelT]]] = None, error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty, load: Optional[LoadSpec] = None, execution_options: Optional[dict[str, Any]] = None, uniquify: Optional[bool] = None, **kwargs: Any, ) -> Union[ModelT, None]: """Get instance identified by ``kwargs`` or None if not found. Args: *filters: Types for specific filtering operations. auto_expunge: Remove object from session before returning. statement: To facilitate customization of the underlying select query. error_messages: An optional dictionary of templates to use for friendlier error messages to clients load: Set relationships to be loaded execution_options: Set default execution options uniquify: Optionally apply the ``unique()`` method to results before returning. **kwargs: Identifier of the instance to be retrieved. Returns: The retrieved instance or None """ self._uniquify = self._get_uniquify(uniquify) error_messages = self._get_error_messages( error_messages=error_messages, default_messages=self.error_messages, ) with wrap_sqlalchemy_exception( error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions ): execution_options = self._get_execution_options(execution_options) statement = self.statement if statement is None else statement loader_options, loader_options_have_wildcard = self._get_loader_options(load) statement = self._get_base_stmt( statement=statement, loader_options=loader_options, execution_options=execution_options, ) statement = self._apply_filters(*filters, apply_pagination=False, statement=statement) statement = self._filter_select_by_kwargs(statement, kwargs) instance = cast( "Result[tuple[ModelT]]", (self._execute(statement, uniquify=loader_options_have_wildcard)), ).scalar_one_or_none() if instance: self._expunge(instance, auto_expunge=auto_expunge) return instance
[docs] def get_or_upsert( self, *filters: Union[StatementFilter, ColumnElement[bool]], match_fields: Optional[Union[list[str], str]] = None, upsert: bool = True, attribute_names: Optional[Iterable[str]] = None, with_for_update: ForUpdateParameter = None, auto_commit: Optional[bool] = None, auto_expunge: Optional[bool] = None, auto_refresh: Union[bool, None] = None, error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty, load: Optional[LoadSpec] = None, execution_options: Optional[dict[str, Any]] = None, uniquify: Optional[bool] = None, **kwargs: Any, ) -> tuple[ModelT, bool]: """Get instance identified by ``kwargs`` or create if it doesn't exist. Args: *filters: Types for specific filtering operations. match_fields: a list of keys to use to match the existing model. When empty, all fields are matched. upsert: When using match_fields and actual model values differ from `kwargs`, automatically perform an update operation on the model. attribute_names: an iterable of attribute names to pass into the ``update`` method. with_for_update: indicating FOR UPDATE should be used, or may be a dictionary containing flags to indicate a more specific set of FOR UPDATE flags for the SELECT auto_expunge: Remove object from session before returning. auto_refresh: Refresh object from session before returning. auto_commit: Commit objects before returning. error_messages: An optional dictionary of templates to use for friendlier error messages to clients load: Set relationships to be loaded execution_options: Set default execution options uniquify: Optionally apply the ``unique()`` method to results before returning. **kwargs: Identifier of the instance to be retrieved. Returns: a tuple that includes the instance and whether it needed to be created. When using match_fields and actual model values differ from ``kwargs``, the model value will be updated. """ self._uniquify = self._get_uniquify(uniquify) error_messages = self._get_error_messages( error_messages=error_messages, default_messages=self.error_messages, ) with wrap_sqlalchemy_exception( error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions ): if match_fields := self._get_match_fields(match_fields=match_fields): match_filter = { field_name: kwargs.get(field_name) for field_name in match_fields if kwargs.get(field_name) is not None } else: match_filter = kwargs existing = self.get_one_or_none( *filters, **match_filter, load=load, execution_options=execution_options, ) if not existing: return ( self.add( self.model_type(**kwargs), auto_commit=auto_commit, auto_expunge=auto_expunge, auto_refresh=auto_refresh, ), True, ) if upsert: for field_name, new_field_value in kwargs.items(): field = getattr(existing, field_name, MISSING) if field is not MISSING and not compare_values(field, new_field_value): # pragma: no cover setattr(existing, field_name, new_field_value) existing = self._attach_to_session(existing, strategy="merge") self._flush_or_commit(auto_commit=auto_commit) self._refresh( existing, attribute_names=attribute_names, with_for_update=with_for_update, auto_refresh=auto_refresh, ) self._expunge(existing, auto_expunge=auto_expunge) return existing, False
[docs] def get_and_update( self, *filters: Union[StatementFilter, ColumnElement[bool]], match_fields: Optional[Union[list[str], str]] = None, attribute_names: Optional[Iterable[str]] = None, with_for_update: ForUpdateParameter = None, auto_commit: Optional[bool] = None, auto_expunge: Optional[bool] = None, auto_refresh: Optional[bool] = None, error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty, load: Optional[LoadSpec] = None, execution_options: Optional[dict[str, Any]] = None, uniquify: Optional[bool] = None, **kwargs: Any, ) -> tuple[ModelT, bool]: """Get instance identified by ``kwargs`` and update the model if the arguments are different. Args: *filters: Types for specific filtering operations. match_fields: a list of keys to use to match the existing model. When empty, all fields are matched. attribute_names: an iterable of attribute names to pass into the ``update`` method. with_for_update: indicating FOR UPDATE should be used, or may be a dictionary containing flags to indicate a more specific set of FOR UPDATE flags for the SELECT auto_expunge: Remove object from session before returning. auto_refresh: Refresh object from session before returning. auto_commit: Commit objects before returning. error_messages: An optional dictionary of templates to use for friendlier error messages to clients load: Set relationships to be loaded execution_options: Set default execution options uniquify: Optionally apply the ``unique()`` method to results before returning. **kwargs: Identifier of the instance to be retrieved. Returns: a tuple that includes the instance and whether it needed to be updated. When using match_fields and actual model values differ from ``kwargs``, the model value will be updated. """ self._uniquify = self._get_uniquify(uniquify) error_messages = self._get_error_messages( error_messages=error_messages, default_messages=self.error_messages, ) with wrap_sqlalchemy_exception( error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions ): if match_fields := self._get_match_fields(match_fields=match_fields): match_filter = { field_name: kwargs.get(field_name) for field_name in match_fields if kwargs.get(field_name) is not None } else: match_filter = kwargs existing = self.get_one(*filters, **match_filter, load=load, execution_options=execution_options) updated = False for field_name, new_field_value in kwargs.items(): field = getattr(existing, field_name, MISSING) if field is not MISSING and not compare_values(field, new_field_value): # pragma: no cover updated = True setattr(existing, field_name, new_field_value) existing = self._attach_to_session(existing, strategy="merge") self._flush_or_commit(auto_commit=auto_commit) self._refresh( existing, attribute_names=attribute_names, with_for_update=with_for_update, auto_refresh=auto_refresh, ) self._expunge(existing, auto_expunge=auto_expunge) return existing, updated
[docs] def count( self, *filters: Union[StatementFilter, ColumnElement[bool]], statement: Optional[Select[tuple[ModelT]]] = None, error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty, load: Optional[LoadSpec] = None, execution_options: Optional[dict[str, Any]] = None, uniquify: Optional[bool] = None, **kwargs: Any, ) -> int: """Get the count of records returned by a query. Args: *filters: Types for specific filtering operations. statement: To facilitate customization of the underlying select query. error_messages: An optional dictionary of templates to use for friendlier error messages to clients load: Set relationships to be loaded execution_options: Set default execution options uniquify: Optionally apply the ``unique()`` method to results before returning. **kwargs: Instance attribute value filters. Returns: Count of records returned by query, ignoring pagination. """ self._uniquify = self._get_uniquify(uniquify) error_messages = self._get_error_messages( error_messages=error_messages, default_messages=self.error_messages, ) with wrap_sqlalchemy_exception( error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions ): execution_options = self._get_execution_options(execution_options) statement = self.statement if statement is None else statement loader_options, loader_options_have_wildcard = self._get_loader_options(load) statement = self._get_base_stmt( statement=statement, loader_options=loader_options, execution_options=execution_options, ) statement = self._apply_filters(*filters, apply_pagination=False, statement=statement) statement = self._filter_select_by_kwargs(statement, kwargs) results = self._execute( statement=self._get_count_stmt( statement=statement, loader_options=loader_options, execution_options=execution_options ), uniquify=loader_options_have_wildcard, ) return cast("int", results.scalar_one())
[docs] def update( self, data: ModelT, *, attribute_names: Optional[Iterable[str]] = None, with_for_update: ForUpdateParameter = None, auto_commit: Optional[bool] = None, auto_expunge: Optional[bool] = None, auto_refresh: Optional[bool] = None, id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None, error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty, load: Optional[LoadSpec] = None, execution_options: Optional[dict[str, Any]] = None, uniquify: Optional[bool] = None, ) -> ModelT: """Update instance with the attribute values present on `data`. Args: data: An instance that should have a value for `self.id_attribute` that exists in the collection. attribute_names: an iterable of attribute names to pass into the ``update`` method. with_for_update: indicating FOR UPDATE should be used, or may be a dictionary containing flags to indicate a more specific set of FOR UPDATE flags for the SELECT auto_expunge: Remove object from session before returning. auto_refresh: Refresh object from session before returning. auto_commit: Commit objects before returning. id_attribute: Allows customization of the unique identifier to use for model fetching. Defaults to `id`, but can reference any surrogate or candidate key for the table. error_messages: An optional dictionary of templates to use for friendlier error messages to clients load: Set relationships to be loaded execution_options: Set default execution options uniquify: Optionally apply the ``unique()`` method to results before returning. Returns: The updated instance. """ self._uniquify = self._get_uniquify(uniquify) error_messages = self._get_error_messages( error_messages=error_messages, default_messages=self.error_messages, ) with wrap_sqlalchemy_exception( error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions ): item_id = self.get_id_attribute_value( data, id_attribute=id_attribute, ) existing_instance = self.get( item_id, id_attribute=id_attribute, load=load, execution_options=execution_options ) mapper = None with ( self.session.no_autoflush, contextlib.suppress(MissingGreenlet, NoInspectionAvailable), ): mapper = inspect(data) if mapper is not None: for column in mapper.mapper.columns: field_name = column.key new_field_value = getattr(data, field_name, MISSING) if new_field_value is not MISSING: # Skip setting columns with defaults/onupdate to None during updates # This prevents overwriting columns that should use their defaults if new_field_value is None and column_has_defaults(column): continue # Only copy attributes that were explicitly set on the input instance # This prevents overwriting existing values with uninitialized None values if not was_attribute_set(data, mapper, field_name): continue existing_field_value = getattr(existing_instance, field_name, MISSING) if existing_field_value is not MISSING and not compare_values( existing_field_value, new_field_value ): setattr(existing_instance, field_name, new_field_value) # Handle relationships by merging objects into session first for relationship in mapper.mapper.relationships: if relationship.viewonly or relationship.lazy in { # pragma: no cover "write_only", "dynamic", "raise", "raise_on_sql", }: # Skip relationships with incompatible lazy loading strategies continue if (new_value := getattr(data, relationship.key, MISSING)) is not MISSING: # Skip relationships that cannot be handled by generic merge operations if isinstance(new_value, list): merged_values = [ # pyright: ignore self.session.merge(item, load=False) # pyright: ignore for item in new_value # pyright: ignore ] setattr(existing_instance, relationship.key, merged_values) elif new_value is not None: merged_value = self.session.merge(new_value, load=False) setattr(existing_instance, relationship.key, merged_value) else: setattr(existing_instance, relationship.key, new_value) instance = self._attach_to_session(existing_instance, strategy="merge") self._flush_or_commit(auto_commit=auto_commit) self._refresh( instance, attribute_names=attribute_names, with_for_update=with_for_update, auto_refresh=auto_refresh, ) self._expunge(instance, auto_expunge=auto_expunge) return instance
[docs] def update_many( self, data: list[ModelT], *, auto_commit: Optional[bool] = None, auto_expunge: Optional[bool] = None, error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty, load: Optional[LoadSpec] = None, execution_options: Optional[dict[str, Any]] = None, uniquify: Optional[bool] = None, ) -> list[ModelT]: """Update one or more instances with the attribute values present on `data`. This function has an optimized bulk update based on the configured SQL dialect: - For backends supporting `RETURNING` with `executemany`, a single bulk update with returning clause is executed. - For other backends, it does a bulk update and then returns the updated data after a refresh. Args: data: A list of instances to update. Each should have a value for `self.id_attribute` that exists in the collection. auto_expunge: Remove object from session before returning. auto_commit: Commit objects before returning. error_messages: An optional dictionary of templates to use for friendlier error messages to clients load: Set default relationships to be loaded execution_options: Set default execution options uniquify: Optionally apply the ``unique()`` method to results before returning. Returns: The updated instances. """ self._uniquify = self._get_uniquify(uniquify) error_messages = self._get_error_messages( error_messages=error_messages, default_messages=self.error_messages, ) supports_updated_at = hasattr(self.model_type, "updated_at") data_to_update: list[dict[str, Any]] = [] for v in data: if isinstance(v, self.model_type) or (hasattr(v, "to_dict") and callable(v.to_dict)): update_payload = v.to_dict() else: update_payload = cast("dict[str, Any]", schema_dump(v)) if supports_updated_at and (update_payload.get("updated_at") is None): update_payload["updated_at"] = datetime.datetime.now(datetime.timezone.utc) data_to_update.append(update_payload) with wrap_sqlalchemy_exception( error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions ): execution_options = self._get_execution_options(execution_options) loader_options = self._get_loader_options(load)[0] supports_returning = self._dialect.update_executemany_returning and self._dialect.name != "oracle" statement = self._get_update_many_statement( self.model_type, supports_returning, loader_options=loader_options, execution_options=execution_options, ) if supports_returning: instances = list( self.session.scalars( statement, cast("_CoreSingleExecuteParams", data_to_update), # this is not correct but the only way # currently to deal with an SQLAlchemy typing issue. See # https://github.com/sqlalchemy/sqlalchemy/discussions/9925 execution_options=execution_options, ), ) self._flush_or_commit(auto_commit=auto_commit) for instance in instances: self._expunge(instance, auto_expunge=auto_expunge) return instances self.session.execute(statement, data_to_update, execution_options=execution_options) self._flush_or_commit(auto_commit=auto_commit) # For non-RETURNING backends, fetch updated instances from database updated_ids: list[Any] = [item[self.id_attribute] for item in data_to_update] updated_instances = self.list( getattr(self.model_type, self.id_attribute).in_(updated_ids), load=loader_options, execution_options=execution_options, ) for instance in updated_instances: self._expunge(instance, auto_expunge=auto_expunge) return updated_instances
def _get_update_many_statement( self, model_type: type[ModelT], supports_returning: bool, loader_options: Union[list[_AbstractLoad], None], execution_options: Union[dict[str, Any], None], ) -> Union[Update, ReturningUpdate[tuple[ModelT]]]: # Base update statement is static statement = self._get_base_stmt( statement=update(table=model_type), loader_options=loader_options, execution_options=execution_options ) if supports_returning: return statement.returning(model_type) return statement
[docs] def list_and_count( self, *filters: Union[StatementFilter, ColumnElement[bool]], statement: Optional[Select[tuple[ModelT]]] = None, auto_expunge: Optional[bool] = None, count_with_window_function: Optional[bool] = None, order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None, error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty, load: Optional[LoadSpec] = None, execution_options: Optional[dict[str, Any]] = None, uniquify: Optional[bool] = None, **kwargs: Any, ) -> tuple[list[ModelT], int]: """List records with total count. Args: *filters: Types for specific filtering operations. statement: To facilitate customization of the underlying select query. auto_expunge: Remove object from session before returning. count_with_window_function: When false, list and count will use two queries instead of an analytical window function. order_by: Set default order options for queries. error_messages: An optional dictionary of templates to use for friendlier error messages to clients load: Set relationships to be loaded execution_options: Set default execution options uniquify: Optionally apply the ``unique()`` method to results before returning. **kwargs: Instance attribute value filters. Returns: Count of records returned by query, ignoring pagination. """ count_with_window_function = ( count_with_window_function if count_with_window_function is not None else self.count_with_window_function ) self._uniquify = self._get_uniquify(uniquify) error_messages = self._get_error_messages( error_messages=error_messages, default_messages=self.error_messages, ) if self._dialect.name in {"spanner", "spanner+spanner"} or not count_with_window_function: return self._list_and_count_basic( *filters, auto_expunge=auto_expunge, statement=statement, load=load, execution_options=execution_options, order_by=order_by, error_messages=error_messages, **kwargs, ) return self._list_and_count_window( *filters, auto_expunge=auto_expunge, statement=statement, load=load, execution_options=execution_options, error_messages=error_messages, order_by=order_by, **kwargs, )
def _expunge(self, instance: "ModelT", auto_expunge: "Optional[bool]") -> None: """Remove instance from session if auto_expunge is enabled. Args: instance: The model instance to expunge auto_expunge: Whether to expunge the instance. If None, uses self.auto_expunge Note: Deleted objects that have been committed are automatically moved to the detached state by SQLAlchemy. Objects returned from DELETE...RETURNING statements are initially persistent but become detached after commit. We skip expunge for objects that are already detached or marked for deletion to avoid InvalidRequestError. """ if auto_expunge is None: auto_expunge = self.auto_expunge if not auto_expunge: return # Check object state before expunging state = inspect(instance) if state is not None and (state.deleted or state.detached): # Skip expunge for objects that are deleted or already detached # - state.deleted: Object marked for deletion, will be detached on commit # - state.detached: Object already removed from session (e.g., from DELETE...RETURNING) return self.session.expunge(instance) return def _flush_or_commit(self, auto_commit: Optional[bool]) -> None: if auto_commit is None: auto_commit = self.auto_commit return self.session.commit() if auto_commit else self.session.flush() def _refresh( self, instance: ModelT, auto_refresh: Optional[bool], attribute_names: Optional[Iterable[str]] = None, with_for_update: ForUpdateParameter = None, ) -> None: if auto_refresh is None: auto_refresh = self.auto_refresh return ( self.session.refresh( instance=instance, attribute_names=attribute_names, with_for_update=with_for_update, ) if auto_refresh else None ) def _list_and_count_window( self, *filters: Union[StatementFilter, ColumnElement[bool]], auto_expunge: Optional[bool] = None, statement: Optional[Select[tuple[ModelT]]] = None, order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None, error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty, load: Optional[LoadSpec] = None, execution_options: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> tuple[list[ModelT], int]: """List records with total count. Args: *filters: Types for specific filtering operations. auto_expunge: Remove object from session before returning. statement: To facilitate customization of the underlying select query. order_by: List[OrderingPair] | OrderingPair | None = None, error_messages: An optional dictionary of templates to use for friendlier error messages to clients load: Set relationships to be loaded execution_options: Set default execution options **kwargs: Instance attribute value filters. Returns: Count of records returned by query using an analytical window function, ignoring pagination. """ error_messages = self._get_error_messages( error_messages=error_messages, default_messages=self.error_messages, ) with wrap_sqlalchemy_exception( error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions ): execution_options = self._get_execution_options(execution_options) statement = self.statement if statement is None else statement loader_options, loader_options_have_wildcard = self._get_loader_options(load) statement = self._get_base_stmt( statement=statement, loader_options=loader_options, execution_options=execution_options, ) if order_by is None: order_by = self.order_by if self.order_by is not None else [] statement = self._apply_order_by(statement=statement, order_by=order_by) statement = self._apply_filters(*filters, statement=statement) statement = self._filter_select_by_kwargs(statement, kwargs) result = self._execute(statement.add_columns(over(sql_func.count())), uniquify=loader_options_have_wildcard) count: int = 0 instances: list[ModelT] = [] for i, (instance, count_value) in enumerate(result): self._expunge(instance, auto_expunge=auto_expunge) instances.append(instance) if i == 0: count = count_value return instances, count def _list_and_count_basic( self, *filters: Union[StatementFilter, ColumnElement[bool]], auto_expunge: Optional[bool] = None, statement: Optional[Select[tuple[ModelT]]] = None, order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None, error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty, load: Optional[LoadSpec] = None, execution_options: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> tuple[list[ModelT], int]: """List records with total count. Args: *filters: Types for specific filtering operations. auto_expunge: Remove object from session before returning. statement: To facilitate customization of the underlying select query. order_by: Set default order options for queries. error_messages: An optional dictionary of templates to use for friendlier error messages to clients load: Set relationships to be loaded execution_options: Set default execution options **kwargs: Instance attribute value filters. Returns: Count of records returned by query using 2 queries, ignoring pagination. """ error_messages = self._get_error_messages( error_messages=error_messages, default_messages=self.error_messages, ) with wrap_sqlalchemy_exception( error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions ): execution_options = self._get_execution_options(execution_options) statement = self.statement if statement is None else statement loader_options, loader_options_have_wildcard = self._get_loader_options(load) statement = self._get_base_stmt( statement=statement, loader_options=loader_options, execution_options=execution_options, ) if order_by is None: order_by = self.order_by if self.order_by is not None else [] statement = self._apply_order_by(statement=statement, order_by=order_by) statement = self._apply_filters(*filters, statement=statement) statement = self._filter_select_by_kwargs(statement, kwargs) count_result = self.session.execute( self._get_count_stmt( statement, loader_options=loader_options, execution_options=execution_options, ), ) count = count_result.scalar_one() if count == 0: return [], 0 result = self._execute(statement, uniquify=loader_options_have_wildcard) instances: list[ModelT] = [] for (instance,) in result: self._expunge(instance, auto_expunge=auto_expunge) instances.append(instance) return instances, count @staticmethod def _get_count_stmt( statement: Select[tuple[ModelT]], loader_options: Optional[list[_AbstractLoad]], # noqa: ARG004 execution_options: Optional[dict[str, Any]], # noqa: ARG004 ) -> Select[tuple[int]]: # Count statement transformations are static return ( statement.with_only_columns(sql_func.count(text("1")), maintain_column_froms=True) .limit(None) .offset(None) .order_by(None) )
[docs] def upsert( self, data: ModelT, *, attribute_names: Optional[Iterable[str]] = None, with_for_update: ForUpdateParameter = None, auto_expunge: Optional[bool] = None, auto_commit: Optional[bool] = None, auto_refresh: Optional[bool] = None, match_fields: Optional[Union[list[str], str]] = None, error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty, load: Optional[LoadSpec] = None, execution_options: Optional[dict[str, Any]] = None, uniquify: Optional[bool] = None, ) -> ModelT: """Modify or create instance. Updates instance with the attribute values present on `data`, or creates a new instance if one doesn't exist. Args: data: Instance to update existing, or be created. Identifier used to determine if an existing instance exists is the value of an attribute on `data` named as value of `self.id_attribute`. attribute_names: an iterable of attribute names to pass into the ``update`` method. with_for_update: indicating FOR UPDATE should be used, or may be a dictionary containing flags to indicate a more specific set of FOR UPDATE flags for the SELECT auto_expunge: Remove object from session before returning. auto_refresh: Refresh object from session before returning. auto_commit: Commit objects before returning. match_fields: a list of keys to use to match the existing model. When empty, all fields are matched. error_messages: An optional dictionary of templates to use for friendlier error messages to clients load: Set relationships to be loaded execution_options: Set default execution options uniquify: Optionally apply the ``unique()`` method to results before returning. Returns: The updated or created instance. """ self._uniquify = self._get_uniquify(uniquify) error_messages = self._get_error_messages( error_messages=error_messages, default_messages=self.error_messages, ) if match_fields := self._get_match_fields(match_fields=match_fields): match_filter = { field_name: getattr(data, field_name, None) for field_name in match_fields if getattr(data, field_name, None) is not None } elif getattr(data, self.id_attribute, None) is not None: match_filter = {self.id_attribute: getattr(data, self.id_attribute, None)} else: match_filter = data.to_dict(exclude={self.id_attribute}) existing = self.get_one_or_none(load=load, execution_options=execution_options, **match_filter) if not existing: return self.add( data, auto_commit=auto_commit, auto_expunge=auto_expunge, auto_refresh=auto_refresh, ) with wrap_sqlalchemy_exception( error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions ): for field_name, new_field_value in data.to_dict(exclude={self.id_attribute}).items(): field = getattr(existing, field_name, MISSING) if field is not MISSING and not compare_values(field, new_field_value): # pragma: no cover setattr(existing, field_name, new_field_value) instance = self._attach_to_session(existing, strategy="merge") self._flush_or_commit(auto_commit=auto_commit) self._refresh( instance, attribute_names=attribute_names, with_for_update=with_for_update, auto_refresh=auto_refresh, ) self._expunge(instance, auto_expunge=auto_expunge) return instance
[docs] def upsert_many( self, data: list[ModelT], *, auto_expunge: Optional[bool] = None, auto_commit: Optional[bool] = None, no_merge: bool = False, match_fields: Optional[Union[list[str], str]] = None, error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty, load: Optional[LoadSpec] = None, execution_options: Optional[dict[str, Any]] = None, uniquify: Optional[bool] = None, ) -> list[ModelT]: """Modify or create multiple instances. Update instances with the attribute values present on `data`, or create a new instance if one doesn't exist. !!! tip In most cases, you will want to set `match_fields` to the combination of attributes, excluded the primary key, that define uniqueness for a row. Args: data: Instance to update existing, or be created. Identifier used to determine if an existing instance exists is the value of an attribute on ``data`` named as value of :attr:`id_attribute`. auto_expunge: Remove object from session before returning. auto_commit: Commit objects before returning. no_merge: Skip the usage of optimized Merge statements match_fields: a list of keys to use to match the existing model. When empty, automatically uses ``self.id_attribute`` (`id` by default) to match . error_messages: An optional dictionary of templates to use for friendlier error messages to clients load: Set default relationships to be loaded execution_options: Set default execution options uniquify: Optionally apply the ``unique()`` method to results before returning. Returns: The updated or created instance. """ self._uniquify = self._get_uniquify(uniquify) error_messages = self._get_error_messages( error_messages=error_messages, default_messages=self.error_messages, ) instances: list[ModelT] = [] data_to_update: list[ModelT] = [] data_to_insert: list[ModelT] = [] match_fields = self._get_match_fields(match_fields=match_fields) if match_fields is None: match_fields = [self.id_attribute] match_filter: list[Union[StatementFilter, ColumnElement[bool]]] = [] if match_fields: for field_name in match_fields: field = get_instrumented_attr(self.model_type, field_name) matched_values = [ field_data for datum in data if (field_data := getattr(datum, field_name)) is not None ] # Use field.in_() if types are incompatible with ANY() or if dialect doesn't prefer ANY() use_in = not self._prefer_any or self._type_must_use_in_instead_of_any(matched_values, field.type) match_filter.append(field.in_(matched_values) if use_in else any_(matched_values) == field) # type: ignore[arg-type] with wrap_sqlalchemy_exception( error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions ): existing_objs = self.list( *match_filter, load=load, execution_options=execution_options, auto_expunge=False, ) for field_name in match_fields: field = get_instrumented_attr(self.model_type, field_name) # Safe deduplication that handles unhashable types (e.g., JSONB dicts) all_values = [getattr(datum, field_name) for datum in existing_objs if datum] matched_values = self._get_unique_values(all_values) # Use field.in_() if types are incompatible with ANY() or if dialect doesn't prefer ANY() use_in = not self._prefer_any or self._type_must_use_in_instead_of_any(matched_values, field.type) match_filter.append(field.in_(matched_values) if use_in else any_(matched_values) == field) # type: ignore[arg-type] existing_ids = self._get_object_ids(existing_objs=existing_objs) data = self._merge_on_match_fields(data, existing_objs, match_fields) for datum in data: if getattr(datum, self.id_attribute, None) in existing_ids: data_to_update.append(datum) else: data_to_insert.append(datum) if data_to_insert: instances.extend( self.add_many(data_to_insert, auto_commit=False, auto_expunge=False), ) if data_to_update: instances.extend( self.update_many( data_to_update, auto_commit=False, auto_expunge=False, load=load, execution_options=execution_options, ), ) self._flush_or_commit(auto_commit=auto_commit) for instance in instances: self._expunge(instance, auto_expunge=auto_expunge) return instances
def _get_object_ids(self, existing_objs: list[ModelT]) -> list[Any]: return [obj_id for datum in existing_objs if (obj_id := getattr(datum, self.id_attribute)) is not None] def _get_match_fields( self, match_fields: Optional[Union[list[str], str]] = None, id_attribute: Optional[str] = None, ) -> Optional[list[str]]: id_attribute = id_attribute or self.id_attribute match_fields = match_fields or self.match_fields if isinstance(match_fields, str): match_fields = [match_fields] return match_fields def _merge_on_match_fields( self, data: list[ModelT], existing_data: list[ModelT], match_fields: Optional[Union[list[str], str]] = None, ) -> list[ModelT]: match_fields = self._get_match_fields(match_fields=match_fields) if match_fields is None: match_fields = [self.id_attribute] for existing_datum in existing_data: for datum in data: match = all( getattr(datum, field_name) == getattr(existing_datum, field_name) for field_name in match_fields ) if match and getattr(existing_datum, self.id_attribute) is not None: setattr(datum, self.id_attribute, getattr(existing_datum, self.id_attribute)) return data
[docs] def list( self, *filters: Union[StatementFilter, ColumnElement[bool]], auto_expunge: Optional[bool] = None, statement: Optional[Select[tuple[ModelT]]] = None, order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None, error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty, load: Optional[LoadSpec] = None, execution_options: Optional[dict[str, Any]] = None, uniquify: Optional[bool] = None, **kwargs: Any, ) -> list[ModelT]: """Get a list of instances, optionally filtered. Args: *filters: Types for specific filtering operations. auto_expunge: Remove object from session before returning. statement: To facilitate customization of the underlying select query. order_by: Set default order options for queries. error_messages: An optional dictionary of templates to use for friendlier error messages to clients load: Set relationships to be loaded execution_options: Set default execution options uniquify: Optionally apply the ``unique()`` method to results before returning. **kwargs: Instance attribute value filters. Returns: The list of instances, after filtering applied. """ self._uniquify = self._get_uniquify(uniquify) error_messages = self._get_error_messages( error_messages=error_messages, default_messages=self.error_messages, ) with wrap_sqlalchemy_exception( error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions ): execution_options = self._get_execution_options(execution_options) statement = self.statement if statement is None else statement loader_options, loader_options_have_wildcard = self._get_loader_options(load) statement = self._get_base_stmt( statement=statement, loader_options=loader_options, execution_options=execution_options, ) if order_by is None: order_by = self.order_by if self.order_by is not None else [] statement = self._apply_order_by(statement=statement, order_by=order_by) statement = self._apply_filters(*filters, statement=statement) statement = self._filter_select_by_kwargs(statement, kwargs) result = self._execute(statement, uniquify=loader_options_have_wildcard) instances = list(result.scalars()) for instance in instances: self._expunge(instance, auto_expunge=auto_expunge) return cast("list[ModelT]", instances)
[docs] @classmethod def check_health(cls, session: Union[Session, scoped_session[Session]]) -> bool: """Perform a health check on the database. Args: session: through which we run a check statement Returns: ``True`` if healthy. """ with wrap_sqlalchemy_exception(): return ( # type: ignore[no-any-return] session.execute(cls._get_health_check_statement(session)) ).scalar_one() == 1
@staticmethod def _get_health_check_statement(session: Union[Session, scoped_session[Session]]) -> TextClause: if session.bind and session.bind.dialect.name == "oracle": return text("SELECT 1 FROM DUAL") return text("SELECT 1") def _attach_to_session(self, model: ModelT, strategy: Literal["add", "merge"] = "add", load: bool = True) -> ModelT: """Attach detached instance to the session. Args: model: The instance to be attached to the session. strategy: How the instance should be attached. - "add": New instance added to session - "merge": Instance merged with existing, or new one added. load: Boolean, when False, merge switches into a "high performance" mode which causes it to forego emitting history events as well as all database access. This flag is used for cases such as transferring graphs of objects into a session from a second level cache, or to transfer just-loaded objects into the session owned by a worker thread or process without re-querying the database. Raises: ValueError: If `strategy` is not one of the expected values. Returns: Instance attached to the session - if `"merge"` strategy, may not be same instance that was provided. """ if strategy == "add": self.session.add(model) return model if strategy == "merge": return self.session.merge(model, load=load) msg = "Unexpected value for `strategy`, must be `'add'` or `'merge'`" # type: ignore[unreachable] raise ValueError(msg) def _execute( self, statement: Select[Any], uniquify: bool = False, ) -> Result[Any]: result = self.session.execute(statement) if uniquify or self._uniquify: result = result.unique() return result
[docs] class SQLAlchemySyncSlugRepository( SQLAlchemySyncRepository[ModelT], SQLAlchemySyncSlugRepositoryProtocol[ModelT], ): """Extends the repository to include slug model features.."""
[docs] def get_by_slug( self, slug: str, error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty, load: Optional[LoadSpec] = None, execution_options: Optional[dict[str, Any]] = None, uniquify: Optional[bool] = None, **kwargs: Any, ) -> Optional[ModelT]: """Select record by slug value. Returns: The model instance or None if not found. """ return self.get_one_or_none( slug=slug, load=load, execution_options=execution_options, error_messages=error_messages, uniquify=uniquify, )
[docs] def get_available_slug( self, value_to_slugify: str, **kwargs: Any, ) -> str: """Get a unique slug for the supplied value. If the value is found to exist, a random 4 digit character is appended to the end. Override this method to change the default behavior Args: value_to_slugify (str): A string that should be converted to a unique slug. **kwargs: stuff Returns: str: a unique slug for the supplied value. This is safe for URLs and other unique identifiers. """ slug = slugify(value_to_slugify) if self._is_slug_unique(slug): return slug random_string = "".join(random.choices(string.ascii_lowercase + string.digits, k=4)) # noqa: S311 return f"{slug}-{random_string}"
def _is_slug_unique( self, slug: str, load: Optional[LoadSpec] = None, execution_options: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> bool: return self.exists(slug=slug, load=load, execution_options=execution_options, **kwargs) is False
[docs] class SQLAlchemySyncQueryRepository: """SQLAlchemy Query Repository. This is a loosely typed helper to query for when you need to select data in ways that don't align to the normal repository pattern. """ error_messages: Optional[ErrorMessages] = None wrap_exceptions: bool = True
[docs] def __init__( self, *, session: Union[Session, scoped_session[Session]], error_messages: Optional[ErrorMessages] = None, wrap_exceptions: bool = True, **kwargs: Any, ) -> None: """Repository pattern for SQLAlchemy models. Args: session: Session managing the unit-of-work for the operation. error_messages: A set of error messages to use for operations. wrap_exceptions: Whether to wrap exceptions in a SQLAlchemy exception. **kwargs: Additional arguments (ignored). """ self.session = session self.error_messages = error_messages self.wrap_exceptions = wrap_exceptions self._dialect = self.session.bind.dialect if self.session.bind is not None else self.session.get_bind().dialect
[docs] def get_one( self, statement: Select[tuple[Any]], **kwargs: Any, ) -> Row[Any]: """Get instance identified by ``kwargs``. Args: statement: To facilitate customization of the underlying select query. **kwargs: Instance attribute value filters. Returns: The retrieved instance. """ with wrap_sqlalchemy_exception(error_messages=self.error_messages, wrap_exceptions=self.wrap_exceptions): statement = self._filter_statement_by_kwargs(statement, **kwargs) instance = (self.execute(statement)).scalar_one_or_none() return self.check_not_found(instance)
[docs] def get_one_or_none( self, statement: Select[Any], **kwargs: Any, ) -> Optional[Row[Any]]: """Get instance identified by ``kwargs`` or None if not found. Args: statement: To facilitate customization of the underlying select query. **kwargs: Instance attribute value filters. Returns: The retrieved instance or None """ with wrap_sqlalchemy_exception(error_messages=self.error_messages, wrap_exceptions=self.wrap_exceptions): statement = self._filter_statement_by_kwargs(statement, **kwargs) instance = (self.execute(statement)).scalar_one_or_none() return instance or None
[docs] def count(self, statement: Select[Any], **kwargs: Any) -> int: """Get the count of records returned by a query. Args: statement: To facilitate customization of the underlying select query. **kwargs: Instance attribute value filters. Returns: Count of records returned by query, ignoring pagination. """ with wrap_sqlalchemy_exception(error_messages=self.error_messages, wrap_exceptions=self.wrap_exceptions): statement = statement.with_only_columns(sql_func.count(text("1")), maintain_column_froms=True).order_by( None, ) statement = self._filter_statement_by_kwargs(statement, **kwargs) results = self.execute(statement) return results.scalar_one() # type: ignore
[docs] def list_and_count( self, statement: Select[Any], count_with_window_function: Optional[bool] = None, **kwargs: Any, ) -> tuple[list[Row[Any]], int]: """List records with total count. Args: statement: To facilitate customization of the underlying select query. count_with_window_function: Force list and count to use two queries instead of an analytical window function. **kwargs: Instance attribute value filters. Returns: Count of records returned by query, ignoring pagination. """ if self._dialect.name in {"spanner", "spanner+spanner"} or count_with_window_function: return self._list_and_count_basic(statement=statement, **kwargs) return self._list_and_count_window(statement=statement, **kwargs)
def _list_and_count_window( self, statement: Select[Any], **kwargs: Any, ) -> tuple[list[Row[Any]], int]: """List records with total count. Args: *filters: Types for specific filtering operations. statement: To facilitate customization of the underlying select query. **kwargs: Instance attribute value filters. Returns: Count of records returned by query using an analytical window function, ignoring pagination. """ with wrap_sqlalchemy_exception(error_messages=self.error_messages, wrap_exceptions=self.wrap_exceptions): statement = statement.add_columns(over(sql_func.count(text("1")))) statement = self._filter_statement_by_kwargs(statement, **kwargs) result = self.execute(statement) count: int = 0 instances: list[Row[Any]] = [] for i, (instance, count_value) in enumerate(result): instances.append(instance) if i == 0: count = count_value return instances, count @staticmethod def _get_count_stmt(statement: Select[Any]) -> Select[Any]: return statement.with_only_columns(sql_func.count(text("1")), maintain_column_froms=True).order_by(None) # pyright: ignore[reportUnknownVariable] def _list_and_count_basic( self, statement: Select[Any], **kwargs: Any, ) -> tuple[list[Row[Any]], int]: """List records with total count. Args: statement: To facilitate customization of the underlying select query. . **kwargs: Instance attribute value filters. Returns: Count of records returned by query using 2 queries, ignoring pagination. """ with wrap_sqlalchemy_exception(error_messages=self.error_messages, wrap_exceptions=self.wrap_exceptions): statement = self._filter_statement_by_kwargs(statement, **kwargs) count_result = self.session.execute(self._get_count_stmt(statement)) count = count_result.scalar_one() result = self.execute(statement) instances: list[Row[Any]] = [] for (instance,) in result: instances.append(instance) return instances, count
[docs] def list(self, statement: Select[Any], **kwargs: Any) -> list[Row[Any]]: """Get a list of instances, optionally filtered. Args: statement: To facilitate customization of the underlying select query. **kwargs: Instance attribute value filters. Returns: The list of instances, after filtering applied. """ with wrap_sqlalchemy_exception(error_messages=self.error_messages, wrap_exceptions=self.wrap_exceptions): statement = self._filter_statement_by_kwargs(statement, **kwargs) result = self.execute(statement) return list(result.all())
def _filter_statement_by_kwargs( self, statement: Select[Any], /, **kwargs: Any, ) -> Select[Any]: """Filter the collection by kwargs. Args: statement: statement to filter **kwargs: key/value pairs such that objects remaining in the statement after filtering have the property that their attribute named `key` has value equal to `value`. Returns: The filtered statement. """ with wrap_sqlalchemy_exception(error_messages=self.error_messages): return statement.filter_by(**kwargs) # the following is all sqlalchemy implementation detail, and shouldn't be directly accessed
[docs] @staticmethod def check_not_found(item_or_none: Optional[T]) -> T: """Raise :class:`NotFoundError` if ``item_or_none`` is ``None``. Args: item_or_none: Item to be tested for existence. Raises: NotFoundError: If ``item_or_none`` is ``None`` Returns: The item, if it exists. """ if item_or_none is None: msg = "No item found when one was expected" raise NotFoundError(msg) return item_or_none
[docs] def execute( self, statement: Union[ ReturningDelete[tuple[Any]], ReturningUpdate[tuple[Any]], Select[tuple[Any]], Update, Delete, Select[Any] ], ) -> Result[Any]: return self.session.execute(statement)