Coverage for src/ensembl/utils/database/unittestdb.py: 81%

70 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-11-06 14:10 +0000

1# See the NOTICE file distributed with this work for additional information 

2# regarding copyright ownership. 

3# 

4# Licensed under the Apache License, Version 2.0 (the "License"); 

5# you may not use this file except in compliance with the License. 

6# You may obtain a copy of the License at 

7# 

8# http://www.apache.org/licenses/LICENSE-2.0 

9# 

10# Unless required by applicable law or agreed to in writing, software 

11# distributed under the License is distributed on an "AS IS" BASIS, 

12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

13# See the License for the specific language governing permissions and 

14# limitations under the License. 

15"""Unit testing database handler. 

16 

17This module provides the main class to create and drop testing databases, populating them from 

18preexisting dumps (if supplied). 

19 

20Examples: 

21 

22 >>> from ensembl.utils.database import UnitTestDB 

23 >>> # For more safety use the context manager (automatically drops the database even if things go wrong): 

24 >>> with UnitTestDB("mysql://user:passwd@mysql-server:4242/", "path/to/dumps", "my_db") as test_db: 

25 >>> dbc = test_db.dbc 

26 

27 >>> # If you know what you are doing you can also control when the test_db is dropped: 

28 >>> test_db = UnitTestDB("mysql://user:passwd@mysql-server:4242/", "path/to/dumps", "my_db") 

29 >>> # You can access the database via test_db.dbc, for instance: 

30 >>> dbc = test_db.dbc 

31 >>> # At the end do not forget to drop the database 

32 >>> test_db.drop() 

33 

34""" 

35 

36from __future__ import annotations 

37 

38__all__ = [ 

39 "UnitTestDB", 

40] 

41 

42import os 

43from pathlib import Path 

44import subprocess 

45from typing import Any 

46 

47import sqlalchemy 

48from sqlalchemy import text 

49from sqlalchemy.engine import make_url 

50from sqlalchemy.schema import MetaData 

51from sqlalchemy_utils.functions import create_database, database_exists, drop_database 

52 

53from ensembl.utils import StrPath 

54from ensembl.utils.database.dbconnection import DBConnection, StrURL 

55 

56 

57TEST_USERNAME = os.environ.get("USER", "pytestuser") 

58 

59 

60class UnitTestDB: 

61 """Creates and connects to a new test database, applying the schema and importing the data. 

62 

63 Args: 

64 server_url: URL of the server hosting the database. 

65 metadata: Use this metadata to create the schema instead of using an SQL schema file. 

66 dump_dir: Directory path with the database schema in `table.sql` (mandatory) and one TSV data 

67 file (without headers) per table following the convention `<table_name>.txt` (optional). 

68 name: Name to give to the new database. If not provided, the last directory name of `dump_dir` 

69 will be used instead. In either case, the new database name will be prefixed by the username. 

70 tmp_path: Temp dir where the test db is created if using SQLite (otherwise use current dir). 

71 

72 Attributes: 

73 dbc: Database connection handler. 

74 

75 Raises: 

76 FileNotFoundError: If `table.sql` is not found. 

77 

78 """ 

79 

80 def __init__( 

81 self, 

82 server_url: StrURL, 

83 *, 

84 dump_dir: StrPath | None = None, 

85 name: str | None = None, 

86 metadata: MetaData | None = None, 

87 tmp_path: StrPath | None = None, 

88 ) -> None: 

89 db_url = make_url(server_url) 

90 if not name: 

91 name = Path(dump_dir).name if dump_dir else "testdb" 

92 db_name = f"{TEST_USERNAME}_{name}" 

93 

94 # Add the database name to the URL 

95 if db_url.get_dialect().name == "sqlite": 95 ↛ 99line 95 didn't jump to line 99 because the condition on line 95 was always true

96 db_path = Path(tmp_path) / db_name if tmp_path else db_name 

97 db_url = db_url.set(database=f"{db_path}.db") 

98 else: 

99 db_url = db_url.set(database=db_name) 

100 # Enable "local_infile" variable for MySQL databases to allow importing data from files 

101 connect_args = {} 

102 if db_url.get_dialect().name == "mysql": 102 ↛ 103line 102 didn't jump to line 103 because the condition on line 102 was never true

103 connect_args["local_infile"] = 1 

104 # Create the database, dropping it beforehand if it already exists 

105 if database_exists(db_url): 

106 drop_database(db_url) 

107 create_database(db_url) 

108 # Establish the connection to the database, load the schema and import the data 

109 try: 

110 self.dbc = DBConnection(db_url, connect_args=connect_args, reflect=False) 

111 self._load_schema_and_data(dump_dir, metadata) 

112 except: 

113 # Make sure the database is deleted before raising the exception 

114 drop_database(db_url) 

115 raise 

116 # Update the loaded metadata information of the database 

117 self.dbc.load_metadata() 

118 

119 def _load_schema_and_data( 

120 self, dump_dir: StrPath | None = None, metadata: MetaData | None = None 

121 ) -> None: 

122 with self.dbc.begin() as conn: 

123 # Set InnoDB engine as default and disable foreign key checks for MySQL databases 

124 if self.dbc.dialect == "mysql": 124 ↛ 125line 124 didn't jump to line 125 because the condition on line 124 was never true

125 conn.execute(text("SET default_storage_engine=InnoDB")) 

126 conn.execute(text("SET FOREIGN_KEY_CHECKS=0")) 

127 

128 # Load the schema 

129 if metadata: 

130 metadata.create_all(conn) 

131 elif dump_dir: 

132 with Path(dump_dir, "table.sql").open("r") as schema: 

133 for query in "".join(schema.readlines()).split(";"): 

134 if query.strip(): 

135 conn.execute(text(query)) 

136 

137 # And import any available data for each table 

138 if dump_dir: 

139 for tsv_file in Path(dump_dir).glob("*.txt"): 

140 table = tsv_file.stem 

141 self._load_data(conn, table, tsv_file) 

142 

143 # Re-enable foreign key checks for MySQL databases 

144 if self.dbc.dialect == "mysql": 144 ↛ 145line 144 didn't jump to line 145 because the condition on line 144 was never true

145 conn.execute(text("SET FOREIGN_KEY_CHECKS=1")) 

146 

147 def __repr__(self) -> str: 

148 """Returns a string representation of this object.""" 

149 return f"{self.__class__.__name__}({self.dbc.url!r})" 

150 

151 def drop(self) -> None: 

152 """Drops the database.""" 

153 drop_database(self.dbc.url) 

154 # Ensure the connection pool is properly closed and disposed 

155 self.dbc.dispose() 

156 

157 def _load_data(self, conn: sqlalchemy.engine.Connection, table: str, src: StrPath) -> None: 

158 """Loads the table data from the given file. 

159 

160 Args: 

161 conn: Open connection to the database. 

162 table: Table name to load the data to. 

163 src: File path with the data in TSV format (without headers). 

164 

165 """ 

166 if self.dbc.dialect == "sqlite": 166 ↛ 169line 166 didn't jump to line 169 because the condition on line 166 was always true

167 # SQLite does not have an equivalent to "LOAD DATA": use its ".import" command instead 

168 subprocess.run(["sqlite3", self.dbc.db_name, ".mode tabs", f".import {src} {table}"], check=True) 

169 elif self.dbc.dialect == "postgresql": 

170 conn.execute(text(f"COPY {table} FROM '{src}'")) 

171 elif self.dbc.dialect == "sqlserver": 

172 conn.execute(text(f"BULK INSERT {table} FROM '{src}'")) 

173 else: 

174 conn.execute(text(f"LOAD DATA LOCAL INFILE '{src}' INTO TABLE {table}")) 

175 

176 def __enter__(self) -> UnitTestDB: 

177 return self 

178 

179 def __exit__(self, *args: Any) -> None: 

180 self.drop()