Source code for litestar.serialization.msgspec_hooks

from __future__ import annotations

from collections import deque
from datetime import date, datetime, time
from decimal import Decimal
from functools import partial
from ipaddress import (
    IPv4Address,
    IPv4Interface,
    IPv4Network,
    IPv6Address,
    IPv6Interface,
    IPv6Network,
)
from pathlib import Path, PurePath
from re import Pattern
from typing import TYPE_CHECKING, Any, Callable, Mapping, TypeVar, overload
from uuid import UUID

import msgspec

from litestar.datastructures.secret_values import SecretBytes, SecretString
from litestar.exceptions import SerializationException
from litestar.types import Empty, EmptyType, Serializer, TypeDecodersSequence
from litestar.utils.typing import get_origin_or_inner_type

if TYPE_CHECKING:
    from litestar.types import TypeEncodersMap

__all__ = (
    "decode_json",
    "decode_msgpack",
    "default_deserializer",
    "default_serializer",
    "encode_json",
    "encode_msgpack",
    "get_serializer",
)

T = TypeVar("T")

DEFAULT_TYPE_ENCODERS: TypeEncodersMap = {
    Path: str,
    PurePath: str,
    IPv4Address: str,
    IPv4Interface: str,
    IPv4Network: str,
    IPv6Address: str,
    IPv6Interface: str,
    IPv6Network: str,
    datetime: lambda val: val.isoformat(),
    date: lambda val: val.isoformat(),
    time: lambda val: val.isoformat(),
    deque: list,
    Decimal: lambda val: int(val) if val.as_tuple().exponent >= 0 else float(val),
    Pattern: lambda val: val.pattern,
    SecretBytes: lambda val: val.get_obscured().decode("utf-8"),
    SecretString: lambda val: val.get_obscured(),
    # support subclasses of stdlib types, If no previous type matched, these will be
    # the last type in the mro, so we use this to (attempt to) convert a subclass into
    # its base class. # see https://github.com/jcrist/msgspec/issues/248
    # and https://github.com/litestar-org/litestar/issues/1003
    str: str,
    int: int,
    float: float,
    set: set,
    frozenset: frozenset,
    bytes: bytes,
}


def default_serializer(value: Any, type_encoders: Mapping[Any, Callable[[Any], Any]] | None = None) -> Any:
    """Transform values non-natively supported by ``msgspec``

    Args:
        value: A value to serialized
        type_encoders: Mapping of types to callables to transforming types
    Returns:
        A serialized value
    Raises:
        TypeError: if value is not supported
    """
    type_encoders = {**DEFAULT_TYPE_ENCODERS, **(type_encoders or {})}

    for base in value.__class__.__mro__[:-1]:
        try:
            encoder = type_encoders[base]
            return encoder(value)
        except KeyError:
            continue

    raise TypeError(f"Unsupported type: {type(value)!r}")


def default_deserializer(
    target_type: Any, value: Any, type_decoders: TypeDecodersSequence | None = None
) -> Any:  # pragma: no cover
    """Transform values non-natively supported by ``msgspec``

    Args:
        target_type: Encountered type
        value: Value to coerce
        type_decoders: Optional sequence of type decoders

    Returns:
        A ``msgspec``-supported type
    """

    from litestar.datastructures.state import ImmutableState

    try:
        if isinstance(value, target_type):
            return value
    except TypeError as exc:
        # we might get a TypeError here if target_type is a subscribed generic. For
        # performance reasons, we let this happen and only unwrap this when we're
        # certain this might be the case
        if (origin := get_origin_or_inner_type(target_type)) is not None:
            target_type = origin
            if isinstance(value, target_type):
                return value
        else:
            raise exc

    if type_decoders:
        for predicate, decoder in type_decoders:
            if predicate(target_type):
                return decoder(target_type, value)

    if issubclass(target_type, (Path, PurePath, ImmutableState, UUID)):
        return target_type(value)

    if issubclass(target_type, SecretBytes) and isinstance(value, (bytes, str)):
        return SecretBytes(value.encode("utf-8") if isinstance(value, str) else value)

    if issubclass(target_type, SecretString) and isinstance(value, str):
        return SecretString(value)

    raise TypeError(f"Unsupported type: {type(value)!r}")


_msgspec_json_encoder = msgspec.json.Encoder(enc_hook=default_serializer)
_msgspec_json_decoder = msgspec.json.Decoder(dec_hook=default_deserializer)
_msgspec_msgpack_encoder = msgspec.msgpack.Encoder(enc_hook=default_serializer)
_msgspec_msgpack_decoder = msgspec.msgpack.Decoder(dec_hook=default_deserializer)


def encode_json(value: Any, serializer: Callable[[Any], Any] | None = None) -> bytes:
    """Encode a value into JSON.

    Args:
        value: Value to encode
        serializer: Optional callable to support non-natively supported types.

    Returns:
        JSON as bytes

    Raises:
        SerializationException: If error encoding ``obj``.
    """
    try:
        return msgspec.json.encode(value, enc_hook=serializer) if serializer else _msgspec_json_encoder.encode(value)
    except (TypeError, msgspec.EncodeError) as msgspec_error:
        raise SerializationException(str(msgspec_error)) from msgspec_error


@overload
def decode_json(value: str | bytes, strict: bool = ...) -> Any: ...


@overload
def decode_json(value: str | bytes, type_decoders: TypeDecodersSequence | None, strict: bool = ...) -> Any: ...


@overload
def decode_json(value: str | bytes, target_type: type[T], strict: bool = ...) -> T: ...


@overload
def decode_json(
    value: str | bytes, target_type: type[T], type_decoders: TypeDecodersSequence | None, strict: bool = ...
) -> T: ...


def decode_json(  # type: ignore[misc]
    value: str | bytes,
    target_type: type[T] | EmptyType = Empty,  # pyright: ignore
    type_decoders: TypeDecodersSequence | None = None,
    strict: bool = True,
) -> Any:
    """Decode a JSON string/bytes into an object.

    Args:
        value: Value to decode
        target_type: An optional type to decode the data into
        type_decoders: Optional sequence of type decoders
        strict: Whether type coercion rules should be strict. Setting to False enables
            a wider set of coercion rules from string to non-string types for all values

    Returns:
        An object

    Raises:
        SerializationException: If error decoding ``value``.
    """
    try:
        if target_type is Empty:
            return _msgspec_json_decoder.decode(value)
        return msgspec.json.decode(
            value,
            dec_hook=partial(
                default_deserializer,
                type_decoders=type_decoders,
            ),
            type=target_type,
            strict=strict,
        )
    except msgspec.DecodeError as msgspec_error:
        raise SerializationException(str(msgspec_error)) from msgspec_error


def encode_msgpack(value: Any, serializer: Callable[[Any], Any] | None = default_serializer) -> bytes:
    """Encode a value into MessagePack.

    Args:
        value: Value to encode
        serializer: Optional callable to support non-natively supported types

    Returns:
        MessagePack as bytes

    Raises:
        SerializationException: If error encoding ``obj``.
    """
    try:
        if serializer is None or serializer is default_serializer:
            return _msgspec_msgpack_encoder.encode(value)
        return msgspec.msgpack.encode(value, enc_hook=serializer)
    except (TypeError, msgspec.EncodeError) as msgspec_error:
        raise SerializationException(str(msgspec_error)) from msgspec_error


@overload
def decode_msgpack(value: bytes, strict: bool = ...) -> Any: ...


@overload
def decode_msgpack(value: bytes, type_decoders: TypeDecodersSequence | None, strict: bool = ...) -> Any: ...


@overload
def decode_msgpack(value: bytes, target_type: type[T], strict: bool = ...) -> T: ...


@overload
def decode_msgpack(
    value: bytes, target_type: type[T], type_decoders: TypeDecodersSequence | None, strict: bool = ...
) -> T: ...


def decode_msgpack(  # type: ignore[misc]
    value: bytes,
    target_type: type[T] | EmptyType = Empty,  # pyright: ignore[reportInvalidTypeVarUse]
    type_decoders: TypeDecodersSequence | None = None,
    strict: bool = True,
) -> Any:
    """Decode a MessagePack string/bytes into an object.

    Args:
        value: Value to decode
        target_type: An optional type to decode the data into
        type_decoders: Optional sequence of type decoders
        strict: Whether type coercion rules should be strict. Setting to False enables
            a wider set of coercion rules from string to non-string types for all values

    Returns:
        An object

    Raises:
        SerializationException: If error decoding ``value``.
    """
    try:
        if target_type is Empty:
            return _msgspec_msgpack_decoder.decode(value)
        return msgspec.msgpack.decode(
            value,
            dec_hook=partial(default_deserializer, type_decoders=type_decoders),
            type=target_type,
            strict=strict,
        )
    except msgspec.DecodeError as msgspec_error:
        raise SerializationException(str(msgspec_error)) from msgspec_error


def get_serializer(type_encoders: TypeEncodersMap | None = None) -> Serializer:
    """Get the serializer for the given type encoders."""

    if type_encoders:
        return partial(default_serializer, type_encoders=type_encoders)

    return default_serializer