Source code for advanced_alchemy.operations

from __future__ import annotations

from typing import TYPE_CHECKING, Any

from sqlalchemy import ClauseElement, ColumnElement, UpdateBase
from sqlalchemy.ext.compiler import compiles

if TYPE_CHECKING:
    from typing import Literal

    from sqlalchemy.sql.compiler import StrSQLCompiler


[docs]class MergeClause(ClauseElement): __visit_name__ = "merge_clause"
[docs] def __init__(self, command: Literal["INSERT", "UPDATE", "DELETE"]) -> None: self.on_sets: dict[str, ColumnElement[Any]] = {} self.predicate: ColumnElement[bool] | None = None self.command = command
def values(self, **kwargs: ColumnElement[Any]) -> MergeClause: self.on_sets = kwargs return self def where(self, expr: ColumnElement[bool]) -> MergeClause: self.predicate = expr return self
@compiles(MergeClause) # type: ignore[no-untyped-call, misc] def visit_merge_clause(element: MergeClause, compiler: StrSQLCompiler, **kw: Any) -> str: case_predicate = "" if element.predicate is not None: case_predicate = f" AND {element.predicate._compiler_dispatch(compiler, **kw)!s}" # noqa: SLF001 if element.command == "INSERT": sets, sets_tos = list(element.on_sets), list(element.on_sets.values()) if kw.get("deterministic", False): sorted_on_sets = dict(sorted(element.on_sets.items(), key=lambda x: x[0])) sets, sets_tos = list(sorted_on_sets), list(sorted_on_sets.values()) merge_insert = ", ".join(sets) values = ", ".join(e._compiler_dispatch(compiler, **kw) for e in sets_tos) # noqa: SLF001 return f"WHEN NOT MATCHED{case_predicate} THEN {element.command} ({merge_insert}) VALUES ({values})" set_list = list(element.on_sets.items()) if kw.get("deterministic", False): set_list.sort(key=lambda x: x[0]) # merge update or merge delete merge_action = "" values = "" if element.on_sets: values = ", ".join( f"{name} = {column._compiler_dispatch(compiler, **kw)}" for name, column in set_list # noqa: SLF001 ) merge_action = f" SET {values}" return f"WHEN MATCHED{case_predicate} THEN {element.command}{merge_action}"
[docs]class Merge(UpdateBase): __visit_name__ = "merge" _bind = None inherit_cache = True
[docs] def __init__(self, into: Any, using: Any, on: Any) -> None: self.into = into self.using = using self.on = on self.clauses: list[ClauseElement] = []
def when_matched(self, operations: set[Literal["UPDATE", "DELETE", "INSERT"]]) -> MergeClause: for op in operations: self.clauses.append(clause := MergeClause(op)) return clause
@compiles(Merge) # type: ignore[no-untyped-call, misc] def visit_merge(element: Merge, compiler: StrSQLCompiler, **kw: Any) -> str: clauses = " ".join(clause._compiler_dispatch(compiler, **kw) for clause in element.clauses) # noqa: SLF001 sql_text = f"MERGE INTO {element.into} USING {element.using} ON {element.on}" if clauses: sql_text += f" {clauses}" return sql_text