Source code for advanced_alchemy.alembic.commands

import inspect  # Added import
import sys
from typing import TYPE_CHECKING, Any, Optional, TextIO, Union

from alembic.config import Config as _AlembicCommandConfig
from alembic.ddl.impl import DefaultImpl

from advanced_alchemy.config.asyncio import SQLAlchemyAsyncConfig
from advanced_alchemy.exceptions import ImproperConfigurationError
from alembic import command as migration_command

if TYPE_CHECKING:
    import os
    from argparse import Namespace
    from collections.abc import Mapping
    from pathlib import Path

    from alembic.runtime.environment import ProcessRevisionDirectiveFn
    from alembic.script.base import Script
    from sqlalchemy import Engine
    from sqlalchemy.ext.asyncio import AsyncEngine

    from advanced_alchemy.config.sync import SQLAlchemySyncConfig


[docs] class AlembicSpannerImpl(DefaultImpl): """Alembic implementation for Spanner.""" __dialect__ = "spanner+spanner"
[docs] class AlembicDuckDBImpl(DefaultImpl): """Alembic implementation for DuckDB.""" __dialect__ = "duckdb"
[docs] class AlembicCommandConfig(_AlembicCommandConfig):
[docs] def __init__( self, engine: "Union[Engine, AsyncEngine]", version_table_name: str, bind_key: "Optional[str]" = None, file_: "Union[str, os.PathLike[str], None]" = None, toml_file: "Union[str, os.PathLike[str], None]" = None, ini_section: str = "alembic", output_buffer: "Optional[TextIO]" = None, stdout: "TextIO" = sys.stdout, cmd_opts: "Optional[Namespace]" = None, config_args: "Optional[Mapping[str, Any]]" = None, attributes: "Optional[dict[str, Any]]" = None, template_directory: "Optional[Path]" = None, version_table_schema: "Optional[str]" = None, render_as_batch: bool = True, compare_type: bool = False, user_module_prefix: "Optional[str]" = "sa.", ) -> None: """Initialize the AlembicCommandConfig. Args: engine (sqlalchemy.engine.Engine | sqlalchemy.ext.asyncio.AsyncEngine): The SQLAlchemy engine instance. version_table_name (str): The name of the version table. bind_key (str | None): The bind key for the metadata. file_ (str | os.PathLike[str] | None): The file path for the alembic .ini configuration. toml_file (str | os.PathLike[str] | None): The file path for the alembic pyproject.toml configuration. ini_section (str): The ini section name. output_buffer (typing.TextIO | None): The output buffer for alembic commands. stdout (typing.TextIO): The standard output stream. cmd_opts (argparse.Namespace | None): Command line options. config_args (typing.Mapping[str, typing.Any] | None): Additional configuration arguments. attributes (dict[str, typing.Any] | None): Additional attributes for the configuration. template_directory (pathlib.Path | None): The directory for alembic templates. version_table_schema (str | None): The schema for the version table. render_as_batch (bool): Whether to render migrations as batch. compare_type (bool): Whether to compare types during migrations. user_module_prefix (str | None): The prefix for user modules. """ self.template_directory = template_directory self.bind_key = bind_key self.version_table_name = version_table_name self.version_table_pk = engine.dialect.name != "spanner+spanner" self.version_table_schema = version_table_schema self.render_as_batch = render_as_batch self.user_module_prefix = user_module_prefix self.compare_type = compare_type self.engine = engine self.db_url = engine.url.render_as_string(hide_password=False) _config_args = {} if config_args is None else dict(config_args) # Prepare kwargs for super().__init__ super_init_kwargs: dict[str, Any] = { "file_": file_, "ini_section": ini_section, "output_buffer": output_buffer, "stdout": stdout, "cmd_opts": cmd_opts, "config_args": _config_args, # Pass the mutable copy "attributes": attributes, } # Inspect the parent class __init__ for toml_file parameter parent_init_sig = inspect.signature(super().__init__) if "toml_file" in parent_init_sig.parameters: super_init_kwargs["toml_file"] = toml_file elif toml_file is not None: msg = ( "The 'toml_file' parameter is not supported by your current Alembic version. " "Please upgrade Alembic to 1.16.0 or later to use this feature, " "or remove the 'toml_file' argument from AlembicCommandConfig." ) raise ImproperConfigurationError(msg) super().__init__(**super_init_kwargs)
[docs] def get_template_directory(self) -> str: """Return the directory where Alembic setup templates are found. This method is used by the alembic ``init`` and ``list_templates`` commands. """ if self.template_directory is not None: return str(self.template_directory) return super().get_template_directory()
class AlembicCommands: def __init__(self, sqlalchemy_config: "Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]") -> None: """Initialize the AlembicCommands. Args: sqlalchemy_config (SQLAlchemyAsyncConfig | SQLAlchemySyncConfig): The SQLAlchemy configuration. """ self.sqlalchemy_config = sqlalchemy_config self.config = self._get_alembic_command_config() def upgrade( self, revision: str = "head", sql: bool = False, tag: "Optional[str]" = None, ) -> None: """Upgrade the database to a specified revision. Args: revision (str): The target revision to upgrade to. sql (bool): If True, generate SQL script instead of applying changes. tag (str | None): An optional tag to apply to the migration. """ return migration_command.upgrade(config=self.config, revision=revision, tag=tag, sql=sql) def downgrade( self, revision: str = "head", sql: bool = False, tag: "Optional[str]" = None, ) -> None: """Downgrade the database to a specified revision. Args: revision (str): The target revision to downgrade to. sql (bool): If True, generate SQL script instead of applying changes. tag (str | None): An optional tag to apply to the migration. """ return migration_command.downgrade(config=self.config, revision=revision, tag=tag, sql=sql) def check(self) -> None: """Check for pending upgrade operations. This method checks if there are any pending upgrade operations that need to be applied to the database. """ return migration_command.check(config=self.config) def current(self, verbose: bool = False) -> None: """Display the current revision of the database. Args: verbose (bool): If True, display detailed information. """ return migration_command.current(self.config, verbose=verbose) def edit(self, revision: str) -> None: """Edit the revision script using the system editor. Args: revision (str): The revision identifier to edit. """ return migration_command.edit(config=self.config, rev=revision) def ensure_version(self, sql: bool = False) -> None: """Ensure the alembic version table exists. Args: sql (bool): If True, generate SQL script instead of applying changes. """ return migration_command.ensure_version(config=self.config, sql=sql) def heads(self, verbose: bool = False, resolve_dependencies: bool = False) -> None: """Show current available heads in the script directory. Args: verbose (bool): If True, display detailed information. resolve_dependencies (bool): If True, resolve dependencies between heads. """ return migration_command.heads(config=self.config, verbose=verbose, resolve_dependencies=resolve_dependencies) def history( self, rev_range: "Optional[str]" = None, verbose: bool = False, indicate_current: bool = False, ) -> None: """List changeset scripts in chronological order. Args: rev_range (str | None): The revision range to display. verbose (bool): If True, display detailed information. indicate_current (bool): If True, indicate the current revision. """ return migration_command.history( config=self.config, rev_range=rev_range, verbose=verbose, indicate_current=indicate_current, ) def merge( self, revisions: str, message: "Optional[str]" = None, branch_label: "Optional[str]" = None, rev_id: "Optional[str]" = None, ) -> "Union[Script, None]": """Merge two revisions together. Args: revisions (str): The revisions to merge. message (str | None): The commit message for the merge. branch_label (str | None): The branch label for the merge. rev_id (str | None): The revision ID for the merge. Returns: Script | None: The resulting script from the merge. """ return migration_command.merge( config=self.config, revisions=revisions, message=message, branch_label=branch_label, rev_id=rev_id, ) def revision( self, message: "Optional[str]" = None, autogenerate: bool = False, sql: bool = False, head: str = "head", splice: bool = False, branch_label: "Optional[str]" = None, version_path: "Optional[str]" = None, rev_id: "Optional[str]" = None, depends_on: "Optional[str]" = None, process_revision_directives: "Optional[ProcessRevisionDirectiveFn]" = None, ) -> "Union[Script, list[Optional[Script]], None]": """Create a new revision file. Args: message (str | None): The commit message for the revision. autogenerate (bool): If True, autogenerate the revision script. sql (bool): If True, generate SQL script instead of applying changes. head (str): The head revision to base the new revision on. splice (bool): If True, create a splice revision. branch_label (str | None): The branch label for the revision. version_path (str | None): The path for the version file. rev_id (str | None): The revision ID for the new revision. depends_on (str | None): The revisions this revision depends on. process_revision_directives (ProcessRevisionDirectiveFn | None): A function to process revision directives. Returns: Script | List[Script | None] | None: The resulting script(s) from the revision. """ return migration_command.revision( config=self.config, message=message, autogenerate=autogenerate, sql=sql, head=head, splice=splice, branch_label=branch_label, version_path=version_path, rev_id=rev_id, depends_on=depends_on, process_revision_directives=process_revision_directives, ) def show( self, rev: Any, ) -> None: """Show the revision(s) denoted by the given symbol. Args: rev (Any): The revision symbol to display. """ return migration_command.show(config=self.config, rev=rev) def init( self, directory: str, package: bool = False, multidb: bool = False, ) -> None: """Initialize a new scripts directory. Args: directory (str): The directory to initialize. package (bool): If True, create a package. multidb (bool): If True, initialize for multiple databases. """ template = "sync" if isinstance(self.sqlalchemy_config, SQLAlchemyAsyncConfig): template = "asyncio" if multidb: template = f"{template}-multidb" msg = "Multi database Alembic configurations are not currently supported." raise NotImplementedError(msg) return migration_command.init( config=self.config, directory=directory, template=template, package=package, ) def list_templates(self) -> None: """List available templates. This method lists all available templates for alembic initialization. """ return migration_command.list_templates(config=self.config) def stamp( self, revision: str, sql: bool = False, tag: "Optional[str]" = None, purge: bool = False, ) -> None: """Stamp the revision table with the given revision. Args: revision (str): The revision to stamp. sql (bool): If True, generate SQL script instead of applying changes. tag (str | None): An optional tag to apply to the migration. purge (bool): If True, purge the revision history. """ return migration_command.stamp(config=self.config, revision=revision, sql=sql, tag=tag, purge=purge) def _get_alembic_command_config(self) -> "AlembicCommandConfig": """Get the Alembic command configuration. Returns: AlembicCommandConfig: The configuration for Alembic commands. """ kwargs: dict[str, Any] = {} if self.sqlalchemy_config.alembic_config.toml_file: kwargs["toml_file"] = self.sqlalchemy_config.alembic_config.toml_file if self.sqlalchemy_config.alembic_config.script_config: kwargs["file_"] = self.sqlalchemy_config.alembic_config.script_config if self.sqlalchemy_config.alembic_config.template_path: kwargs["template_directory"] = self.sqlalchemy_config.alembic_config.template_path kwargs.update( { "engine": self.sqlalchemy_config.get_engine(), "version_table_name": self.sqlalchemy_config.alembic_config.version_table_name, }, ) self.config = AlembicCommandConfig(**kwargs) self.config.set_main_option("script_location", self.sqlalchemy_config.alembic_config.script_location) return self.config