# ruff: noqa: TC004
"""Application ORM configuration."""
from __future__ import annotations
import contextlib
import re
from datetime import date, datetime
from typing import TYPE_CHECKING, Any, Iterator, Optional, Protocol, cast, runtime_checkable
from uuid import UUID
from sqlalchemy import Date, MetaData, String
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.orm import (
DeclarativeBase,
Mapper,
declared_attr,
)
from sqlalchemy.orm import (
registry as SQLAlchemyRegistry, # noqa: N812
)
from sqlalchemy.orm.decl_base import _TableArgsType as TableArgsType # pyright: ignore[reportPrivateUsage]
from typing_extensions import Self, TypeVar
from advanced_alchemy.mixins import (
AuditColumns as _AuditColumns,
)
from advanced_alchemy.mixins import (
BigIntPrimaryKey as _BigIntPrimaryKey,
)
from advanced_alchemy.mixins import (
NanoIDPrimaryKey as _NanoIDPrimaryKey,
)
from advanced_alchemy.mixins import (
UUIDPrimaryKey as _UUIDPrimaryKey,
)
from advanced_alchemy.mixins import (
UUIDv6PrimaryKey as _UUIDv6PrimaryKey,
)
from advanced_alchemy.mixins import (
UUIDv7PrimaryKey as _UUIDv7PrimaryKey,
)
from advanced_alchemy.types import GUID, DateTimeUTC, JsonB
from advanced_alchemy.utils.dataclass import DataclassProtocol
from advanced_alchemy.utils.deprecation import warn_deprecation
if TYPE_CHECKING:
from sqlalchemy.sql import FromClause
from sqlalchemy.sql.schema import (
_NamingSchemaParameter as NamingSchemaParameter, # pyright: ignore[reportPrivateUsage]
)
from sqlalchemy.types import TypeEngine
# these should stay here since they are deprecated. They are imported in the __getattr__ function
from advanced_alchemy.mixins import (
AuditColumns,
BigIntPrimaryKey,
NanoIDPrimaryKey,
SlugKey,
UUIDPrimaryKey,
UUIDv6PrimaryKey,
UUIDv7PrimaryKey,
)
__all__ = (
"AdvancedDeclarativeBase",
"AuditColumns",
"BasicAttributes",
"BigIntAuditBase",
"BigIntBase",
"BigIntBaseT",
"BigIntPrimaryKey",
"CommonTableAttributes",
"ModelProtocol",
"NanoIDAuditBase",
"NanoIDBase",
"NanoIDBaseT",
"NanoIDPrimaryKey",
"SQLQuery",
"SlugKey",
"TableArgsType",
"UUIDAuditBase",
"UUIDBase",
"UUIDBaseT",
"UUIDPrimaryKey",
"UUIDv6AuditBase",
"UUIDv6Base",
"UUIDv6BaseT",
"UUIDv6PrimaryKey",
"UUIDv7AuditBase",
"UUIDv7Base",
"UUIDv7BaseT",
"UUIDv7PrimaryKey",
"convention",
"create_registry",
"merge_table_arguments",
"metadata_registry",
"orm_registry",
"table_name_regexp",
)
def __getattr__(attr_name: str) -> object:
_deprecated_attrs = {
"SlugKey",
"AuditColumns",
"NanoIDPrimaryKey",
"BigIntPrimaryKey",
"UUIDPrimaryKey",
"UUIDv6PrimaryKey",
"UUIDv7PrimaryKey",
}
if attr_name in _deprecated_attrs:
from advanced_alchemy import mixins
module = "advanced_alchemy.mixins"
value = globals()[attr_name] = getattr(mixins, attr_name)
warn_deprecation(
deprecated_name=f"advanced_alchemy.base.{attr_name}",
version="0.26.0",
kind="import",
removal_in="1.0.0",
info=f"importing {attr_name} from 'advanced_alchemy.base' is deprecated, please import it from '{module}' instead",
)
return value
if attr_name in set(__all__).difference(_deprecated_attrs):
value = globals()[attr_name] = locals()[attr_name]
return value
msg = f"module {__name__!r} has no attribute {attr_name!r}"
raise AttributeError(msg)
UUIDBaseT = TypeVar("UUIDBaseT", bound="UUIDBase")
"""Type variable for :class:`UUIDBase`."""
BigIntBaseT = TypeVar("BigIntBaseT", bound="BigIntBase")
"""Type variable for :class:`BigIntBase`."""
UUIDv6BaseT = TypeVar("UUIDv6BaseT", bound="UUIDv6Base")
"""Type variable for :class:`UUIDv6Base`."""
UUIDv7BaseT = TypeVar("UUIDv7BaseT", bound="UUIDv7Base")
"""Type variable for :class:`UUIDv7Base`."""
NanoIDBaseT = TypeVar("NanoIDBaseT", bound="NanoIDBase")
"""Type variable for :class:`NanoIDBase`."""
convention: NamingSchemaParameter = {
"ix": "ix_%(column_0_label)s",
"uq": "uq_%(table_name)s_%(column_0_name)s",
"ck": "ck_%(table_name)s_%(constraint_name)s",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"pk": "pk_%(table_name)s",
}
"""Templates for automated constraint name generation."""
table_name_regexp = re.compile("((?<=[a-z0-9])[A-Z]|(?!^)[A-Z](?=[a-z]))")
"""Regular expression for table name"""
[docs]
def merge_table_arguments(cls: type[DeclarativeBase], table_args: TableArgsType | None = None) -> TableArgsType:
"""Merge Table Arguments.
When using mixins that include their own table args, it is difficult to append info into the model such as a comment.
This function helps you merge the args together.
Args:
cls: :class:`sqlalchemy.orm.DeclarativeBase` This is the model that will get the table args
table_args: :class:`TableArgsType` additional information to add to table_args
Returns:
:class:`TableArgsType`
"""
args: list[Any] = []
kwargs: dict[str, Any] = {}
mixin_table_args = (getattr(super(base_cls, cls), "__table_args__", None) for base_cls in cls.__bases__) # pyright: ignore[reportUnknownParameter,reportUnknownArgumentType,reportArgumentType]
for arg_to_merge in (*mixin_table_args, table_args):
if arg_to_merge:
if isinstance(arg_to_merge, tuple):
last_positional_arg = arg_to_merge[-1] # pyright: ignore[reportUnknownVariableType]
args.extend(arg_to_merge[:-1]) # pyright: ignore[reportUnknownArgumentType]
if isinstance(last_positional_arg, dict):
kwargs.update(last_positional_arg) # pyright: ignore[reportUnknownArgumentType]
else:
args.append(last_positional_arg)
else:
kwargs.update(arg_to_merge)
if args:
if kwargs:
return (*args, kwargs)
return tuple(args)
return kwargs
[docs]
@runtime_checkable
class ModelProtocol(Protocol):
"""The base SQLAlchemy model protocol."""
if TYPE_CHECKING:
__table__: FromClause
__mapper__: Mapper[Any]
__name__: str
[docs]
def to_dict(self, exclude: set[str] | None = None) -> dict[str, Any]:
"""Convert model to dictionary.
Returns:
Dict[str, Any]: A dict representation of the model
"""
...
[docs]
class BasicAttributes:
"""Basic attributes for SQLALchemy tables and queries."""
if TYPE_CHECKING:
__name__: str
__table__: FromClause
__mapper__: Mapper[Any]
[docs]
def to_dict(self, exclude: set[str] | None = None) -> dict[str, Any]:
"""Convert model to dictionary.
Returns:
Dict[str, Any]: A dict representation of the model
"""
exclude = {"sa_orm_sentinel", "_sentinel"}.union(self._sa_instance_state.unloaded).union(exclude or []) # type: ignore[attr-defined]
return {
field: getattr(self, field)
for field in self.__mapper__.columns.keys() # noqa: SIM118
if field not in exclude
}
[docs]
class CommonTableAttributes(BasicAttributes):
"""Common attributes for SQLALchemy tables.
.. seealso::
:class:`BasicAttributes`
"""
if TYPE_CHECKING:
__tablename__: str
else:
@declared_attr.directive
def __tablename__(cls) -> str:
"""Infer table name from class name."""
return table_name_regexp.sub(r"_\1", cls.__name__).lower()
[docs]
def create_registry(
custom_annotation_map: dict[Any, type[TypeEngine[Any]] | TypeEngine[Any]] | None = None,
) -> SQLAlchemyRegistry:
"""Create a new SQLAlchemy registry.
Args:
custom_annotation_map: :class:`dict` of custom type annotations to use for the registry
Returns:
:class:`sqlalchemy.orm.registry`
"""
import uuid as core_uuid
meta = MetaData(naming_convention=convention)
type_annotation_map: dict[Any, type[TypeEngine[Any]] | TypeEngine[Any]] = {
UUID: GUID,
core_uuid.UUID: GUID,
datetime: DateTimeUTC,
date: Date,
dict: JsonB,
DataclassProtocol: JsonB,
}
with contextlib.suppress(ImportError):
from pydantic import AnyHttpUrl, AnyUrl, EmailStr, IPvAnyAddress, IPvAnyInterface, IPvAnyNetwork, Json
type_annotation_map.update(
{
EmailStr: String,
AnyUrl: String,
AnyHttpUrl: String,
Json: JsonB,
IPvAnyAddress: String,
IPvAnyInterface: String,
IPvAnyNetwork: String,
}
)
with contextlib.suppress(ImportError):
from msgspec import Struct
type_annotation_map[Struct] = JsonB
if custom_annotation_map is not None:
type_annotation_map.update(custom_annotation_map)
return SQLAlchemyRegistry(metadata=meta, type_annotation_map=type_annotation_map)
orm_registry = create_registry()
class MetadataRegistry:
"""A registry for metadata."""
_instance: MetadataRegistry | None = None
_registry: dict[str | None, MetaData] = {None: orm_registry.metadata}
def __new__(cls) -> Self:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cast(Self, cls._instance)
def get(self, bind_key: str | None = None) -> MetaData:
"""Get the metadata for the given bind key."""
return self._registry.setdefault(bind_key, MetaData(naming_convention=convention))
def set(self, bind_key: str | None, metadata: MetaData) -> None:
"""Set the metadata for the given bind key."""
self._registry[bind_key] = metadata
def __iter__(self) -> Iterator[str | None]:
return iter(self._registry)
def __getitem__(self, bind_key: str | None) -> MetaData:
return self._registry[bind_key]
def __setitem__(self, bind_key: str | None, metadata: MetaData) -> None:
self._registry[bind_key] = metadata
def __contains__(self, bind_key: str | None) -> bool:
return bind_key in self._registry
metadata_registry = MetadataRegistry()
[docs]
class AdvancedDeclarativeBase(DeclarativeBase):
"""A subclass of declarative base that allows for overriding of the registry.
.. seealso::
:class:`sqlalchemy.orm.DeclarativeBase`
"""
registry = orm_registry
__abstract__ = True
__metadata_registry__: MetadataRegistry = MetadataRegistry()
__bind_key__: Optional[str] = None # noqa: UP007
def __init_subclass__(cls, **kwargs: Any) -> None:
bind_key = getattr(cls, "__bind_key__", None)
if bind_key is not None:
cls.metadata = cls.__metadata_registry__.get(bind_key)
elif None not in cls.__metadata_registry__ and getattr(cls, "metadata", None) is not None:
cls.__metadata_registry__[None] = cls.metadata
super().__init_subclass__(**kwargs)
[docs]
class UUIDBase(_UUIDPrimaryKey, CommonTableAttributes, AdvancedDeclarativeBase, AsyncAttrs):
"""Base for all SQLAlchemy declarative models with UUID v4 primary keys.
.. seealso::
:class:`advanced_alchemy.mixins.UUIDPrimaryKey`
:class:`CommonTableAttributes`
:class:`AdvancedDeclarativeBase`
:class:`AsyncAttrs`
"""
__abstract__ = True
[docs]
class UUIDAuditBase(CommonTableAttributes, _UUIDPrimaryKey, _AuditColumns, AdvancedDeclarativeBase, AsyncAttrs):
"""Base for declarative models with UUID v4 primary keys and audit columns.
.. seealso::
:class:`CommonTableAttributes`
:class:`advanced_alchemy.mixins.UUIDPrimaryKey`
:class:`advanced_alchemy.mixins.AuditColumns`
:class:`AdvancedDeclarativeBase`
:class:`AsyncAttrs`
"""
__abstract__ = True
[docs]
class UUIDv6Base(_UUIDv6PrimaryKey, CommonTableAttributes, AdvancedDeclarativeBase, AsyncAttrs):
"""Base for all SQLAlchemy declarative models with UUID v6 primary keys.
.. seealso::
:class:`advanced_alchemy.mixins.UUIDv6PrimaryKey`
:class:`CommonTableAttributes`
:class:`AdvancedDeclarativeBase`
:class:`AsyncAttrs`
"""
__abstract__ = True
[docs]
class UUIDv6AuditBase(CommonTableAttributes, _UUIDv6PrimaryKey, _AuditColumns, AdvancedDeclarativeBase, AsyncAttrs):
"""Base for declarative models with UUID v6 primary keys and audit columns.
.. seealso::
:class:`CommonTableAttributes`
:class:`advanced_alchemy.mixins.UUIDv6PrimaryKey`
:class:`advanced_alchemy.mixins.AuditColumns`
:class:`AdvancedDeclarativeBase`
:class:`AsyncAttrs`
"""
__abstract__ = True
[docs]
class UUIDv7Base(_UUIDv7PrimaryKey, CommonTableAttributes, AdvancedDeclarativeBase, AsyncAttrs):
"""Base for all SQLAlchemy declarative models with UUID v7 primary keys.
.. seealso::
:class:`advanced_alchemy.mixins.UUIDv7PrimaryKey`
:class:`CommonTableAttributes`
:class:`AdvancedDeclarativeBase`
:class:`AsyncAttrs`
"""
__abstract__ = True
[docs]
class UUIDv7AuditBase(CommonTableAttributes, _UUIDv7PrimaryKey, _AuditColumns, AdvancedDeclarativeBase, AsyncAttrs):
"""Base for declarative models with UUID v7 primary keys and audit columns.
.. seealso::
:class:`CommonTableAttributes`
:class:`advanced_alchemy.mixins.UUIDv7PrimaryKey`
:class:`advanced_alchemy.mixins.AuditColumns`
:class:`AdvancedDeclarativeBase`
:class:`AsyncAttrs`
"""
__abstract__ = True
[docs]
class NanoIDBase(_NanoIDPrimaryKey, CommonTableAttributes, AdvancedDeclarativeBase, AsyncAttrs):
"""Base for all SQLAlchemy declarative models with Nano ID primary keys.
.. seealso::
:class:`advanced_alchemy.mixins.NanoIDPrimaryKey`
:class:`CommonTableAttributes`
:class:`AdvancedDeclarativeBase`
:class:`AsyncAttrs`
"""
__abstract__ = True
[docs]
class NanoIDAuditBase(CommonTableAttributes, _NanoIDPrimaryKey, _AuditColumns, AdvancedDeclarativeBase, AsyncAttrs):
"""Base for declarative models with Nano ID primary keys and audit columns.
.. seealso::
:class:`CommonTableAttributes`
:class:`advanced_alchemy.mixins.NanoIDPrimaryKey`
:class:`advanced_alchemy.mixins.AuditColumns`
:class:`AdvancedDeclarativeBase`
:class:`AsyncAttrs`
"""
__abstract__ = True
[docs]
class BigIntBase(_BigIntPrimaryKey, CommonTableAttributes, AdvancedDeclarativeBase, AsyncAttrs):
"""Base for all SQLAlchemy declarative models with BigInt primary keys.
.. seealso::
:class:`advanced_alchemy.mixins.BigIntPrimaryKey`
:class:`CommonTableAttributes`
:class:`AdvancedDeclarativeBase`
:class:`AsyncAttrs`
"""
__abstract__ = True
[docs]
class BigIntAuditBase(CommonTableAttributes, _BigIntPrimaryKey, _AuditColumns, AdvancedDeclarativeBase, AsyncAttrs):
"""Base for declarative models with BigInt primary keys and audit columns.
.. seealso::
:class:`CommonTableAttributes`
:class:`advanced_alchemy.mixins.BigIntPrimaryKey`
:class:`advanced_alchemy.mixins.AuditColumns`
:class:`AdvancedDeclarativeBase`
:class:`AsyncAttrs`
"""
__abstract__ = True
class DefaultBase(CommonTableAttributes, AdvancedDeclarativeBase, AsyncAttrs):
"""Base for all SQLAlchemy declarative models. No primary key is added.
.. seealso::
:class:`CommonTableAttributes`
:class:`AdvancedDeclarativeBase`
:class:`AsyncAttrs`
"""
__abstract__ = True
[docs]
class SQLQuery(BasicAttributes, AdvancedDeclarativeBase, AsyncAttrs):
"""Base for all SQLAlchemy custom mapped objects.
.. seealso::
:class:`BasicAttributes`
:class:`AdvancedDeclarativeBase`
:class:`AsyncAttrs`
"""
__abstract__ = True
__allow_unmapped__ = True