Source code for ensembl.utils.database.unittestdb

# See the NOTICE file distributed with this work for additional information
# regarding copyright ownership.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit testing database handler.

This module provides the main class to create and drop testing databases, populating them from
preexisting dumps (if supplied).

Examples:

    >>> from ensembl.utils.database import UnitTestDB
    >>> # For more safety use the context manager (automatically drops the database even if things go wrong):
    >>> with UnitTestDB("mysql://user:passwd@mysql-server:4242/", "path/to/dumps", "my_db") as test_db:
    >>>    dbc = test_db.dbc

    >>> # If you know what you are doing you can also control when the test_db is dropped:
    >>> test_db = UnitTestDB("mysql://user:passwd@mysql-server:4242/", "path/to/dumps", "my_db")
    >>> # You can access the database via test_db.dbc, for instance:
    >>> dbc = test_db.dbc
    >>> # At the end do not forget to drop the database
    >>> test_db.drop()

"""

from __future__ import annotations

__all__ = [
    "UnitTestDB",
]

import os
from pathlib import Path
import subprocess
from typing import Any

import sqlalchemy
from sqlalchemy import text
from sqlalchemy.engine import make_url
from sqlalchemy.schema import MetaData
from sqlalchemy_utils.functions import create_database, database_exists, drop_database

from ensembl.utils import StrPath
from ensembl.utils.database.dbconnection import DBConnection, StrURL

TEST_USERNAME = os.environ.get("USER", "pytestuser")


[docs] class UnitTestDB: """Creates and connects to a new test database, applying the schema and importing the data. Args: server_url: URL of the server hosting the database. metadata: Use this metadata to create the schema instead of using an SQL schema file. dump_dir: Directory path with the database schema in `table.sql` (mandatory) and one TSV data file (without headers) per table following the convention `<table_name>.txt` (optional). name: Name to give to the new database. If not provided, the last directory name of `dump_dir` will be used instead. In either case, the new database name will be prefixed by the username. tmp_path: Temp dir where the test db is created if using SQLite (otherwise use current dir). Attributes: dbc: Database connection handler. Raises: FileNotFoundError: If `table.sql` is not found. """ def __init__( self, server_url: StrURL, *, dump_dir: StrPath | None = None, name: str | None = None, metadata: MetaData | None = None, tmp_path: StrPath | None = None, ) -> None: db_url = make_url(server_url) if not name: name = Path(dump_dir).name if dump_dir else "testdb" db_name = f"{TEST_USERNAME}_{name}" # Add the database name to the URL if db_url.get_dialect().name == "sqlite": db_path = Path(tmp_path) / db_name if tmp_path else db_name db_url = db_url.set(database=f"{db_path}.db") else: db_url = db_url.set(database=db_name) # Enable "local_infile" variable for MySQL databases to allow importing data from files connect_args = {} if db_url.get_dialect().name == "mysql": connect_args["local_infile"] = 1 # Create the database, dropping it beforehand if it already exists if database_exists(db_url): drop_database(db_url) create_database(db_url) # Establish the connection to the database, load the schema and import the data try: self.dbc = DBConnection(db_url, connect_args=connect_args, reflect=False) self._load_schema_and_data(dump_dir, metadata) except: # Make sure the database is deleted before raising the exception drop_database(db_url) raise # Update the loaded metadata information of the database self.dbc.load_metadata() def _load_schema_and_data( self, dump_dir: StrPath | None = None, metadata: MetaData | None = None ) -> None: with self.dbc.begin() as conn: # Set InnoDB engine as default and disable foreign key checks for MySQL databases if self.dbc.dialect == "mysql": conn.execute(text("SET default_storage_engine=InnoDB")) conn.execute(text("SET FOREIGN_KEY_CHECKS=0")) # Load the schema if metadata: metadata.create_all(conn) elif dump_dir: with Path(dump_dir, "table.sql").open("r") as schema: for query in "".join(schema.readlines()).split(";"): if query.strip(): conn.execute(text(query)) # And import any available data for each table if dump_dir: for tsv_file in Path(dump_dir).glob("*.txt"): table = tsv_file.stem self._load_data(conn, table, tsv_file) # Re-enable foreign key checks for MySQL databases if self.dbc.dialect == "mysql": conn.execute(text("SET FOREIGN_KEY_CHECKS=1"))
[docs] def __repr__(self) -> str: """Returns a string representation of this object.""" return f"{self.__class__.__name__}({self.dbc.url!r})"
[docs] def drop(self) -> None: """Drops the database.""" drop_database(self.dbc.url) # Ensure the connection pool is properly closed and disposed self.dbc.dispose()
def _load_data(self, conn: sqlalchemy.engine.Connection, table: str, src: StrPath) -> None: """Loads the table data from the given file. Args: conn: Open connection to the database. table: Table name to load the data to. src: File path with the data in TSV format (without headers). """ if self.dbc.dialect == "sqlite": # SQLite does not have an equivalent to "LOAD DATA": use its ".import" command instead subprocess.run(["sqlite3", self.dbc.db_name, ".mode tabs", f".import {src} {table}"], check=True) elif self.dbc.dialect == "postgresql": conn.execute(text(f"COPY {table} FROM '{src}'")) elif self.dbc.dialect == "sqlserver": conn.execute(text(f"BULK INSERT {table} FROM '{src}'")) else: conn.execute(text(f"LOAD DATA LOCAL INFILE '{src}' INTO TABLE {table}")) def __enter__(self) -> UnitTestDB: return self def __exit__(self, *args: Any) -> None: self.drop()