Source code for advanced_alchemy.extensions.litestar.cli

from __future__ import annotations

from contextlib import suppress
from pathlib import Path
from typing import TYPE_CHECKING, Sequence, cast

from anyio import run
from click import Path as ClickPath
from click import argument, group, option
from litestar.cli._utils import LitestarGroup, console

if TYPE_CHECKING:
    from litestar import Litestar

    from advanced_alchemy.extensions.litestar.plugins import SQLAlchemyInitPlugin
    from advanced_alchemy.extensions.litestar.plugins.init.config.asyncio import SQLAlchemyAsyncConfig
    from advanced_alchemy.extensions.litestar.plugins.init.config.sync import SQLAlchemySyncConfig
    from alembic.migration import MigrationContext
    from alembic.operations.ops import MigrationScript, UpgradeOps


[docs] def get_database_migration_plugin(app: Litestar) -> SQLAlchemyInitPlugin: """Retrieve a database migration plugin from the Litestar application's plugins. This function attempts to find and return either the SQLAlchemyPlugin or SQLAlchemyInitPlugin. If neither plugin is found, it raises an ImproperlyConfiguredException. """ from advanced_alchemy.exceptions import ImproperConfigurationError from advanced_alchemy.extensions.litestar.plugins import SQLAlchemyInitPlugin with suppress(KeyError): return app.plugins.get(SQLAlchemyInitPlugin) msg = "Failed to initialize database migrations. The required plugin (SQLAlchemyPlugin or SQLAlchemyInitPlugin) is missing." raise ImproperConfigurationError(msg)
@group(cls=LitestarGroup, name="database") def database_group() -> None: """Manage SQLAlchemy database components.""" @database_group.command( name="show-current-revision", help="Shows the current revision for the database.", ) @option("--verbose", type=bool, help="Enable verbose output.", default=False, is_flag=True) def show_database_revision(app: Litestar, verbose: bool) -> None: """Show current database revision.""" from advanced_alchemy.alembic.commands import AlembicCommands console.rule("[yellow]Listing current revision[/]", align="left") config = get_database_migration_plugin(app).config sqlalchemy_config = config[0] alembic_commands = AlembicCommands(sqlalchemy_config=sqlalchemy_config) alembic_commands.current(verbose=verbose) @database_group.command( name="downgrade", help="Downgrade database to a specific revision.", ) @option("--sql", type=bool, help="Generate SQL output for offline migrations.", default=False, is_flag=True) @option( "--tag", help="an arbitrary 'tag' that can be intercepted by custom env.py scripts via the .EnvironmentContext.get_tag_argument method.", type=str, default=None, ) @option( "--no-prompt", help="Do not prompt for confirmation before downgrading.", type=bool, default=False, required=False, show_default=True, is_flag=True, ) @argument( "revision", type=str, default="-1", ) def downgrade_database(app: Litestar, revision: str, sql: bool, tag: str | None, no_prompt: bool) -> None: """Downgrade the database to the latest revision.""" from rich.prompt import Confirm from advanced_alchemy.alembic.commands import AlembicCommands console.rule("[yellow]Starting database downgrade process[/]", align="left") input_confirmed = ( True if no_prompt else Confirm.ask(f"Are you sure you want to downgrade the database to the `{revision}` revision?") ) if input_confirmed: config = get_database_migration_plugin(app).config sqlalchemy_config = config[0] alembic_commands = AlembicCommands(sqlalchemy_config=sqlalchemy_config) alembic_commands.downgrade(revision=revision, sql=sql, tag=tag) @database_group.command( name="upgrade", help="Upgrade database to a specific revision.", ) @option("--sql", type=bool, help="Generate SQL output for offline migrations.", default=False, is_flag=True) @option( "--tag", help="an arbitrary 'tag' that can be intercepted by custom env.py scripts via the .EnvironmentContext.get_tag_argument method.", type=str, default=None, ) @option( "--no-prompt", help="Do not prompt for confirmation before upgrading.", type=bool, default=False, required=False, show_default=True, is_flag=True, ) @argument( "revision", type=str, default="head", ) def upgrade_database(app: Litestar, revision: str, sql: bool, tag: str | None, no_prompt: bool) -> None: """Upgrade the database to the latest revision.""" from rich.prompt import Confirm from advanced_alchemy.alembic.commands import AlembicCommands console.rule("[yellow]Starting database upgrade process[/]", align="left") input_confirmed = ( True if no_prompt else Confirm.ask(f"[bold]Are you sure you want migrate the database to the `{revision}` revision?[/]") ) if input_confirmed: config = get_database_migration_plugin(app).config sqlalchemy_config = config[0] alembic_commands = AlembicCommands(sqlalchemy_config=sqlalchemy_config) alembic_commands.upgrade(revision=revision, sql=sql, tag=tag) @database_group.command( name="init", help="Initialize migrations for the project.", ) @argument("directory", default=None) @option("--multidb", is_flag=True, default=False, help="Support multiple databases") @option("--package", is_flag=True, default=True, help="Create `__init__.py` for created folder") @option( "--no-prompt", help="Do not prompt for confirmation before initializing.", type=bool, default=False, required=False, show_default=True, is_flag=True, ) def init_alembic(app: Litestar, directory: str | None, multidb: bool, package: bool, no_prompt: bool) -> None: """Upgrade the database to the latest revision.""" from rich.prompt import Confirm from advanced_alchemy.alembic.commands import AlembicCommands console.rule("[yellow]Initializing database migrations.", align="left") plugin = get_database_migration_plugin(app) input_confirmed = ( True if no_prompt else Confirm.ask(f"[bold]Are you sure you want initialize the project in `{directory}`?[/]") ) if input_confirmed: for config in plugin.config: directory = config.alembic_config.script_location if directory is None else directory alembic_commands = AlembicCommands(sqlalchemy_config=config) alembic_commands.init(directory=directory, multidb=multidb, package=package) @database_group.command( name="make-migrations", help="Create a new migration revision.", ) @option("-m", "--message", default=None, help="Revision message") @option("--autogenerate/--no-autogenerate", default=True, help="Automatically populate revision with detected changes") @option("--sql", is_flag=True, default=False, help="Export to `.sql` instead of writing to the database.") @option("--head", default="head", help="Specify head revision to use as base for new revision.") @option("--splice", is_flag=True, default=False, help='Allow a non-head revision as the "head" to splice onto') @option("--branch-label", default=None, help="Specify a branch label to apply to the new revision") @option("--version-path", default=None, help="Specify specific path from config for version file") @option("--rev-id", default=None, help="Specify a ID to use for revision.") @option( "--no-prompt", help="Do not prompt for a migration message.", type=bool, default=False, required=False, show_default=True, is_flag=True, ) def create_revision( app: Litestar, message: str | None, autogenerate: bool, sql: bool, head: str, splice: bool, branch_label: str | None, version_path: str | None, rev_id: str | None, no_prompt: bool, ) -> None: """Create a new database revision.""" from rich.prompt import Prompt from advanced_alchemy.alembic.commands import AlembicCommands def process_revision_directives( context: MigrationContext, # noqa: ARG001 revision: tuple[str], # noqa: ARG001 directives: list[MigrationScript], ) -> None: """Handle revision directives.""" if autogenerate and cast("UpgradeOps", directives[0].upgrade_ops).is_empty(): # Generate a revision file only if changes to the schema are detected console.rule( "[magenta]The generation of a migration file is being skipped because it would result in an empty file.", style="magenta", align="left", ) console.rule( "[magenta]More information can be found here. https://alembic.sqlalchemy.org/en/latest/autogenerate.html#what-does-autogenerate-detect-and-what-does-it-not-detect", style="magenta", align="left", ) console.rule( "[magenta]If you intend to create an empty migration file, use the --no-autogenerate option.", style="magenta", align="left", ) directives.clear() console.rule("[yellow]Starting database upgrade process[/]", align="left") if message is None: message = "autogenerated" if no_prompt else Prompt.ask("Please enter a message describing this revision") config = get_database_migration_plugin(app).config sqlalchemy_config = config[0] alembic_commands = AlembicCommands(sqlalchemy_config=sqlalchemy_config) alembic_commands.revision( message=message, autogenerate=autogenerate, sql=sql, head=head, splice=splice, branch_label=branch_label, version_path=version_path, rev_id=rev_id, process_revision_directives=process_revision_directives, # type: ignore[arg-type] ) @database_group.command( name="merge-migrations", help="Merge multiple revisions into a single new revision.", ) @option("--revisions", default="head", help="Specify head revision to use as base for new revision.") @option("-m", "--message", default=None, help="Revision message") @option("--branch-label", default=None, help="Specify a branch label to apply to the new revision") @option("--rev-id", default=None, help="Specify a ID to use for revision.") @option( "--no-prompt", help="Do not prompt for a migration message.", type=bool, default=False, required=False, show_default=True, is_flag=True, ) def merge_revisions( app: Litestar, revisions: str, message: str | None, branch_label: str | None, rev_id: str | None, no_prompt: bool, ) -> None: """Merge multiple revisions into a single new revision.""" from rich.prompt import Prompt from advanced_alchemy.alembic.commands import AlembicCommands console.rule("[yellow]Starting database upgrade process[/]", align="left") if message is None: message = "autogenerated" if no_prompt else Prompt.ask("Please enter a message describing this revision") config = get_database_migration_plugin(app).config sqlalchemy_config = config[0] alembic_commands = AlembicCommands(sqlalchemy_config=sqlalchemy_config) alembic_commands.merge(message=message, revisions=revisions, branch_label=branch_label, rev_id=rev_id) @database_group.command( name="stamp-migration", help="Mark (Stamp) a specific revision as current without applying the migrations.", ) @option( "--revision", type=str, help="Revision to stamp to", default="-1", ) @option("--sql", type=bool, help="Generate SQL output for offline migrations.", default=False, is_flag=True) @option( "--purge", type=bool, help="Delete existing records in the alembic version table before stamping.", default=False, is_flag=True, ) @option( "--tag", help="an arbitrary 'tag' that can be intercepted by custom env.py scripts via the .EnvironmentContext.get_tag_argument method.", type=str, default=None, ) @option( "--no-prompt", help="Do not prompt for confirmation.", type=bool, default=False, required=False, show_default=True, is_flag=True, ) def stamp_revision(app: Litestar, revision: str, sql: bool, tag: str | None, purge: bool, no_prompt: bool) -> None: """Create a new database revision.""" from rich.prompt import Confirm from advanced_alchemy.alembic.commands import AlembicCommands console.rule("[yellow]Stamping database revision as current[/]", align="left") input_confirmed = True if no_prompt else Confirm.ask("Are you sure you want to stamp revision as current?") if input_confirmed: config = get_database_migration_plugin(app).config sqlalchemy_config = config[0] alembic_commands = AlembicCommands(sqlalchemy_config=sqlalchemy_config) alembic_commands.stamp(sql=sql, revision=revision, tag=tag, purge=purge) @database_group.command(name="drop-all", help="Drop all tables from the database.") @option( "--no-prompt", help="Do not prompt for confirmation before upgrading.", type=bool, default=False, required=False, show_default=True, is_flag=True, ) def drop_all(app: Litestar, no_prompt: bool) -> None: from rich.prompt import Confirm from advanced_alchemy.alembic.utils import drop_all from advanced_alchemy.base import metadata_registry console.rule("[yellow]Dropping all tables from the database[/]", align="left") input_confirmed = no_prompt or Confirm.ask("[bold red]Are you sure you want to drop all tables from the database?") config = get_database_migration_plugin(app).config async def _drop_all( configs: Sequence[SQLAlchemyAsyncConfig | SQLAlchemySyncConfig], ) -> None: for config in configs: engine = config.get_engine() await drop_all(engine, config.alembic_config.version_table_name, metadata_registry.get(config.bind_key)) if input_confirmed: run( _drop_all, config, ) @database_group.command(name="dump-data", help="Dump specified tables from the database to JSON files.") @option( "--table", "table_names", help="Name of the table to dump. Multiple tables can be specified. Use '*' to dump all tables.", type=str, required=True, multiple=True, ) @option( "--dir", "dump_dir", help="Directory to save the JSON files. Defaults to WORKDIR/fixtures", type=ClickPath(path_type=Path), # type: ignore[type-var,unused-ignore] # pyright: ignore[reportCallIssue, reportUntypedFunctionDecorator, reportArgumentType] default=Path.cwd() / "fixtures", required=False, ) def dump_table_data(app: Litestar, table_names: tuple[str, ...], dump_dir: Path) -> None: from rich.prompt import Confirm all_tables = "*" in table_names if all_tables and not Confirm.ask( "[yellow bold]You have specified '*'. Are you sure you want to dump all tables from the database?", ): # user has decided not to dump all tables return console.rule("[red bold]No data was dumped.", style="red", align="left") from advanced_alchemy.alembic.utils import dump_tables # _TODO: Find a way to read from different registries from advanced_alchemy.base import metadata_registry, orm_registry configs = get_database_migration_plugin(app).config async def _dump_tables() -> None: for config in configs: target_tables = set(metadata_registry.get(config.bind_key).tables) if not all_tables: # only consider tables specified by user for table_name in set(table_names) - target_tables: console.rule( f"[red bold]Skipping table '{table_name}' because it is not available in the default registry", style="red", align="left", ) target_tables.intersection_update(table_names) else: console.rule("[yellow bold]Dumping all tables", style="yellow", align="left") models = [mapper.class_ for mapper in orm_registry.mappers if mapper.class_.__table__.name in target_tables] await dump_tables(dump_dir, config.get_session(), models) console.rule("[green bold]Data dump complete", align="left") return run(_dump_tables)