import datetime
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Final, Generic, Optional, TypeVar, Union, cast
from litestar.exceptions import ImproperlyConfiguredException
from litestar.middleware.session.server_side import ServerSideSessionBackend, ServerSideSessionConfig
from sqlalchemy import (
BooleanClauseList,
Dialect,
Index,
LargeBinary,
ScalarResult,
String,
UniqueConstraint,
delete,
func,
select,
)
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import Mapped, Session, declarative_mixin, declared_attr, mapped_column
from advanced_alchemy.base import UUIDv7Base
from advanced_alchemy.extensions.litestar.plugins.init import (
SQLAlchemyAsyncConfig,
SQLAlchemySyncConfig,
)
from advanced_alchemy.operations import OnConflictUpsert
from advanced_alchemy.utils.sync_tools import async_
if TYPE_CHECKING:
from litestar.stores.base import Store
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.decl_base import _TableArgsType as TableArgsType # pyright: ignore[reportPrivateUsage]
from sqlalchemy.sql import Select
from sqlalchemy.sql.elements import BooleanClauseList
SQLAlchemyConfig = Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]
SQLAlchemyConfigT = TypeVar("SQLAlchemyConfigT", bound=SQLAlchemyConfig)
SessionModelT = TypeVar("SessionModelT", bound="SessionModelMixin")
# Session ID field limit as defined in the database schema
SESSION_ID_MAX_LENGTH = 255
# PostgreSQL version supporting MERGE (same as store.py)
_POSTGRES_VERSION_SUPPORTING_MERGE: Final = 15
# Temporary toggle to disable PostgreSQL MERGE due to locking concerns
_DISABLE_POSTGRES_MERGE: Final = True
[docs]
@declarative_mixin
class SessionModelMixin(UUIDv7Base):
"""Mixin for session storage."""
__abstract__ = True
@declared_attr.directive
@classmethod
def __table_args__(cls) -> "TableArgsType":
return (
UniqueConstraint(
cls.session_id,
name=f"uq_{cls.__tablename__}_session_id",
).ddl_if(callable_=cls._create_unique_session_id_constraint),
Index(
f"ix_{cls.__tablename__}_session_id_unique",
cls.session_id,
unique=True,
).ddl_if(callable_=cls._create_unique_session_id_index),
)
@declared_attr
def session_id(cls) -> Mapped[str]:
return mapped_column(String(length=255), nullable=False)
@declared_attr
def data(cls) -> Mapped[bytes]:
return mapped_column(LargeBinary, nullable=False)
@declared_attr
def expires_at(cls) -> Mapped[datetime.datetime]:
return mapped_column(index=True)
@classmethod
def _create_unique_session_id_index(cls, *_: Any, **kwargs: Any) -> bool:
dialect_name = kwargs.get("dialect", {}).name if "dialect" in kwargs else ""
return bool("spanner" in dialect_name.lower())
@classmethod
def _create_unique_session_id_constraint(cls, *_: Any, **kwargs: Any) -> bool:
dialect_name = kwargs.get("dialect", {}).name if "dialect" in kwargs else ""
return "spanner" not in dialect_name.lower()
@hybrid_property
def is_expired(self) -> bool: # pyright: ignore
"""Boolean indicating if the session has expired.
Returns:
`True` if the session has expired, otherwise `False`
"""
return datetime.datetime.now(datetime.timezone.utc) > self.expires_at
@is_expired.expression # type: ignore[no-redef]
def is_expired(cls) -> "BooleanClauseList": # noqa: N805
"""SQL-Expression to check if the session has expired.
Returns:
SQL-Expression to check if the session has expired.
"""
return cast("BooleanClauseList", func.now() > cls.expires_at)
[docs]
class SQLAlchemySessionBackendBase(ServerSideSessionBackend, ABC, Generic[SQLAlchemyConfigT]):
"""Session backend to store data in a database with SQLAlchemy. Works with both sync and async engines.
Notes:
- Requires `sqlalchemy` which needs to be installed separately, and a configured
SQLAlchemyPlugin.
"""
__slots__ = ("_model", "_session_maker")
[docs]
def __init__(
self,
config: "ServerSideSessionConfig",
alchemy_config: "SQLAlchemyConfigT",
model: "type[SessionModelMixin]",
) -> None:
"""Initialize `BaseSQLAlchemyBackend`.
Args:
config: An instance of `SQLAlchemyBackendConfig`
alchemy_config: An instance of `SQLAlchemyConfig`
model: A mapped model subclassing `SessionModelMixin`
"""
self._model = model
self._config = config
self._alchemy = alchemy_config
[docs]
def __deepcopy__(self, memo: dict[int, Any]) -> "SQLAlchemySessionBackendBase[SQLAlchemyConfigT]":
"""Custom deepcopy implementation to handle unpicklable SQLAlchemy objects."""
# Create a new instance with the same configuration
cls = self.__class__
# Create a shallow copy first
new_obj = cls.__new__(cls)
memo[id(self)] = new_obj
# Copy the ServerSideSessionConfig safely - it should be serializable
try:
new_obj._config = deepcopy(self.config, memo) # noqa: SLF001
except (TypeError, AttributeError):
# If config can't be deep-copied, just reference the original
new_obj._config = self.config # noqa: SLF001
# Model classes are safe to reference directly
new_obj._model = self.model # noqa: SLF001
# SQLAlchemy config contains unpicklable objects, so we reference the original
# This is safe because configs are typically shared and immutable
new_obj._alchemy = self.alchemy # noqa: SLF001
return new_obj
def _select_session_obj(self, session_id: str) -> "Select[tuple[SessionModelMixin]]":
return select(self._model).where(self._model.session_id == session_id)
def _update_session_expiry(self, session_obj: "SessionModelMixin") -> None:
session_obj.expires_at = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(
seconds=self.config.max_age
)
[docs]
@staticmethod
def supports_merge(dialect: "Optional[Dialect]" = None, force_disable_merge: bool = False) -> bool:
"""Check if the dialect supports MERGE statements for upserts."""
return bool(
dialect
and (
(
dialect.server_version_info is not None
and dialect.server_version_info[0] >= _POSTGRES_VERSION_SUPPORTING_MERGE
and dialect.name == "postgresql"
and not _DISABLE_POSTGRES_MERGE # Temporary PostgreSQL MERGE disable
)
or dialect.name == "oracle"
)
and not force_disable_merge
)
[docs]
@staticmethod
def supports_upsert(dialect: "Optional[Dialect]" = None, force_disable_upsert: bool = False) -> bool:
"""Check if the dialect supports native upsert operations."""
return bool(
dialect
and (dialect.name in {"postgresql", "cockroachdb", "sqlite", "mysql", "mariadb", "duckdb"})
and not force_disable_upsert
)
[docs]
@abstractmethod
async def delete_expired(self) -> None:
"""Delete all expired sessions from the database."""
@property
def model(self) -> "type[SessionModelMixin]":
return self._model
@property
def config(self) -> "ServerSideSessionConfig":
return self._config
@config.setter
def config(self, value: "ServerSideSessionConfig") -> None:
self._config = value
@property
def alchemy(self) -> "SQLAlchemyConfigT":
return self._alchemy
@property
def _backend_class(self) -> "type[Union[SQLAlchemySyncSessionBackend, SQLAlchemyAsyncSessionBackend]]":
"""Return either `SQLAlchemyBackend` or `AsyncSQLAlchemyBackend`, depending on the engine type configured in the
`SQLAlchemyPlugin`
"""
if isinstance(self.alchemy, SQLAlchemyAsyncConfig):
return SQLAlchemyAsyncSessionBackend
return SQLAlchemySyncSessionBackend
[docs]
class SQLAlchemyAsyncSessionBackend(SQLAlchemySessionBackendBase[SQLAlchemyAsyncConfig]):
"""Asynchronous SQLAlchemy backend."""
async def _get_session_obj(self, *, db_session: "AsyncSession", session_id: str) -> Optional[SessionModelMixin]:
return (
cast(
"ScalarResult[Optional[SessionModelMixin]]",
await db_session.scalars(self._select_session_obj(session_id)),
)
).one_or_none()
[docs]
async def get(self, /, session_id: str, store: "Store") -> Optional[bytes]:
"""Retrieve data associated with `session_id`.
Args:
session_id: The session-ID
store: The store to get the session from (not used in this backend)
Returns:
The session data, if existing, otherwise `None`.
"""
session_id = session_id[:SESSION_ID_MAX_LENGTH] if len(session_id) > SESSION_ID_MAX_LENGTH else session_id
async with self.alchemy.get_session() as db_session:
session_obj = await self._get_session_obj(db_session=db_session, session_id=session_id)
if session_obj:
if not session_obj.is_expired:
data = session_obj.data
self._update_session_expiry(session_obj)
await db_session.commit()
return data
await db_session.delete(session_obj)
await db_session.commit()
return None
[docs]
async def set(self, /, session_id: str, data: bytes, store: "Store") -> None:
"""Store `data` under the `session_id` for later retrieval.
If there is already data associated with `session_id`, replace
it with `data` and reset its expiry time
Args:
session_id: The session-ID.
data: Serialized session data
store: The store to store the session in (not used in this backend)
"""
session_id = session_id[:SESSION_ID_MAX_LENGTH] if len(session_id) > SESSION_ID_MAX_LENGTH else session_id
expires_at = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(seconds=self.config.max_age)
async with self.alchemy.get_session() as db_session:
if db_session.bind is None: # pyright: ignore[reportUnnecessaryComparison]
msg = "Database connection is not available" # type: ignore[unreachable]
raise ImproperlyConfiguredException(msg)
dialect = db_session.bind.dialect
dialect_name = dialect.name
values = {
"session_id": session_id,
"data": data,
"expires_at": expires_at,
}
conflict_columns = ["session_id"]
update_columns = ["data", "expires_at"]
if OnConflictUpsert.supports_native_upsert(dialect_name):
upsert_stmt = OnConflictUpsert.create_upsert(
table=self._model.__table__, # type: ignore[arg-type]
values=values,
conflict_columns=conflict_columns,
update_columns=update_columns,
dialect_name=dialect_name,
validate_identifiers=False,
)
await db_session.execute(upsert_stmt)
elif self.supports_merge(dialect):
merge_stmt, additional_params = OnConflictUpsert.create_merge_upsert(
table=self._model.__table__, # type: ignore[arg-type]
values=values,
conflict_columns=conflict_columns,
update_columns=update_columns,
dialect_name=dialect_name,
validate_identifiers=False,
)
# Merge additional Oracle parameters with original values
merge_values = {**values, **additional_params}
await db_session.execute(merge_stmt, merge_values)
else:
# Fallback logic: Check existence, then update or insert
session_obj = await self._get_session_obj(db_session=db_session, session_id=session_id)
if not session_obj:
session_obj = self._model(session_id=session_id)
db_session.add(session_obj)
session_obj.data = data
session_obj.expires_at = expires_at
await db_session.commit()
[docs]
async def delete(self, /, session_id: str, store: "Store") -> None:
"""Delete the data associated with `session_id`. Fails silently if no such session-ID exists.
Args:
session_id: The session-ID
store: The store to delete the session from (not used in this backend)
"""
session_id = session_id[:SESSION_ID_MAX_LENGTH] if len(session_id) > SESSION_ID_MAX_LENGTH else session_id
async with self.alchemy.get_session() as db_session:
await db_session.execute(delete(self._model).where(self._model.session_id == session_id))
await db_session.commit()
[docs]
async def delete_all(self, /, store: "Store") -> None:
"""Delete all session data."""
async with self.alchemy.get_session() as db_session:
await db_session.execute(delete(self._model))
await db_session.commit()
[docs]
async def delete_expired(self) -> None:
"""Delete all expired session from the database."""
async with self.alchemy.get_session() as db_session:
await db_session.execute(delete(self._model).where(self._model.is_expired))
await db_session.commit()
[docs]
class SQLAlchemySyncSessionBackend(SQLAlchemySessionBackendBase[SQLAlchemySyncConfig]):
"""Synchronous SQLAlchemy backend."""
def _get_session_obj(self, *, db_session: "Session", session_id: str) -> "Optional[SessionModelMixin]":
return db_session.scalars(self._select_session_obj(session_id)).one_or_none()
def _get_sync(self, session_id: str) -> Optional[bytes]:
session_id = session_id[:SESSION_ID_MAX_LENGTH] if len(session_id) > SESSION_ID_MAX_LENGTH else session_id
with self.alchemy.get_session() as db_session:
session_obj = self._get_session_obj(db_session=db_session, session_id=session_id)
if session_obj:
if not session_obj.is_expired:
data = session_obj.data
self._update_session_expiry(session_obj)
db_session.commit()
return data
db_session.delete(session_obj)
db_session.commit()
return None
[docs]
async def get(self, /, session_id: str, store: "Store") -> Optional[bytes]:
"""Retrieve data associated with `session_id`.
Args:
session_id: The session-ID
store: The store to get the session from
Returns:
The session data, if existing, otherwise `None`.
"""
return await async_(self._get_sync)(session_id)
def _set_sync(self, session_id: str, data: bytes) -> None:
session_id = session_id[:SESSION_ID_MAX_LENGTH] if len(session_id) > SESSION_ID_MAX_LENGTH else session_id
expires_at = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(seconds=self.config.max_age)
with self.alchemy.get_session() as db_session:
if db_session.bind is None:
msg = "Database connection is not available"
raise ImproperlyConfiguredException(msg)
dialect = db_session.bind.dialect
dialect_name = dialect.name
values = {
"session_id": session_id,
"data": data,
"expires_at": expires_at,
}
conflict_columns = ["session_id"]
update_columns = ["data", "expires_at"]
if OnConflictUpsert.supports_native_upsert(dialect_name):
upsert_stmt = OnConflictUpsert.create_upsert(
table=self._model.__table__, # type: ignore[arg-type]
values=values,
conflict_columns=conflict_columns,
update_columns=update_columns,
dialect_name=dialect_name,
validate_identifiers=False,
)
db_session.execute(upsert_stmt)
elif self.supports_merge(dialect):
merge_stmt, additional_params = OnConflictUpsert.create_merge_upsert(
table=self._model.__table__, # type: ignore[arg-type]
values=values,
conflict_columns=conflict_columns,
update_columns=update_columns,
dialect_name=dialect_name,
validate_identifiers=False,
)
# Merge additional Oracle parameters with original values
merge_values = {**values, **additional_params}
db_session.execute(merge_stmt, merge_values)
else:
# Fallback logic: Check existence, then update or insert
session_obj = self._get_session_obj(db_session=db_session, session_id=session_id)
if not session_obj:
session_obj = self._model(session_id=session_id)
db_session.add(session_obj)
session_obj.data = data
session_obj.expires_at = expires_at
db_session.commit()
[docs]
async def set(self, /, session_id: str, data: bytes, store: "Store") -> None:
"""Store `data` under the `session_id` for later retrieval.
If there is already data associated with `session_id`, replace
it with `data` and reset its expiry time
Args:
session_id: The session-ID
data: Serialized session data
store: The store to store the session in
"""
return await async_(self._set_sync)(session_id, data)
def _delete_sync(self, session_id: str) -> None:
session_id = session_id[:SESSION_ID_MAX_LENGTH] if len(session_id) > SESSION_ID_MAX_LENGTH else session_id
with self.alchemy.get_session() as db_session:
db_session.execute(delete(self._model).where(self._model.session_id == session_id))
db_session.commit()
[docs]
async def delete(self, /, session_id: str, store: "Store") -> None:
"""Delete the data associated with `session_id`. Fails silently if no such session-ID exists.
Args:
session_id: The session-ID
store: The store to delete the session from
"""
return await async_(self._delete_sync)(session_id)
def _delete_all_sync(self) -> None:
with self.alchemy.get_session() as db_session:
db_session.execute(delete(self._model))
db_session.commit()
[docs]
async def delete_all(self) -> None:
"""Delete all session data."""
return await async_(self._delete_all_sync)()
def _delete_expired_sync(self) -> None:
with self.alchemy.get_session() as db_session:
db_session.execute(delete(self._model).where(self._model.is_expired))
db_session.commit()
[docs]
async def delete_expired(self) -> None:
"""Delete all expired session from the database."""
return await async_(self._delete_expired_sync)()