Coverage for src/ensembl/utils/database/dbconnection.py: 90%
98 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-06 14:10 +0000
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-06 14:10 +0000
1# See the NOTICE file distributed with this work for additional information
2# regarding copyright ownership.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15"""Database connection handler.
17This module provides the main class to connect to and access databases. It will be an ORM-less
18connection, that is, the data can only be accessed via SQL queries (see example below).
20Examples:
22 >>> from ensembl.utils.database import DBConnection
23 >>> dbc = DBConnection("mysql://ensro@mysql-server:4242/mydb")
24 >>> # You can access the database data via sql queries, for instance:
25 >>> results = dbc.execute("SELECT * FROM my_table;")
26 >>> # Or via a connection in a transaction manner:
27 >>> with dbc.begin() as conn:
28 >>> results = conn.execute("SELECT * FROM my_table;")
30"""
32from __future__ import annotations
34__all__ = [
35 "Query",
36 "StrURL",
37 "DBConnection",
38]
40from contextlib import contextmanager
41from typing import Any, ContextManager, Generator, Optional, TypeVar
43import sqlalchemy
44from sqlalchemy import create_engine, event
45from sqlalchemy.orm import sessionmaker
46from sqlalchemy.schema import MetaData, Table
49Query = TypeVar("Query", str, sqlalchemy.sql.expression.ClauseElement, sqlalchemy.sql.expression.TextClause)
50StrURL = TypeVar("StrURL", str, sqlalchemy.engine.URL)
53class DBConnection:
54 """Database connection handler, providing also the database's schema and properties.
56 Args:
57 url: URL to the database, e.g. `mysql://user:passwd@host:port/my_db`.
58 reflect: Reflect the database schema or not.
60 """
62 def __init__(self, url: StrURL, reflect: bool = True, **kwargs: Any) -> None:
63 self._engine = create_engine(url, future=True, **kwargs)
64 self._metadata: MetaData | None = None
65 if reflect:
66 self.load_metadata()
68 def __repr__(self) -> str:
69 """Returns a string representation of this object."""
70 return f"{self.__class__.__name__}({self.url!r})"
72 def load_metadata(self) -> None:
73 """Loads the metadata information of the database."""
74 # Note: Just reflect() is not enough as it would not delete tables that no longer exist
75 self._metadata = sqlalchemy.MetaData()
76 self._metadata.reflect(bind=self._engine)
78 def create_all_tables(self, metadata: MetaData) -> None:
79 """Create the tables from the metadata and set the metadata.
81 This assumes the database is empty beforehand. If the tables already exist, they will be ignored.
82 If there are other tables, you may need to run `self.load_metadata()` to update the metadata schema.
83 """
84 self._metadata = metadata
85 metadata.create_all(self._engine)
87 def create_table(self, table: Table) -> None:
88 """Create a table in the database and update the metadata. Do nothing if the table already exists."""
89 table.create(self._engine)
90 # We need to update the metadata to register the new table
91 self.load_metadata()
93 @property
94 def url(self) -> str:
95 """Returns the database URL."""
96 return self._engine.url.render_as_string(hide_password=False)
98 @property
99 def db_name(self) -> Optional[str]:
100 """Returns the database name."""
101 return self._engine.url.database
103 @property
104 def host(self) -> Optional[str]:
105 """Returns the database host."""
106 return self._engine.url.host
108 @property
109 def port(self) -> Optional[int]:
110 """Returns the port of the database host."""
111 return self._engine.url.port
113 @property
114 def dialect(self) -> str:
115 """Returns the SQLAlchemy database dialect name of the database host."""
116 return self._engine.name
118 @property
119 def tables(self) -> dict[str, sqlalchemy.schema.Table]:
120 """Returns the database tables keyed to their name, or an empty dict if no metadata was loaded."""
121 if self._metadata:
122 return self._metadata.tables
123 return {}
125 def get_primary_key_columns(self, table: str) -> list[str]:
126 """Returns the primary key column names for the given table.
128 Args:
129 table: Table name.
131 """
132 return [col.name for col in self.tables[table].primary_key.columns.values()]
134 def get_columns(self, table: str) -> list[str]:
135 """Returns the column names for the given table.
137 Args:
138 table: Table name.
140 """
141 return [col.name for col in self.tables[table].columns]
143 def connect(self) -> sqlalchemy.engine.Connection:
144 """Returns a new database connection."""
145 return self._engine.connect()
147 def begin(self, *args: Any) -> ContextManager[sqlalchemy.engine.Connection]:
148 """Returns a context manager delivering a database connection with a transaction established."""
149 return self._engine.begin(*args)
151 def dispose(self) -> None:
152 """Disposes of the connection pool."""
153 self._engine.dispose()
155 def _enable_sqlite_savepoints(self, engine: sqlalchemy.engine.Engine) -> None:
156 """Enables SQLite SAVEPOINTS to allow session rollbacks."""
158 @event.listens_for(engine, "connect")
159 def do_connect(
160 dbapi_connection: Any, # SQLAlchemy is not clear about the type of this argument
161 connection_record: sqlalchemy.pool.ConnectionPoolEntry, # pylint: disable=unused-argument
162 ) -> None:
163 """Disables emitting the BEGIN statement entirely, as well as COMMIT before any DDL."""
164 dbapi_connection.isolation_level = None
166 @event.listens_for(engine, "begin")
167 def do_begin(conn: sqlalchemy.engine.Connection) -> None:
168 """Emits a custom own BEGIN."""
169 conn.exec_driver_sql("BEGIN")
171 @contextmanager
172 def session_scope(self) -> Generator[sqlalchemy.orm.Session, None, None]:
173 """Provides a transactional scope around a series of operations with rollback in case of failure.
175 Bear in mind MySQL's storage engine MyISAM does not support rollback transactions, so all
176 the modifications performed to the database will persist.
178 """
179 # Create a dedicated engine for this session
180 engine = create_engine(self._engine.url)
181 if self.dialect == "sqlite": 181 ↛ 183line 181 didn't jump to line 183 because the condition on line 181 was always true
182 self._enable_sqlite_savepoints(engine)
183 Session = sessionmaker(future=True)
184 session = Session(bind=engine, autoflush=False)
185 try:
186 yield session
187 session.commit()
188 except:
189 # Rollback to ensure no changes are made to the database
190 session.rollback()
191 raise
192 finally:
193 # Whatever happens, make sure the session is closed
194 session.close()
196 @contextmanager
197 def test_session_scope(self) -> Generator[sqlalchemy.orm.Session, None, None]:
198 """Provides a transactional scope around a series of operations that will be rolled back at the end.
200 Bear in mind MySQL's storage engine MyISAM does not support rollback transactions, so all
201 the modifications performed to the database will persist.
203 """
204 # Create a dedicated engine for this session
205 engine = create_engine(self._engine.url)
206 if self.dialect == "sqlite": 206 ↛ 209line 206 didn't jump to line 209 because the condition on line 206 was always true
207 self._enable_sqlite_savepoints(engine)
208 # Connect to the database
209 connection = engine.connect()
210 # Begin a non-ORM transaction
211 transaction = connection.begin()
212 # Bind an individual session to the connection
213 Session = sessionmaker(future=True)
214 try:
215 # Running on SQLAlchemy 2.0+
216 session = Session(bind=connection, join_transaction_mode="create_savepoint")
217 except TypeError:
218 # Running on SQLAlchemy 1.4
219 session = Session(bind=connection)
220 # If the database supports SAVEPOINT, starting a savepoint will allow to also use rollback
221 connection.begin_nested()
223 # Define a new transaction event
224 @event.listens_for(session, "after_transaction_end")
225 def end_savepoint(
226 session: sqlalchemy.orm.Session, # pylint: disable=unused-argument
227 transaction: sqlalchemy.orm.SessionTransaction, # pylint: disable=unused-argument
228 ) -> None:
229 if not connection.in_nested_transaction():
230 connection.begin_nested()
232 try:
233 yield session
234 finally:
235 # Whatever happens, make sure the session and connection are closed, rolling back
236 # everything done with the session (including calls to commit())
237 session.close()
238 transaction.rollback()
239 connection.close()