Coverage for src / ensembl / utils / argparse.py: 100%
100 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:45 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:45 +0000
1# See the NOTICE file distributed with this work for additional information
2# regarding copyright ownership.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15"""Provide an extended version of `argparse.ArgumentParser` with additional functionality.
17Examples:
19 >>> from pathlib import Path
20 >>> from ensembl.util.argparse import ArgumentParser
21 >>> parser = ArgumentParser(description="Tool description")
22 >>> parser.add_argument_src_path("--src_file", required=True, help="Path to source file")
23 >>> parser.add_server_arguments(help="Server to connect to")
24 >>> args = parser.parse_args()
25 >>> args
26 Namespace(host='localhost', port=3826, src_file=PosixPath('/path/to/src_file.txt'),
27 url=URL('mysql://username@localhost:3826'), user='username')
29"""
31from __future__ import annotations
33__all__ = [
34 "ArgumentError",
35 "ArgumentParser",
36]
38import argparse
39import os
40from pathlib import Path
41import re
42from typing import Any, Callable
44from sqlalchemy.engine import make_url, URL
46from ensembl.utils import StrPath
49class ArgumentError(Exception):
50 """An error from creating an argument (optional or positional)."""
53class ArgumentParser(argparse.ArgumentParser):
54 """Extends `argparse.ArgumentParser` with additional methods and functionality.
56 The default behaviour of the help text will be to display the default values on every non-required
57 argument, i.e. optional arguments with `required=False`.
59 """
61 def __init__(self, *args: Any, **kwargs: Any) -> None:
62 """Extends the base class to include the information about default argument values by default."""
63 super().__init__(*args, **kwargs)
64 self.formatter_class = argparse.ArgumentDefaultsHelpFormatter
66 def _validate_src_path(self, src_path: StrPath) -> Path:
67 """Returns the path if exists and it is readable, raises an error through the parser otherwise.
69 Args:
70 src_path: File or directory path to check.
72 """
73 src_path = Path(src_path)
74 if not src_path.exists():
75 self.error(f"'{src_path}' not found")
76 elif not os.access(src_path, os.R_OK):
77 self.error(f"'{src_path}' not readable")
78 return src_path
80 def _validate_dst_path(self, dst_path: StrPath, exists_ok: bool = False) -> Path:
81 """Returns the path if it is writable, raises an error through the parser otherwise.
83 Args:
84 dst_path: File or directory path to check.
85 exists_ok: Do not raise an error during parsing if the destination path already exists.
87 """
88 dst_path = Path(dst_path)
89 if dst_path.exists():
90 if os.access(dst_path, os.W_OK):
91 if exists_ok:
92 return dst_path
93 self.error(f"'{dst_path}' already exists")
94 else:
95 self.error(f"'{dst_path}' is not writable")
96 # Check if the first parent directory that exists is writable
97 for parent_path in dst_path.parents:
98 if parent_path.exists():
99 if not os.access(parent_path, os.W_OK):
100 self.error(f"'{dst_path}' is not writable")
101 break
102 else:
103 self.error(f"'{dst_path}' is not writable")
104 return dst_path
106 def _validate_number(
107 self,
108 value: str,
109 value_type: Callable[[str], int | float],
110 min_value: int | float | None,
111 max_value: int | float | None,
112 ) -> int | float:
113 """Returns the numeric value if it is of the expected type and it is within the specified range.
115 Args:
116 value: String representation of numeric value to check.
117 value_type: Expected type of the numeric value.
118 min_value: Minimum value constrain. If `None`, no minimum value constrain.
119 max_value: Maximum value constrain. If `None`, no maximum value constrain.
121 """
122 # Check if the string representation can be converted to the expected type
123 try:
124 result = value_type(value)
125 except (TypeError, ValueError):
126 self.error(f"invalid {value_type.__name__} value: {value}")
127 # Check if numeric value is within range
128 if (min_value is not None) and (result < min_value):
129 self.error(f"{value} is lower than minimum value ({min_value})")
130 if (max_value is not None) and (result > max_value):
131 self.error(f"{value} is greater than maximum value ({max_value})")
132 return result
134 def add_argument(self, *args: Any, **kwargs: Any) -> None: # type: ignore[override]
135 """Extends the parent function by excluding the default value in the help text when not provided.
137 Only applied to required arguments without a default value, i.e. positional arguments or optional
138 arguments with `required=True`.
140 """
141 if kwargs.get("required", False):
142 kwargs.setdefault("default", argparse.SUPPRESS)
143 super().add_argument(*args, **kwargs)
145 def add_argument_src_path(self, *args: Any, **kwargs: Any) -> None:
146 """Adds `pathlib.Path` argument, checking if it exists and it is readable at parsing time.
148 If "metavar" is not defined, it is added with "PATH" as value to improve help text readability.
150 """
151 kwargs.setdefault("metavar", "PATH")
152 kwargs["type"] = self._validate_src_path
153 self.add_argument(*args, **kwargs)
155 def add_argument_dst_path(self, *args: Any, exists_ok: bool = True, **kwargs: Any) -> None:
156 """Adds `pathlib.Path` argument, checking if it is writable at parsing time.
158 If "metavar" is not defined it is added with "PATH" as value to improve help text readability.
160 Args:
161 exists_ok: Do not raise an error if the destination path already exists.
163 """
164 kwargs.setdefault("metavar", "PATH")
165 kwargs["type"] = lambda x: self._validate_dst_path(x, exists_ok)
166 self.add_argument(*args, **kwargs)
168 def add_argument_url(self, *args: Any, **kwargs: Any) -> None:
169 """Adds `sqlalchemy.engine.URL` argument.
171 If "metavar" is not defined it is added with "URI" as value to improve help text readability.
173 """
174 kwargs.setdefault("metavar", "URI")
175 kwargs["type"] = make_url
176 self.add_argument(*args, **kwargs)
178 # pylint: disable=redefined-builtin
179 def add_numeric_argument(
180 self,
181 *args: Any,
182 type: Callable[[str], int | float] = float,
183 min_value: int | float | None = None,
184 max_value: int | float | None = None,
185 **kwargs: Any,
186 ) -> None:
187 """Adds a numeric argument with constrains on its type and its minimum or maximum value.
189 Note that the default value (if defined) is not checked unless the argument is an optional argument
190 and no value is provided in the command line.
192 Args:
193 type: Type to convert the argument value to when parsing.
194 min_value: Minimum value constrain. If `None`, no minimum value constrain.
195 max_value: Maximum value constrain. If `None`, no maximum value constrain.
197 """
198 # If both minimum and maximum values are defined, ensure min_value <= max_value
199 if (min_value is not None) and (max_value is not None) and (min_value > max_value):
200 raise ArgumentError("minimum value is greater than maximum value")
201 # Add lambda function to check numeric constrains when parsing argument
202 kwargs["type"] = lambda x: self._validate_number(x, type, min_value, max_value)
203 self.add_argument(*args, **kwargs)
205 # pylint: disable=redefined-builtin
206 def add_server_arguments(
207 self, prefix: str = "", include_database: bool = False, help: str | None = None
208 ) -> None:
209 """Adds the usual set of arguments needed to connect to a server, i.e. `--host`, `--port`, `--user`
210 and `--password` (optional).
212 Note that the parser will assume this is a MySQL server.
214 Warning:
215 Avoid passing ``--password`` directly on the command line as it will be visible in the
216 process list and shell history. Use an environment variable or an interactive prompt via
217 ``getpass`` instead.
219 Args:
220 prefix: Prefix to add the each argument, e.g. if prefix is `src_`, the arguments will be
221 `--src_host`, etc.
222 include_database: Include `--database` argument.
223 help: Description message to include for this set of arguments.
225 """
226 group = self.add_argument_group(f"{prefix}server connection arguments", description=help)
227 group.add_argument(
228 f"--{prefix}host", required=True, metavar="HOST", default=argparse.SUPPRESS, help="host name"
229 )
230 group.add_argument(
231 f"--{prefix}port",
232 required=True,
233 type=int,
234 metavar="PORT",
235 default=argparse.SUPPRESS,
236 help="port number",
237 )
238 group.add_argument(
239 f"--{prefix}user", required=True, metavar="USER", default=argparse.SUPPRESS, help="user name"
240 )
241 group.add_argument(f"--{prefix}password", metavar="PWD", help="host password")
242 if include_database:
243 group.add_argument(
244 f"--{prefix}database",
245 required=True,
246 metavar="NAME",
247 default=argparse.SUPPRESS,
248 help="database name",
249 )
251 def add_log_arguments(self, add_log_file: bool = False) -> None:
252 """Adds the usual set of arguments required to set and initialise a logging system.
254 The current set includes a mutually exclusive group for the default logging level: `--verbose`,
255 `--debug`, `--quiet` or `--log LEVEL`.
257 Args:
258 add_log_file: Add arguments to allow storing messages into a file, i.e. `--log_file` and
259 `--log_file_level`.
261 """
262 # Define the list of log levels available
263 log_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
264 # NOTE: from 3.11 this list can be changed to: logging.getLevelNamesMapping().keys()
265 # Create logging arguments group
266 group = self.add_argument_group("logging arguments")
267 # Add 3 mutually exclusive options to set the logging level
268 subgroup = group.add_mutually_exclusive_group()
269 subgroup.add_argument(
270 "-v",
271 "--verbose",
272 action="store_const",
273 const="INFO",
274 dest="log_level",
275 help="verbose mode, i.e. 'INFO' log level",
276 )
277 subgroup.add_argument(
278 "--debug",
279 action="store_const",
280 const="DEBUG",
281 dest="log_level",
282 help="debugging mode, i.e. 'DEBUG' log level",
283 )
284 subgroup.add_argument(
285 "--quiet",
286 action="store_const",
287 const="CRITICAL",
288 dest="log_level",
289 help="quiet mode, i.e. 'CRITICAL' log level",
290 )
291 subgroup.add_argument(
292 "--log",
293 choices=log_levels,
294 type=str.upper,
295 default="WARNING",
296 metavar="LEVEL",
297 dest="log_level",
298 help="level of the events to track: %(choices)s",
299 )
300 subgroup.set_defaults(log_level="WARNING")
301 if add_log_file:
302 # Add log file-related arguments
303 group.add_argument(
304 "--log_file",
305 type=lambda x: self._validate_dst_path(x, exists_ok=True),
306 metavar="PATH",
307 default=None,
308 help="log file path",
309 )
310 group.add_argument(
311 "--log_file_level",
312 choices=log_levels,
313 type=str.upper,
314 default="DEBUG",
315 metavar="LEVEL",
316 help="level of the events to track in the log file: %(choices)s",
317 )
319 def parse_args(self, *args: Any, **kwargs: Any) -> argparse.Namespace: # type: ignore[override]
320 """Extends the parent function by adding a new URL argument for every server group added.
322 The type of this new argument will be `sqlalchemy.engine.URL`. It also logs all the parsed
323 arguments for debugging purposes when logging arguments have been added.
325 """
326 arguments = super().parse_args(*args, **kwargs)
327 # Build and add an sqlalchemy.engine.URL object for every server group added
328 pattern = re.compile(r"([\w-]*)host$")
329 server_prefixes = [x.group(1) for x in map(pattern.match, vars(arguments)) if x]
330 for prefix in server_prefixes:
331 # Raise an error rather than overwriting when the URL argument is already present
332 if f"{prefix}url" in arguments:
333 self.error(f"argument '{prefix}url' is already present")
334 try:
335 server_url = URL.create(
336 "mysql",
337 getattr(arguments, f"{prefix}user"),
338 getattr(arguments, f"{prefix}password"),
339 getattr(arguments, f"{prefix}host"),
340 getattr(arguments, f"{prefix}port"),
341 getattr(arguments, f"{prefix}database", None),
342 )
343 except AttributeError:
344 # Not a database server host argument
345 continue
346 setattr(arguments, f"{prefix}url", server_url)
347 return arguments