Source code for advanced_alchemy.types.guid

from __future__ import annotations

from base64 import b64decode
from importlib.util import find_spec
from typing import TYPE_CHECKING, Any, cast
from uuid import UUID

from sqlalchemy.dialects.mssql import UNIQUEIDENTIFIER as MSSQL_UNIQUEIDENTIFIER
from sqlalchemy.dialects.oracle import RAW as ORA_RAW
from sqlalchemy.dialects.postgresql import UUID as PG_UUID
from sqlalchemy.types import BINARY, CHAR, TypeDecorator
from typing_extensions import Buffer

if TYPE_CHECKING:
    from sqlalchemy.engine import Dialect

UUID_UTILS_INSTALLED = find_spec("uuid_utils")
NANOID_INSTALLED = find_spec("fastnanoid")


[docs]class GUID(TypeDecorator[UUID]): """Platform-independent GUID type. Uses PostgreSQL's UUID type (Postgres, DuckDB, Cockroach), MSSQL's UNIQUEIDENTIFIER type, Oracle's RAW(16) type, otherwise uses BINARY(16) or CHAR(32), storing as stringified hex values. Will accept stringified UUIDs as a hexstring or an actual UUID """ impl = BINARY(16) cache_ok = True @property def python_type(self) -> type[UUID]: return UUID
[docs] def __init__(self, *args: Any, binary: bool = True, **kwargs: Any) -> None: self.binary = binary
[docs] def load_dialect_impl(self, dialect: Dialect) -> Any: if dialect.name in {"postgresql", "duckdb", "cockroachdb"}: return dialect.type_descriptor(PG_UUID()) if dialect.name == "oracle": return dialect.type_descriptor(ORA_RAW(16)) if dialect.name == "mssql": return dialect.type_descriptor(MSSQL_UNIQUEIDENTIFIER()) if self.binary: return dialect.type_descriptor(BINARY(16)) return dialect.type_descriptor(CHAR(32))
[docs] def process_bind_param( self, value: bytes | str | UUID | None, dialect: Dialect, ) -> bytes | str | None: if value is None: return value if dialect.name in {"postgresql", "duckdb", "cockroachdb", "mssql"}: return str(value) value = self.to_uuid(value) if value is None: return value if dialect.name in {"oracle", "spanner+spanner"}: return value.bytes return value.bytes if self.binary else value.hex
[docs] def process_result_value( self, value: bytes | str | UUID | None, dialect: Dialect, ) -> UUID | None: if value is None: return value if value.__class__.__name__ == "UUID": return cast("UUID", value) if dialect.name == "spanner+spanner": return UUID(bytes=b64decode(cast("str | Buffer", value))) if self.binary: return UUID(bytes=cast("bytes", value)) return UUID(hex=cast("str", value))
@staticmethod def to_uuid(value: Any) -> UUID | None: if value.__class__.__name__ == "UUID" or value is None: return cast("UUID | None", value) try: value = UUID(hex=value) except (TypeError, ValueError): value = UUID(bytes=value) return cast("UUID | None", value)
[docs] def compare_values(self, x: Any, y: Any) -> bool: """Compare two values for equality.""" if x.__class__.__name__ == "UUID" and y.__class__.__name__ == "UUID": return cast("bool", x.bytes == y.bytes) return cast("bool", x == y)