Coverage for src/ensembl/utils/database/unittestdb.py: 81%
70 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-06 14:10 +0000
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-06 14:10 +0000
1# See the NOTICE file distributed with this work for additional information
2# regarding copyright ownership.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15"""Unit testing database handler.
17This module provides the main class to create and drop testing databases, populating them from
18preexisting dumps (if supplied).
20Examples:
22 >>> from ensembl.utils.database import UnitTestDB
23 >>> # For more safety use the context manager (automatically drops the database even if things go wrong):
24 >>> with UnitTestDB("mysql://user:passwd@mysql-server:4242/", "path/to/dumps", "my_db") as test_db:
25 >>> dbc = test_db.dbc
27 >>> # If you know what you are doing you can also control when the test_db is dropped:
28 >>> test_db = UnitTestDB("mysql://user:passwd@mysql-server:4242/", "path/to/dumps", "my_db")
29 >>> # You can access the database via test_db.dbc, for instance:
30 >>> dbc = test_db.dbc
31 >>> # At the end do not forget to drop the database
32 >>> test_db.drop()
34"""
36from __future__ import annotations
38__all__ = [
39 "UnitTestDB",
40]
42import os
43from pathlib import Path
44import subprocess
45from typing import Any
47import sqlalchemy
48from sqlalchemy import text
49from sqlalchemy.engine import make_url
50from sqlalchemy.schema import MetaData
51from sqlalchemy_utils.functions import create_database, database_exists, drop_database
53from ensembl.utils import StrPath
54from ensembl.utils.database.dbconnection import DBConnection, StrURL
57TEST_USERNAME = os.environ.get("USER", "pytestuser")
60class UnitTestDB:
61 """Creates and connects to a new test database, applying the schema and importing the data.
63 Args:
64 server_url: URL of the server hosting the database.
65 metadata: Use this metadata to create the schema instead of using an SQL schema file.
66 dump_dir: Directory path with the database schema in `table.sql` (mandatory) and one TSV data
67 file (without headers) per table following the convention `<table_name>.txt` (optional).
68 name: Name to give to the new database. If not provided, the last directory name of `dump_dir`
69 will be used instead. In either case, the new database name will be prefixed by the username.
70 tmp_path: Temp dir where the test db is created if using SQLite (otherwise use current dir).
72 Attributes:
73 dbc: Database connection handler.
75 Raises:
76 FileNotFoundError: If `table.sql` is not found.
78 """
80 def __init__(
81 self,
82 server_url: StrURL,
83 *,
84 dump_dir: StrPath | None = None,
85 name: str | None = None,
86 metadata: MetaData | None = None,
87 tmp_path: StrPath | None = None,
88 ) -> None:
89 db_url = make_url(server_url)
90 if not name:
91 name = Path(dump_dir).name if dump_dir else "testdb"
92 db_name = f"{TEST_USERNAME}_{name}"
94 # Add the database name to the URL
95 if db_url.get_dialect().name == "sqlite": 95 ↛ 99line 95 didn't jump to line 99 because the condition on line 95 was always true
96 db_path = Path(tmp_path) / db_name if tmp_path else db_name
97 db_url = db_url.set(database=f"{db_path}.db")
98 else:
99 db_url = db_url.set(database=db_name)
100 # Enable "local_infile" variable for MySQL databases to allow importing data from files
101 connect_args = {}
102 if db_url.get_dialect().name == "mysql": 102 ↛ 103line 102 didn't jump to line 103 because the condition on line 102 was never true
103 connect_args["local_infile"] = 1
104 # Create the database, dropping it beforehand if it already exists
105 if database_exists(db_url):
106 drop_database(db_url)
107 create_database(db_url)
108 # Establish the connection to the database, load the schema and import the data
109 try:
110 self.dbc = DBConnection(db_url, connect_args=connect_args, reflect=False)
111 self._load_schema_and_data(dump_dir, metadata)
112 except:
113 # Make sure the database is deleted before raising the exception
114 drop_database(db_url)
115 raise
116 # Update the loaded metadata information of the database
117 self.dbc.load_metadata()
119 def _load_schema_and_data(
120 self, dump_dir: StrPath | None = None, metadata: MetaData | None = None
121 ) -> None:
122 with self.dbc.begin() as conn:
123 # Set InnoDB engine as default and disable foreign key checks for MySQL databases
124 if self.dbc.dialect == "mysql": 124 ↛ 125line 124 didn't jump to line 125 because the condition on line 124 was never true
125 conn.execute(text("SET default_storage_engine=InnoDB"))
126 conn.execute(text("SET FOREIGN_KEY_CHECKS=0"))
128 # Load the schema
129 if metadata:
130 metadata.create_all(conn)
131 elif dump_dir:
132 with Path(dump_dir, "table.sql").open("r") as schema:
133 for query in "".join(schema.readlines()).split(";"):
134 if query.strip():
135 conn.execute(text(query))
137 # And import any available data for each table
138 if dump_dir:
139 for tsv_file in Path(dump_dir).glob("*.txt"):
140 table = tsv_file.stem
141 self._load_data(conn, table, tsv_file)
143 # Re-enable foreign key checks for MySQL databases
144 if self.dbc.dialect == "mysql": 144 ↛ 145line 144 didn't jump to line 145 because the condition on line 144 was never true
145 conn.execute(text("SET FOREIGN_KEY_CHECKS=1"))
147 def __repr__(self) -> str:
148 """Returns a string representation of this object."""
149 return f"{self.__class__.__name__}({self.dbc.url!r})"
151 def drop(self) -> None:
152 """Drops the database."""
153 drop_database(self.dbc.url)
154 # Ensure the connection pool is properly closed and disposed
155 self.dbc.dispose()
157 def _load_data(self, conn: sqlalchemy.engine.Connection, table: str, src: StrPath) -> None:
158 """Loads the table data from the given file.
160 Args:
161 conn: Open connection to the database.
162 table: Table name to load the data to.
163 src: File path with the data in TSV format (without headers).
165 """
166 if self.dbc.dialect == "sqlite": 166 ↛ 169line 166 didn't jump to line 169 because the condition on line 166 was always true
167 # SQLite does not have an equivalent to "LOAD DATA": use its ".import" command instead
168 subprocess.run(["sqlite3", self.dbc.db_name, ".mode tabs", f".import {src} {table}"], check=True)
169 elif self.dbc.dialect == "postgresql":
170 conn.execute(text(f"COPY {table} FROM '{src}'"))
171 elif self.dbc.dialect == "sqlserver":
172 conn.execute(text(f"BULK INSERT {table} FROM '{src}'"))
173 else:
174 conn.execute(text(f"LOAD DATA LOCAL INFILE '{src}' INTO TABLE {table}"))
176 def __enter__(self) -> UnitTestDB:
177 return self
179 def __exit__(self, *args: Any) -> None:
180 self.drop()