Source code for advanced_alchemy.alembic.commands

from __future__ import annotations

import sys
from typing import TYPE_CHECKING, Any, Mapping, TextIO

from advanced_alchemy.config.asyncio import SQLAlchemyAsyncConfig
from alembic import command as migration_command
from alembic.config import Config as _AlembicCommandConfig
from alembic.ddl.impl import DefaultImpl

if TYPE_CHECKING:
    import os
    from argparse import Namespace
    from pathlib import Path

    from sqlalchemy import Engine
    from sqlalchemy.ext.asyncio import AsyncEngine

    from advanced_alchemy.config.sync import SQLAlchemySyncConfig
    from alembic.runtime.environment import ProcessRevisionDirectiveFn
    from alembic.script.base import Script


[docs] class AlembicSpannerImpl(DefaultImpl): """Alembic implementation for Spanner.""" __dialect__ = "spanner+spanner"
[docs] class AlembicDuckDBImpl(DefaultImpl): """Alembic implementation for DuckDB.""" __dialect__ = "duckdb"
[docs] class AlembicCommandConfig(_AlembicCommandConfig):
[docs] def __init__( self, engine: Engine | AsyncEngine, version_table_name: str, bind_key: str | None = None, file_: str | os.PathLike[str] | None = None, ini_section: str = "alembic", output_buffer: TextIO | None = None, stdout: TextIO = sys.stdout, cmd_opts: Namespace | None = None, config_args: Mapping[str, Any] | None = None, attributes: dict[str, Any] | None = None, template_directory: Path | None = None, version_table_schema: str | None = None, render_as_batch: bool = True, compare_type: bool = False, user_module_prefix: str | None = "sa.", ) -> None: """Initialize the AlembicCommandConfig. Args: engine (sqlalchemy.engine.Engine | sqlalchemy.ext.asyncio.AsyncEngine): The SQLAlchemy engine instance. version_table_name (str): The name of the version table. bind_key (str | None): The bind key for the metadata. file_ (str | os.PathLike[str] | None): The file path for the alembic configuration. ini_section (str): The ini section name. output_buffer (typing.TextIO | None): The output buffer for alembic commands. stdout (typing.TextIO): The standard output stream. cmd_opts (argparse.Namespace | None): Command line options. config_args (typing.Mapping[str, typing.Any] | None): Additional configuration arguments. attributes (dict[str, typing.Any] | None): Additional attributes for the configuration. template_directory (pathlib.Path | None): The directory for alembic templates. version_table_schema (str | None): The schema for the version table. render_as_batch (bool): Whether to render migrations as batch. compare_type (bool): Whether to compare types during migrations. user_module_prefix (str | None): The prefix for user modules. """ self.template_directory = template_directory self.bind_key = bind_key self.version_table_name = version_table_name self.version_table_pk = engine.dialect.name != "spanner+spanner" self.version_table_schema = version_table_schema self.render_as_batch = render_as_batch self.user_module_prefix = user_module_prefix self.compare_type = compare_type self.engine = engine self.db_url = engine.url.render_as_string(hide_password=False) if config_args is None: config_args = {} super().__init__(file_, ini_section, output_buffer, stdout, cmd_opts, config_args, attributes)
[docs] def get_template_directory(self) -> str: """Return the directory where Alembic setup templates are found. This method is used by the alembic ``init`` and ``list_templates`` commands. """ if self.template_directory is not None: return str(self.template_directory) return super().get_template_directory()
class AlembicCommands: def __init__(self, sqlalchemy_config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig) -> None: """Initialize the AlembicCommands. Args: sqlalchemy_config (SQLAlchemyAsyncConfig | SQLAlchemySyncConfig): The SQLAlchemy configuration. """ self.sqlalchemy_config = sqlalchemy_config self.config = self._get_alembic_command_config() def upgrade( self, revision: str = "head", sql: bool = False, tag: str | None = None, ) -> None: """Upgrade the database to a specified revision. Args: revision (str): The target revision to upgrade to. sql (bool): If True, generate SQL script instead of applying changes. tag (str | None): An optional tag to apply to the migration. """ return migration_command.upgrade(config=self.config, revision=revision, tag=tag, sql=sql) def downgrade( self, revision: str = "head", sql: bool = False, tag: str | None = None, ) -> None: """Downgrade the database to a specified revision. Args: revision (str): The target revision to downgrade to. sql (bool): If True, generate SQL script instead of applying changes. tag (str | None): An optional tag to apply to the migration. """ return migration_command.downgrade(config=self.config, revision=revision, tag=tag, sql=sql) def check(self) -> None: """Check for pending upgrade operations. This method checks if there are any pending upgrade operations that need to be applied to the database. """ return migration_command.check(config=self.config) def current(self, verbose: bool = False) -> None: """Display the current revision of the database. Args: verbose (bool): If True, display detailed information. """ return migration_command.current(self.config, verbose=verbose) def edit(self, revision: str) -> None: """Edit the revision script using the system editor. Args: revision (str): The revision identifier to edit. """ return migration_command.edit(config=self.config, rev=revision) def ensure_version(self, sql: bool = False) -> None: """Ensure the alembic version table exists. Args: sql (bool): If True, generate SQL script instead of applying changes. """ return migration_command.ensure_version(config=self.config, sql=sql) def heads(self, verbose: bool = False, resolve_dependencies: bool = False) -> None: """Show current available heads in the script directory. Args: verbose (bool): If True, display detailed information. resolve_dependencies (bool): If True, resolve dependencies between heads. """ return migration_command.heads(config=self.config, verbose=verbose, resolve_dependencies=resolve_dependencies) def history( self, rev_range: str | None = None, verbose: bool = False, indicate_current: bool = False, ) -> None: """List changeset scripts in chronological order. Args: rev_range (str | None): The revision range to display. verbose (bool): If True, display detailed information. indicate_current (bool): If True, indicate the current revision. """ return migration_command.history( config=self.config, rev_range=rev_range, verbose=verbose, indicate_current=indicate_current, ) def merge( self, revisions: str, message: str | None = None, branch_label: str | None = None, rev_id: str | None = None, ) -> Script | None: """Merge two revisions together. Args: revisions (str): The revisions to merge. message (str | None): The commit message for the merge. branch_label (str | None): The branch label for the merge. rev_id (str | None): The revision ID for the merge. Returns: Script | None: The resulting script from the merge. """ return migration_command.merge( config=self.config, revisions=revisions, message=message, branch_label=branch_label, rev_id=rev_id, ) def revision( self, message: str | None = None, autogenerate: bool = False, sql: bool = False, head: str = "head", splice: bool = False, branch_label: str | None = None, version_path: str | None = None, rev_id: str | None = None, depends_on: str | None = None, process_revision_directives: ProcessRevisionDirectiveFn | None = None, ) -> Script | list[Script | None] | None: """Create a new revision file. Args: message (str | None): The commit message for the revision. autogenerate (bool): If True, autogenerate the revision script. sql (bool): If True, generate SQL script instead of applying changes. head (str): The head revision to base the new revision on. splice (bool): If True, create a splice revision. branch_label (str | None): The branch label for the revision. version_path (str | None): The path for the version file. rev_id (str | None): The revision ID for the new revision. depends_on (str | None): The revisions this revision depends on. process_revision_directives (ProcessRevisionDirectiveFn | None): A function to process revision directives. Returns: Script | List[Script | None] | None: The resulting script(s) from the revision. """ return migration_command.revision( config=self.config, message=message, autogenerate=autogenerate, sql=sql, head=head, splice=splice, branch_label=branch_label, version_path=version_path, rev_id=rev_id, depends_on=depends_on, process_revision_directives=process_revision_directives, ) def show( self, rev: Any, ) -> None: """Show the revision(s) denoted by the given symbol. Args: rev (Any): The revision symbol to display. """ return migration_command.show(config=self.config, rev=rev) def init( self, directory: str, package: bool = False, multidb: bool = False, ) -> None: """Initialize a new scripts directory. Args: directory (str): The directory to initialize. package (bool): If True, create a package. multidb (bool): If True, initialize for multiple databases. """ template = "sync" if isinstance(self.sqlalchemy_config, SQLAlchemyAsyncConfig): template = "asyncio" if multidb: template = f"{template}-multidb" msg = "Multi database Alembic configurations are not currently supported." raise NotImplementedError(msg) return migration_command.init( config=self.config, directory=directory, template=template, package=package, ) def list_templates(self) -> None: """List available templates. This method lists all available templates for alembic initialization. """ return migration_command.list_templates(config=self.config) def stamp( self, revision: str, sql: bool = False, tag: str | None = None, purge: bool = False, ) -> None: """Stamp the revision table with the given revision. Args: revision (str): The revision to stamp. sql (bool): If True, generate SQL script instead of applying changes. tag (str | None): An optional tag to apply to the migration. purge (bool): If True, purge the revision history. """ return migration_command.stamp(config=self.config, revision=revision, sql=sql, tag=tag, purge=purge) def _get_alembic_command_config(self) -> AlembicCommandConfig: """Get the Alembic command configuration. Returns: AlembicCommandConfig: The configuration for Alembic commands. """ kwargs: dict[str, Any] = {} if self.sqlalchemy_config.alembic_config.script_config: kwargs["file_"] = self.sqlalchemy_config.alembic_config.script_config if self.sqlalchemy_config.alembic_config.template_path: kwargs["template_directory"] = self.sqlalchemy_config.alembic_config.template_path kwargs.update( { "engine": self.sqlalchemy_config.get_engine(), "version_table_name": self.sqlalchemy_config.alembic_config.version_table_name, }, ) self.config = AlembicCommandConfig(**kwargs) self.config.set_main_option("script_location", self.sqlalchemy_config.alembic_config.script_location) return self.config