Source code for advanced_alchemy.base

"""Application ORM configuration."""

from __future__ import annotations

import contextlib
import re
from datetime import date, datetime, timezone
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, runtime_checkable
from uuid import UUID

from sqlalchemy import Date, Index, MetaData, Sequence, String, UniqueConstraint
from sqlalchemy.orm import (
    DeclarativeBase,
    Mapped,
    Mapper,
    declarative_mixin,
    declared_attr,
    mapped_column,
    orm_insert_sentinel,
    registry,
    validates,
)
from sqlalchemy.orm.decl_base import _TableArgsType as TableArgsType  # pyright: ignore[reportPrivateUsage]

from advanced_alchemy.types import GUID, NANOID_INSTALLED, UUID_UTILS_INSTALLED, BigIntIdentity, DateTimeUTC, JsonB

if UUID_UTILS_INSTALLED and not TYPE_CHECKING:
    from uuid_utils.compat import uuid4, uuid6, uuid7  # pyright: ignore[reportMissingImports]

else:
    from uuid import uuid4  # type: ignore[assignment]

    uuid6 = uuid4  # type: ignore[assignment]
    uuid7 = uuid4  # type: ignore[assignment]

if NANOID_INSTALLED and not TYPE_CHECKING:
    from fastnanoid import generate as nanoid  # pyright: ignore[reportMissingImports]

else:
    nanoid = uuid4

if TYPE_CHECKING:
    from sqlalchemy.sql import FromClause
    from sqlalchemy.sql.schema import (
        _NamingSchemaParameter as NamingSchemaParameter,  # pyright: ignore[reportPrivateUsage]
    )
    from sqlalchemy.types import TypeEngine


__all__ = (
    "AuditColumns",
    "BigIntAuditBase",
    "BigIntBase",
    "BigIntPrimaryKey",
    "CommonTableAttributes",
    "create_registry",
    "ModelProtocol",
    "UUIDAuditBase",
    "UUIDBase",
    "UUIDv6AuditBase",
    "UUIDv6Base",
    "UUIDv7AuditBase",
    "UUIDv7Base",
    "NanoIDAuditBase",
    "NanoIDBase",
    "UUIDPrimaryKey",
    "UUIDv7PrimaryKey",
    "UUIDv6PrimaryKey",
    "SlugKey",
    "SQLQuery",
    "orm_registry",
    "merge_table_arguments",
    "TableArgsType",
)


UUIDBaseT = TypeVar("UUIDBaseT", bound="UUIDBase")
BigIntBaseT = TypeVar("BigIntBaseT", bound="BigIntBase")
UUIDv6BaseT = TypeVar("UUIDv6BaseT", bound="UUIDv6Base")
UUIDv7BaseT = TypeVar("UUIDv7BaseT", bound="UUIDv7Base")
NanoIDBaseT = TypeVar("NanoIDBaseT", bound="NanoIDBase")

convention: NamingSchemaParameter = {
    "ix": "ix_%(column_0_label)s",
    "uq": "uq_%(table_name)s_%(column_0_name)s",
    "ck": "ck_%(table_name)s_%(constraint_name)s",
    "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
    "pk": "pk_%(table_name)s",
}
"""Templates for automated constraint name generation."""
table_name_regexp = re.compile("((?<=[a-z0-9])[A-Z]|(?!^)[A-Z](?=[a-z]))")
"""Regular expression for table name"""


[docs]def merge_table_arguments(cls: type[DeclarativeBase], table_args: TableArgsType | None = None) -> TableArgsType: """Merge Table Arguments. When using mixins that include their own table args, it is difficult to append info into the model such as a comment. This function helps you merge the args together. Args: cls (DeclarativeBase): This is the model that will get the table args table_args: additional information to add to table_args Returns: tuple | dict: The merged __table_args__ property """ args: list[Any] = [] kwargs: dict[str, Any] = {} mixin_table_args = (getattr(super(base_cls, cls), "__table_args__", None) for base_cls in cls.__bases__) # pyright: ignore[reportUnknownParameter,reportUnknownArgumentType,reportArgumentType] for arg_to_merge in (*mixin_table_args, table_args): if arg_to_merge: if isinstance(arg_to_merge, tuple): last_positional_arg = arg_to_merge[-1] args.extend(arg_to_merge[:-1]) if isinstance(last_positional_arg, dict): kwargs.update(last_positional_arg) # pyright: ignore[reportUnknownArgumentType] else: args.append(last_positional_arg) else: kwargs.update(arg_to_merge) if args: if kwargs: return (*args, kwargs) return tuple(args) return kwargs
[docs]@runtime_checkable class ModelProtocol(Protocol): """The base SQLAlchemy model protocol.""" if TYPE_CHECKING: __table__: FromClause __mapper__: Mapper[Any] __name__: str
[docs] def to_dict(self, exclude: set[str] | None = None) -> dict[str, Any]: """Convert model to dictionary. Returns: dict[str, Any]: A dict representation of the model """ ...
[docs]class UUIDPrimaryKey: """UUID Primary Key Field Mixin.""" id: Mapped[UUID] = mapped_column(default=uuid4, primary_key=True) """UUID Primary key column.""" @declared_attr def _sentinel(cls) -> Mapped[int]: return orm_insert_sentinel(name="sa_orm_sentinel")
[docs]class UUIDv6PrimaryKey: """UUID v6 Primary Key Field Mixin.""" id: Mapped[UUID] = mapped_column(default=uuid6, primary_key=True) """UUID Primary key column.""" @declared_attr def _sentinel(cls) -> Mapped[int]: return orm_insert_sentinel(name="sa_orm_sentinel")
[docs]class UUIDv7PrimaryKey: """UUID v7 Primary Key Field Mixin.""" id: Mapped[UUID] = mapped_column(default=uuid7, primary_key=True) """UUID Primary key column.""" @declared_attr def _sentinel(cls) -> Mapped[int]: return orm_insert_sentinel(name="sa_orm_sentinel")
class NanoIDPrimaryKey: """Nano ID Primary Key Field Mixin.""" id: Mapped[str] = mapped_column(default=nanoid, primary_key=True) """Nano ID Primary key column.""" @declared_attr def _sentinel(cls) -> Mapped[int]: return orm_insert_sentinel(name="sa_orm_sentinel")
[docs]class BigIntPrimaryKey: """BigInt Primary Key Field Mixin.""" # noinspection PyMethodParameters @declared_attr def id(cls) -> Mapped[int]: """BigInt Primary key column.""" return mapped_column( BigIntIdentity, Sequence(f"{cls.__tablename__}_id_seq", optional=False), # type: ignore[attr-defined] primary_key=True, )
[docs]class AuditColumns: """Created/Updated At Fields Mixin.""" created_at: Mapped[datetime] = mapped_column( DateTimeUTC(timezone=True), default=lambda: datetime.now(timezone.utc), ) """Date/time of instance creation.""" updated_at: Mapped[datetime] = mapped_column( DateTimeUTC(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc), ) """Date/time of instance last update.""" @validates("created_at", "updated_at") def validate_tz_info(self, _: str, value: datetime) -> datetime: if value.tzinfo is None: value = value.replace(tzinfo=timezone.utc) return value
class BasicAttributes: """Basic attributes for SQLALchemy tables and queries.""" if TYPE_CHECKING: __name__: str __table__: FromClause __mapper__: Mapper[Any] def to_dict(self, exclude: set[str] | None = None) -> dict[str, Any]: """Convert model to dictionary. Returns: dict[str, Any]: A dict representation of the model """ exclude = {"sa_orm_sentinel", "_sentinel"}.union(self._sa_instance_state.unloaded).union(exclude or []) # type: ignore[attr-defined] return { field: getattr(self, field) for field in self.__mapper__.columns.keys() # noqa: SIM118 if field not in exclude }
[docs]class CommonTableAttributes(BasicAttributes): """Common attributes for SQLALchemy tables.""" if TYPE_CHECKING: __tablename__: str else: @declared_attr.directive def __tablename__(cls) -> str: """Infer table name from class name.""" return table_name_regexp.sub(r"_\1", cls.__name__).lower()
[docs]@declarative_mixin class SlugKey: """Slug unique Field Model Mixin.""" @declared_attr def slug(cls) -> Mapped[str]: """Slug field.""" return mapped_column( String(length=100), nullable=False, ) @staticmethod def _create_unique_slug_index(*_args: Any, **kwargs: Any) -> bool: return bool(kwargs["dialect"].name.startswith("spanner")) @staticmethod def _create_unique_slug_constraint(*_args: Any, **kwargs: Any) -> bool: return not kwargs["dialect"].name.startswith("spanner") @declared_attr.directive @classmethod def __table_args__(cls) -> TableArgsType: return ( UniqueConstraint( cls.slug, name=f"uq_{cls.__tablename__}_slug", # type: ignore[attr-defined] ).ddl_if(callable_=cls._create_unique_slug_constraint), Index( f"ix_{cls.__tablename__}_slug_unique", # type: ignore[attr-defined] cls.slug, unique=True, ).ddl_if(callable_=cls._create_unique_slug_index), )
[docs]def create_registry( custom_annotation_map: dict[Any, type[TypeEngine[Any]] | TypeEngine[Any]] | None = None, ) -> registry: """Create a new SQLAlchemy registry.""" import uuid as core_uuid meta = MetaData(naming_convention=convention) type_annotation_map: dict[Any, type[TypeEngine[Any]] | TypeEngine[Any]] = { UUID: GUID, core_uuid.UUID: GUID, datetime: DateTimeUTC, date: Date, dict: JsonB, } with contextlib.suppress(ImportError): from pydantic import AnyHttpUrl, AnyUrl, EmailStr, Json type_annotation_map.update({EmailStr: String, AnyUrl: String, AnyHttpUrl: String, Json: JsonB}) with contextlib.suppress(ImportError): from msgspec import Struct type_annotation_map[Struct] = JsonB if custom_annotation_map is not None: type_annotation_map.update(custom_annotation_map) return registry(metadata=meta, type_annotation_map=type_annotation_map)
orm_registry = create_registry()
[docs]class UUIDBase(UUIDPrimaryKey, CommonTableAttributes, DeclarativeBase): """Base for all SQLAlchemy declarative models with UUID primary keys.""" registry = orm_registry
[docs]class UUIDAuditBase(CommonTableAttributes, UUIDPrimaryKey, AuditColumns, DeclarativeBase): """Base for declarative models with UUID primary keys and audit columns.""" registry = orm_registry
[docs]class UUIDv6Base(UUIDv6PrimaryKey, CommonTableAttributes, DeclarativeBase): """Base for all SQLAlchemy declarative models with UUID primary keys.""" registry = orm_registry
[docs]class UUIDv6AuditBase(CommonTableAttributes, UUIDv6PrimaryKey, AuditColumns, DeclarativeBase): """Base for declarative models with UUID primary keys and audit columns.""" registry = orm_registry
[docs]class UUIDv7Base(UUIDv7PrimaryKey, CommonTableAttributes, DeclarativeBase): """Base for all SQLAlchemy declarative models with UUID primary keys.""" registry = orm_registry
[docs]class UUIDv7AuditBase(CommonTableAttributes, UUIDv7PrimaryKey, AuditColumns, DeclarativeBase): """Base for declarative models with UUID primary keys and audit columns.""" registry = orm_registry
[docs]class NanoIDBase(NanoIDPrimaryKey, CommonTableAttributes, DeclarativeBase): """Base for all SQLAlchemy declarative models with Nano ID primary keys.""" registry = orm_registry
[docs]class NanoIDAuditBase(CommonTableAttributes, NanoIDPrimaryKey, AuditColumns, DeclarativeBase): """Base for declarative models with Nano ID primary keys and audit columns.""" registry = orm_registry
[docs]class BigIntBase(BigIntPrimaryKey, CommonTableAttributes, DeclarativeBase): """Base for all SQLAlchemy declarative models with BigInt primary keys.""" registry = orm_registry
[docs]class BigIntAuditBase(CommonTableAttributes, BigIntPrimaryKey, AuditColumns, DeclarativeBase): """Base for declarative models with BigInt primary keys and audit columns.""" registry = orm_registry
class DefaultBase(CommonTableAttributes, DeclarativeBase): """Base for all SQLAlchemy declarative models. No primary key is added""" registry = orm_registry
[docs]class SQLQuery(BasicAttributes, DeclarativeBase): """Base for all SQLAlchemy custom mapped objects.""" __allow_unmapped__ = True registry = orm_registry