Source code for advanced_alchemy.extensions.litestar.plugins.serialization

from __future__ import annotations

from typing import TYPE_CHECKING, Any

from litestar.plugins import SerializationPluginProtocol
from sqlalchemy.orm import DeclarativeBase

from advanced_alchemy.extensions.litestar.dto import SQLAlchemyDTO
from advanced_alchemy.extensions.litestar.plugins import _slots_base

if TYPE_CHECKING:
    from litestar.typing import FieldDefinition


[docs]class SQLAlchemySerializationPlugin(SerializationPluginProtocol, _slots_base.SlotsBase):
[docs] def __init__(self) -> None: self._type_dto_map: dict[type[DeclarativeBase], type[SQLAlchemyDTO[Any]]] = {}
[docs] def supports_type(self, field_definition: FieldDefinition) -> bool: return ( field_definition.is_collection and field_definition.has_inner_subclass_of(DeclarativeBase) ) or field_definition.is_subclass_of(DeclarativeBase)
[docs] def create_dto_for_type(self, field_definition: FieldDefinition) -> type[SQLAlchemyDTO[Any]]: # assumes that the type is a container of SQLAlchemy models or a single SQLAlchemy model annotation = next( ( inner_type.annotation for inner_type in field_definition.inner_types if inner_type.is_subclass_of(DeclarativeBase) ), field_definition.annotation, ) if annotation in self._type_dto_map: return self._type_dto_map[annotation] self._type_dto_map[annotation] = dto_type = SQLAlchemyDTO[annotation] # type:ignore[valid-type] return dto_type