from __future__ import annotations
from typing import TYPE_CHECKING, Any
from litestar.plugins import SerializationPluginProtocol
from sqlalchemy.orm import DeclarativeBase
from advanced_alchemy.extensions.litestar.dto import SQLAlchemyDTO
from advanced_alchemy.extensions.litestar.plugins import _slots_base
if TYPE_CHECKING:
from litestar.typing import FieldDefinition
[docs]
class SQLAlchemySerializationPlugin(SerializationPluginProtocol, _slots_base.SlotsBase):
[docs]
def __init__(self) -> None:
self._type_dto_map: dict[type[DeclarativeBase], type[SQLAlchemyDTO[Any]]] = {}
[docs]
def supports_type(self, field_definition: FieldDefinition) -> bool:
return (
field_definition.is_collection and field_definition.has_inner_subclass_of(DeclarativeBase)
) or field_definition.is_subclass_of(DeclarativeBase)
[docs]
def create_dto_for_type(self, field_definition: FieldDefinition) -> type[SQLAlchemyDTO[Any]]:
# assumes that the type is a container of SQLAlchemy models or a single SQLAlchemy model
annotation = next(
(
inner_type.annotation
for inner_type in field_definition.inner_types
if inner_type.is_subclass_of(DeclarativeBase)
),
field_definition.annotation,
)
if annotation in self._type_dto_map:
return self._type_dto_map[annotation]
self._type_dto_map[annotation] = dto_type = SQLAlchemyDTO[annotation] # type:ignore[valid-type]
return dto_type