Coverage for src/ensembl/utils/argparse.py: 99%
98 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-06 14:10 +0000
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-06 14:10 +0000
1# See the NOTICE file distributed with this work for additional information
2# regarding copyright ownership.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15"""Provide an extended version of `argparse.ArgumentParser` with additional functionality.
17Examples:
19 >>> from pathlib import Path
20 >>> from ensembl.util.argparse import ArgumentParser
21 >>> parser = ArgumentParser(description="Tool description")
22 >>> parser.add_argument_src_path("--src_file", required=True, help="Path to source file")
23 >>> parser.add_server_arguments(help="Server to connect to")
24 >>> args = parser.parse_args()
25 >>> args
26 Namespace(host='localhost', port=3826, src_file=PosixPath('/path/to/src_file.txt'),
27 url=URL('mysql://username@localhost:3826'), user='username')
29"""
31from __future__ import annotations
33__all__ = [
34 "ArgumentError",
35 "ArgumentParser",
36]
38import argparse
39import os
40from pathlib import Path
41import re
42from typing import Any, Callable
44from sqlalchemy.engine import make_url, URL
46from ensembl.utils import StrPath
49class ArgumentError(Exception):
50 """An error from creating an argument (optional or positional)."""
53class ArgumentParser(argparse.ArgumentParser):
54 """Extends `argparse.ArgumentParser` with additional methods and functionality.
56 The default behaviour of the help text will be to display the default values on every non-required
57 argument, i.e. optional arguments with `required=False`.
59 """
61 def __init__(self, *args: Any, **kwargs: Any) -> None:
62 """Extends the base class to include the information about default argument values by default."""
63 super().__init__(*args, **kwargs)
64 self.formatter_class = argparse.ArgumentDefaultsHelpFormatter
66 def _validate_src_path(self, src_path: StrPath) -> Path:
67 """Returns the path if exists and it is readable, raises an error through the parser otherwise.
69 Args:
70 src_path: File or directory path to check.
72 """
73 src_path = Path(src_path)
74 if not src_path.exists():
75 self.error(f"'{src_path}' not found")
76 elif not os.access(src_path, os.R_OK):
77 self.error(f"'{src_path}' not readable")
78 return src_path
80 def _validate_dst_path(self, dst_path: StrPath, exists_ok: bool = False) -> Path:
81 """Returns the path if it is writable, raises an error through the parser otherwise.
83 Args:
84 dst_path: File or directory path to check.
85 exists_ok: Do not raise an error during parsing if the destination path already exists.
87 """
88 dst_path = Path(dst_path)
89 if dst_path.exists():
90 if os.access(dst_path, os.W_OK):
91 if exists_ok:
92 return dst_path
93 self.error(f"'{dst_path}' already exists")
94 else:
95 self.error(f"'{dst_path}' is not writable")
96 # Check if the first parent directory that exists is writable
97 for parent_path in dst_path.parents: 97 ↛ 102line 97 didn't jump to line 102 because the loop on line 97 didn't complete
98 if parent_path.exists():
99 if not os.access(parent_path, os.W_OK):
100 self.error(f"'{dst_path}' is not writable")
101 break
102 return dst_path
104 def _validate_number(
105 self,
106 value: str,
107 value_type: Callable[[str], int | float],
108 min_value: int | float | None,
109 max_value: int | float | None,
110 ) -> int | float:
111 """Returns the numeric value if it is of the expected type and it is within the specified range.
113 Args:
114 value: String representation of numeric value to check.
115 value_type: Expected type of the numeric value.
116 min_value: Minimum value constrain. If `None`, no minimum value constrain.
117 max_value: Maximum value constrain. If `None`, no maximum value constrain.
119 """
120 # Check if the string representation can be converted to the expected type
121 try:
122 result = value_type(value)
123 except (TypeError, ValueError):
124 self.error(f"invalid {value_type.__name__} value: {value}")
125 # Check if numeric value is within range
126 if (min_value is not None) and (result < min_value):
127 self.error(f"{value} is lower than minimum value ({min_value})")
128 if (max_value is not None) and (result > max_value):
129 self.error(f"{value} is greater than maximum value ({max_value})")
130 return result
132 def add_argument(self, *args: Any, **kwargs: Any) -> None: # type: ignore[override]
133 """Extends the parent function by excluding the default value in the help text when not provided.
135 Only applied to required arguments without a default value, i.e. positional arguments or optional
136 arguments with `required=True`.
138 """
139 if kwargs.get("required", False):
140 kwargs.setdefault("default", argparse.SUPPRESS)
141 super().add_argument(*args, **kwargs)
143 def add_argument_src_path(self, *args: Any, **kwargs: Any) -> None:
144 """Adds `pathlib.Path` argument, checking if it exists and it is readable at parsing time.
146 If "metavar" is not defined, it is added with "PATH" as value to improve help text readability.
148 """
149 kwargs.setdefault("metavar", "PATH")
150 kwargs["type"] = self._validate_src_path
151 self.add_argument(*args, **kwargs)
153 def add_argument_dst_path(self, *args: Any, exists_ok: bool = True, **kwargs: Any) -> None:
154 """Adds `pathlib.Path` argument, checking if it is writable at parsing time.
156 If "metavar" is not defined it is added with "PATH" as value to improve help text readability.
158 Args:
159 exists_ok: Do not raise an error if the destination path already exists.
161 """
162 kwargs.setdefault("metavar", "PATH")
163 kwargs["type"] = lambda x: self._validate_dst_path(x, exists_ok)
164 self.add_argument(*args, **kwargs)
166 def add_argument_url(self, *args: Any, **kwargs: Any) -> None:
167 """Adds `sqlalchemy.engine.URL` argument.
169 If "metavar" is not defined it is added with "URI" as value to improve help text readability.
171 """
172 kwargs.setdefault("metavar", "URI")
173 kwargs["type"] = make_url
174 self.add_argument(*args, **kwargs)
176 # pylint: disable=redefined-builtin
177 def add_numeric_argument(
178 self,
179 *args: Any,
180 type: Callable[[str], int | float] = float,
181 min_value: int | float | None = None,
182 max_value: int | float | None = None,
183 **kwargs: Any,
184 ) -> None:
185 """Adds a numeric argument with constrains on its type and its minimum or maximum value.
187 Note that the default value (if defined) is not checked unless the argument is an optional argument
188 and no value is provided in the command line.
190 Args:
191 type: Type to convert the argument value to when parsing.
192 min_value: Minimum value constrain. If `None`, no minimum value constrain.
193 max_value: Maximum value constrain. If `None`, no maximum value constrain.
195 """
196 # If both minimum and maximum values are defined, ensure min_value <= max_value
197 if (min_value is not None) and (max_value is not None) and (min_value > max_value):
198 raise ArgumentError("minimum value is greater than maximum value")
199 # Add lambda function to check numeric constrains when parsing argument
200 kwargs["type"] = lambda x: self._validate_number(x, type, min_value, max_value)
201 self.add_argument(*args, **kwargs)
203 # pylint: disable=redefined-builtin
204 def add_server_arguments(
205 self, prefix: str = "", include_database: bool = False, help: str | None = None
206 ) -> None:
207 """Adds the usual set of arguments needed to connect to a server, i.e. `--host`, `--port`, `--user`
208 and `--password` (optional).
210 Note that the parser will assume this is a MySQL server.
212 Args:
213 prefix: Prefix to add the each argument, e.g. if prefix is `src_`, the arguments will be
214 `--src_host`, etc.
215 include_database: Include `--database` argument.
216 help: Description message to include for this set of arguments.
218 """
219 group = self.add_argument_group(f"{prefix}server connection arguments", description=help)
220 group.add_argument(
221 f"--{prefix}host", required=True, metavar="HOST", default=argparse.SUPPRESS, help="host name"
222 )
223 group.add_argument(
224 f"--{prefix}port",
225 required=True,
226 type=int,
227 metavar="PORT",
228 default=argparse.SUPPRESS,
229 help="port number",
230 )
231 group.add_argument(
232 f"--{prefix}user", required=True, metavar="USER", default=argparse.SUPPRESS, help="user name"
233 )
234 group.add_argument(f"--{prefix}password", metavar="PWD", help="host password")
235 if include_database:
236 group.add_argument(
237 f"--{prefix}database",
238 required=True,
239 metavar="NAME",
240 default=argparse.SUPPRESS,
241 help="database name",
242 )
244 def add_log_arguments(self, add_log_file: bool = False) -> None:
245 """Adds the usual set of arguments required to set and initialise a logging system.
247 The current set includes a mutually exclusive group for the default logging level: `--verbose`,
248 `--debug` or `--log LEVEL`.
250 Args:
251 add_log_file: Add arguments to allow storing messages into a file, i.e. `--log_file` and
252 `--log_file_level`.
254 """
255 # Define the list of log levels available
256 log_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
257 # NOTE: from 3.11 this list can be changed to: logging.getLevelNamesMapping().keys()
258 # Create logging arguments group
259 group = self.add_argument_group("logging arguments")
260 # Add 3 mutually exclusive options to set the logging level
261 subgroup = group.add_mutually_exclusive_group()
262 subgroup.add_argument(
263 "-v",
264 "--verbose",
265 action="store_const",
266 const="INFO",
267 dest="log_level",
268 help="verbose mode, i.e. 'INFO' log level",
269 )
270 subgroup.add_argument(
271 "--debug",
272 action="store_const",
273 const="DEBUG",
274 dest="log_level",
275 help="debugging mode, i.e. 'DEBUG' log level",
276 )
277 subgroup.add_argument(
278 "--log",
279 choices=log_levels,
280 type=str.upper,
281 default="WARNING",
282 metavar="LEVEL",
283 dest="log_level",
284 help="level of the events to track: %(choices)s",
285 )
286 subgroup.set_defaults(log_level="WARNING")
287 if add_log_file:
288 # Add log file-related arguments
289 group.add_argument(
290 "--log_file",
291 type=lambda x: self._validate_dst_path(x, exists_ok=True),
292 metavar="PATH",
293 default=None,
294 help="log file path",
295 )
296 group.add_argument(
297 "--log_file_level",
298 choices=log_levels,
299 type=str.upper,
300 default="DEBUG",
301 metavar="LEVEL",
302 help="level of the events to track in the log file: %(choices)s",
303 )
305 def parse_args(self, *args: Any, **kwargs: Any) -> argparse.Namespace: # type: ignore[override]
306 """Extends the parent function by adding a new URL argument for every server group added.
308 The type of this new argument will be `sqlalchemy.engine.URL`. It also logs all the parsed
309 arguments for debugging purposes when logging arguments have been added.
311 """
312 arguments = super().parse_args(*args, **kwargs)
313 # Build and add an sqlalchemy.engine.URL object for every server group added
314 pattern = re.compile(r"([\w-]*)host$")
315 server_prefixes = [x.group(1) for x in map(pattern.match, vars(arguments)) if x]
316 for prefix in server_prefixes:
317 # Raise an error rather than overwriting when the URL argument is already present
318 if f"{prefix}url" in arguments:
319 self.error(f"argument '{prefix}url' is already present")
320 try:
321 server_url = URL.create(
322 "mysql",
323 getattr(arguments, f"{prefix}user"),
324 getattr(arguments, f"{prefix}password"),
325 getattr(arguments, f"{prefix}host"),
326 getattr(arguments, f"{prefix}port"),
327 getattr(arguments, f"{prefix}database", None),
328 )
329 except AttributeError:
330 # Not a database server host argument
331 continue
332 setattr(arguments, f"{prefix}url", server_url)
333 return arguments