Coverage for src / ensembl / utils / database / unittestdb.py: 81%

70 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:45 +0000

1# See the NOTICE file distributed with this work for additional information 

2# regarding copyright ownership. 

3# 

4# Licensed under the Apache License, Version 2.0 (the "License"); 

5# you may not use this file except in compliance with the License. 

6# You may obtain a copy of the License at 

7# 

8# http://www.apache.org/licenses/LICENSE-2.0 

9# 

10# Unless required by applicable law or agreed to in writing, software 

11# distributed under the License is distributed on an "AS IS" BASIS, 

12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

13# See the License for the specific language governing permissions and 

14# limitations under the License. 

15"""Unit testing database handler. 

16 

17This module provides the main class to create and drop testing databases, populating them from 

18preexisting dumps (if supplied). 

19 

20Examples: 

21 

22 >>> from ensembl.utils.database import UnitTestDB 

23 >>> # For more safety use the context manager (automatically drops the database even if things go wrong): 

24 >>> with UnitTestDB("mysql://user:passwd@mysql-server:4242/", "path/to/dumps", "my_db") as test_db: 

25 >>> dbc = test_db.dbc 

26 

27 >>> # If you know what you are doing you can also control when the test_db is dropped: 

28 >>> test_db = UnitTestDB("mysql://user:passwd@mysql-server:4242/", "path/to/dumps", "my_db") 

29 >>> # You can access the database via test_db.dbc, for instance: 

30 >>> dbc = test_db.dbc 

31 >>> # At the end do not forget to drop the database 

32 >>> test_db.drop() 

33 

34""" 

35 

36from __future__ import annotations 

37 

38__all__ = [ 

39 "UnitTestDB", 

40] 

41 

42import os 

43from pathlib import Path 

44import subprocess 

45from typing import Any 

46 

47import sqlalchemy 

48from sqlalchemy import text 

49from sqlalchemy.engine import make_url 

50from sqlalchemy.schema import MetaData 

51from sqlalchemy_utils.functions import create_database, database_exists, drop_database 

52 

53from ensembl.utils import StrPath 

54from ensembl.utils.database.dbconnection import DBConnection, StrURL 

55 

56TEST_USERNAME = os.environ.get("USER", "pytestuser") 

57 

58 

59class UnitTestDB: 

60 """Creates and connects to a new test database, applying the schema and importing the data. 

61 

62 Args: 

63 server_url: URL of the server hosting the database. 

64 metadata: Use this metadata to create the schema instead of using an SQL schema file. 

65 dump_dir: Directory path with the database schema in `table.sql` (mandatory) and one TSV data 

66 file (without headers) per table following the convention `<table_name>.txt` (optional). 

67 name: Name to give to the new database. If not provided, the last directory name of `dump_dir` 

68 will be used instead. In either case, the new database name will be prefixed by the username. 

69 tmp_path: Temp dir where the test db is created if using SQLite (otherwise use current dir). 

70 

71 Attributes: 

72 dbc: Database connection handler. 

73 

74 Raises: 

75 FileNotFoundError: If `table.sql` is not found. 

76 

77 """ 

78 

79 def __init__( 

80 self, 

81 server_url: StrURL, 

82 *, 

83 dump_dir: StrPath | None = None, 

84 name: str | None = None, 

85 metadata: MetaData | None = None, 

86 tmp_path: StrPath | None = None, 

87 ) -> None: 

88 db_url = make_url(server_url) 

89 if not name: 

90 name = Path(dump_dir).name if dump_dir else "testdb" 

91 db_name = f"{TEST_USERNAME}_{name}" 

92 

93 # Add the database name to the URL 

94 if db_url.get_dialect().name == "sqlite": 94 ↛ 98line 94 didn't jump to line 98 because the condition on line 94 was always true

95 db_path = Path(tmp_path) / db_name if tmp_path else db_name 

96 db_url = db_url.set(database=f"{db_path}.db") 

97 else: 

98 db_url = db_url.set(database=db_name) 

99 # Enable "local_infile" variable for MySQL databases to allow importing data from files 

100 connect_args = {} 

101 if db_url.get_dialect().name == "mysql": 101 ↛ 102line 101 didn't jump to line 102 because the condition on line 101 was never true

102 connect_args["local_infile"] = 1 

103 # Create the database, dropping it beforehand if it already exists 

104 if database_exists(db_url): 

105 drop_database(db_url) 

106 create_database(db_url) 

107 # Establish the connection to the database, load the schema and import the data 

108 try: 

109 self.dbc = DBConnection(db_url, connect_args=connect_args, reflect=False) 

110 self._load_schema_and_data(dump_dir, metadata) 

111 except: 

112 # Make sure the database is deleted before raising the exception 

113 drop_database(db_url) 

114 raise 

115 # Update the loaded metadata information of the database 

116 self.dbc.load_metadata() 

117 

118 def _load_schema_and_data( 

119 self, dump_dir: StrPath | None = None, metadata: MetaData | None = None 

120 ) -> None: 

121 with self.dbc.begin() as conn: 

122 # Set InnoDB engine as default and disable foreign key checks for MySQL databases 

123 if self.dbc.dialect == "mysql": 123 ↛ 124line 123 didn't jump to line 124 because the condition on line 123 was never true

124 conn.execute(text("SET default_storage_engine=InnoDB")) 

125 conn.execute(text("SET FOREIGN_KEY_CHECKS=0")) 

126 

127 # Load the schema 

128 if metadata: 

129 metadata.create_all(conn) 

130 elif dump_dir: 

131 with Path(dump_dir, "table.sql").open("r") as schema: 

132 for query in "".join(schema.readlines()).split(";"): 

133 if query.strip(): 

134 conn.execute(text(query)) 

135 

136 # And import any available data for each table 

137 if dump_dir: 

138 for tsv_file in Path(dump_dir).glob("*.txt"): 

139 table = tsv_file.stem 

140 self._load_data(conn, table, tsv_file) 

141 

142 # Re-enable foreign key checks for MySQL databases 

143 if self.dbc.dialect == "mysql": 143 ↛ 144line 143 didn't jump to line 144 because the condition on line 143 was never true

144 conn.execute(text("SET FOREIGN_KEY_CHECKS=1")) 

145 

146 def __repr__(self) -> str: 

147 """Returns a string representation of this object.""" 

148 return f"{self.__class__.__name__}({self.dbc.url!r})" 

149 

150 def drop(self) -> None: 

151 """Drops the database.""" 

152 drop_database(self.dbc.url) 

153 # Ensure the connection pool is properly closed and disposed 

154 self.dbc.dispose() 

155 

156 def _load_data(self, conn: sqlalchemy.engine.Connection, table: str, src: StrPath) -> None: 

157 """Loads the table data from the given file. 

158 

159 Args: 

160 conn: Open connection to the database. 

161 table: Table name to load the data to. 

162 src: File path with the data in TSV format (without headers). 

163 

164 """ 

165 if self.dbc.dialect == "sqlite": 165 ↛ 168line 165 didn't jump to line 168 because the condition on line 165 was always true

166 # SQLite does not have an equivalent to "LOAD DATA": use its ".import" command instead 

167 subprocess.run(["sqlite3", self.dbc.db_name, ".mode tabs", f".import {src} {table}"], check=True) 

168 elif self.dbc.dialect == "postgresql": 

169 conn.execute(text(f"COPY {table} FROM '{src}'")) 

170 elif self.dbc.dialect == "sqlserver": 

171 conn.execute(text(f"BULK INSERT {table} FROM '{src}'")) 

172 else: 

173 conn.execute(text(f"LOAD DATA LOCAL INFILE '{src}' INTO TABLE {table}")) 

174 

175 def __enter__(self) -> UnitTestDB: 

176 return self 

177 

178 def __exit__(self, *args: Any) -> None: 

179 self.drop()