Coverage for src / ensembl / utils / database / unittestdb.py: 81%
70 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:45 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:45 +0000
1# See the NOTICE file distributed with this work for additional information
2# regarding copyright ownership.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15"""Unit testing database handler.
17This module provides the main class to create and drop testing databases, populating them from
18preexisting dumps (if supplied).
20Examples:
22 >>> from ensembl.utils.database import UnitTestDB
23 >>> # For more safety use the context manager (automatically drops the database even if things go wrong):
24 >>> with UnitTestDB("mysql://user:passwd@mysql-server:4242/", "path/to/dumps", "my_db") as test_db:
25 >>> dbc = test_db.dbc
27 >>> # If you know what you are doing you can also control when the test_db is dropped:
28 >>> test_db = UnitTestDB("mysql://user:passwd@mysql-server:4242/", "path/to/dumps", "my_db")
29 >>> # You can access the database via test_db.dbc, for instance:
30 >>> dbc = test_db.dbc
31 >>> # At the end do not forget to drop the database
32 >>> test_db.drop()
34"""
36from __future__ import annotations
38__all__ = [
39 "UnitTestDB",
40]
42import os
43from pathlib import Path
44import subprocess
45from typing import Any
47import sqlalchemy
48from sqlalchemy import text
49from sqlalchemy.engine import make_url
50from sqlalchemy.schema import MetaData
51from sqlalchemy_utils.functions import create_database, database_exists, drop_database
53from ensembl.utils import StrPath
54from ensembl.utils.database.dbconnection import DBConnection, StrURL
56TEST_USERNAME = os.environ.get("USER", "pytestuser")
59class UnitTestDB:
60 """Creates and connects to a new test database, applying the schema and importing the data.
62 Args:
63 server_url: URL of the server hosting the database.
64 metadata: Use this metadata to create the schema instead of using an SQL schema file.
65 dump_dir: Directory path with the database schema in `table.sql` (mandatory) and one TSV data
66 file (without headers) per table following the convention `<table_name>.txt` (optional).
67 name: Name to give to the new database. If not provided, the last directory name of `dump_dir`
68 will be used instead. In either case, the new database name will be prefixed by the username.
69 tmp_path: Temp dir where the test db is created if using SQLite (otherwise use current dir).
71 Attributes:
72 dbc: Database connection handler.
74 Raises:
75 FileNotFoundError: If `table.sql` is not found.
77 """
79 def __init__(
80 self,
81 server_url: StrURL,
82 *,
83 dump_dir: StrPath | None = None,
84 name: str | None = None,
85 metadata: MetaData | None = None,
86 tmp_path: StrPath | None = None,
87 ) -> None:
88 db_url = make_url(server_url)
89 if not name:
90 name = Path(dump_dir).name if dump_dir else "testdb"
91 db_name = f"{TEST_USERNAME}_{name}"
93 # Add the database name to the URL
94 if db_url.get_dialect().name == "sqlite": 94 ↛ 98line 94 didn't jump to line 98 because the condition on line 94 was always true
95 db_path = Path(tmp_path) / db_name if tmp_path else db_name
96 db_url = db_url.set(database=f"{db_path}.db")
97 else:
98 db_url = db_url.set(database=db_name)
99 # Enable "local_infile" variable for MySQL databases to allow importing data from files
100 connect_args = {}
101 if db_url.get_dialect().name == "mysql": 101 ↛ 102line 101 didn't jump to line 102 because the condition on line 101 was never true
102 connect_args["local_infile"] = 1
103 # Create the database, dropping it beforehand if it already exists
104 if database_exists(db_url):
105 drop_database(db_url)
106 create_database(db_url)
107 # Establish the connection to the database, load the schema and import the data
108 try:
109 self.dbc = DBConnection(db_url, connect_args=connect_args, reflect=False)
110 self._load_schema_and_data(dump_dir, metadata)
111 except:
112 # Make sure the database is deleted before raising the exception
113 drop_database(db_url)
114 raise
115 # Update the loaded metadata information of the database
116 self.dbc.load_metadata()
118 def _load_schema_and_data(
119 self, dump_dir: StrPath | None = None, metadata: MetaData | None = None
120 ) -> None:
121 with self.dbc.begin() as conn:
122 # Set InnoDB engine as default and disable foreign key checks for MySQL databases
123 if self.dbc.dialect == "mysql": 123 ↛ 124line 123 didn't jump to line 124 because the condition on line 123 was never true
124 conn.execute(text("SET default_storage_engine=InnoDB"))
125 conn.execute(text("SET FOREIGN_KEY_CHECKS=0"))
127 # Load the schema
128 if metadata:
129 metadata.create_all(conn)
130 elif dump_dir:
131 with Path(dump_dir, "table.sql").open("r") as schema:
132 for query in "".join(schema.readlines()).split(";"):
133 if query.strip():
134 conn.execute(text(query))
136 # And import any available data for each table
137 if dump_dir:
138 for tsv_file in Path(dump_dir).glob("*.txt"):
139 table = tsv_file.stem
140 self._load_data(conn, table, tsv_file)
142 # Re-enable foreign key checks for MySQL databases
143 if self.dbc.dialect == "mysql": 143 ↛ 144line 143 didn't jump to line 144 because the condition on line 143 was never true
144 conn.execute(text("SET FOREIGN_KEY_CHECKS=1"))
146 def __repr__(self) -> str:
147 """Returns a string representation of this object."""
148 return f"{self.__class__.__name__}({self.dbc.url!r})"
150 def drop(self) -> None:
151 """Drops the database."""
152 drop_database(self.dbc.url)
153 # Ensure the connection pool is properly closed and disposed
154 self.dbc.dispose()
156 def _load_data(self, conn: sqlalchemy.engine.Connection, table: str, src: StrPath) -> None:
157 """Loads the table data from the given file.
159 Args:
160 conn: Open connection to the database.
161 table: Table name to load the data to.
162 src: File path with the data in TSV format (without headers).
164 """
165 if self.dbc.dialect == "sqlite": 165 ↛ 168line 165 didn't jump to line 168 because the condition on line 165 was always true
166 # SQLite does not have an equivalent to "LOAD DATA": use its ".import" command instead
167 subprocess.run(["sqlite3", self.dbc.db_name, ".mode tabs", f".import {src} {table}"], check=True)
168 elif self.dbc.dialect == "postgresql":
169 conn.execute(text(f"COPY {table} FROM '{src}'"))
170 elif self.dbc.dialect == "sqlserver":
171 conn.execute(text(f"BULK INSERT {table} FROM '{src}'"))
172 else:
173 conn.execute(text(f"LOAD DATA LOCAL INFILE '{src}' INTO TABLE {table}"))
175 def __enter__(self) -> UnitTestDB:
176 return self
178 def __exit__(self, *args: Any) -> None:
179 self.drop()