Coverage for src/ensembl/utils/argparse.py: 99%
94 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-05 15:47 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-05 15:47 +0000
1# See the NOTICE file distributed with this work for additional information
2# regarding copyright ownership.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15"""Provide an extended version of `argparse.ArgumentParser` with additional functionality.
17Examples:
19 >>> from pathlib import Path
20 >>> from ensembl.util.argparse import ArgumentParser
21 >>> parser = ArgumentParser(description="Tool description")
22 >>> parser.add_argument_src_path("--src_file", required=True, help="Path to source file")
23 >>> parser.add_server_arguments(help="Server to connect to")
24 >>> args = parser.parse_args()
25 >>> args
26 Namespace(host='myserver', port=3826, src_file=PosixPath('/path/to/src_file.txt'),
27 url=URL('mysql://username@myserver:3826'), user='username')
29"""
31from __future__ import annotations
33__all__ = [
34 "ArgumentError",
35 "ArgumentParser",
36]
38import argparse
39import os
40from pathlib import Path
41from typing import Any, Callable
43from sqlalchemy.engine import make_url, URL
45from ensembl.utils import StrPath
48class ArgumentError(Exception):
49 """An error from creating an argument (optional or positional)."""
52class ArgumentParser(argparse.ArgumentParser):
53 """Extends `argparse.ArgumentParser` with additional methods and functionality.
55 The default behaviour of the help text will be to display the default values on every non-required
56 argument, i.e. optional arguments with `required=False`.
58 """
60 def __init__(self, *args: Any, **kwargs: Any) -> None:
61 """Extends the base class to include the information about default argument values by default."""
62 super().__init__(*args, **kwargs)
63 self.formatter_class = argparse.ArgumentDefaultsHelpFormatter
64 self.__server_groups: list[str] = []
66 def _validate_src_path(self, src_path: StrPath) -> Path:
67 """Returns the path if exists and it is readable, raises an error through the parser otherwise.
69 Args:
70 src_path: File or directory path to check.
72 """
73 src_path = Path(src_path)
74 if not src_path.exists():
75 self.error(f"'{src_path}' not found")
76 elif not os.access(src_path, os.R_OK):
77 self.error(f"'{src_path}' not readable")
78 return src_path
80 def _validate_dst_path(self, dst_path: StrPath, exists_ok: bool = False) -> Path:
81 """Returns the path if it is writable, raises an error through the parser otherwise.
83 Args:
84 dst_path: File or directory path to check.
85 exists_ok: Do not raise an error during parsing if the destination path already exists.
87 """
88 dst_path = Path(dst_path)
89 if dst_path.exists():
90 if os.access(dst_path, os.W_OK):
91 if exists_ok:
92 return dst_path
93 self.error(f"'{dst_path}' already exists")
94 else:
95 self.error(f"'{dst_path}' is not writable")
96 # Check if the first parent directory that exists is writable
97 for parent_path in dst_path.parents: 97 ↛ 102line 97 didn't jump to line 102 because the loop on line 97 didn't complete
98 if parent_path.exists():
99 if not os.access(parent_path, os.W_OK):
100 self.error(f"'{dst_path}' is not writable")
101 break
102 return dst_path
104 def _validate_number(
105 self,
106 value: str,
107 value_type: Callable[[str], int | float],
108 min_value: int | float | None,
109 max_value: int | float | None,
110 ) -> int | float:
111 """Returns the numeric value if it is of the expected type and it is within the specified range.
113 Args:
114 value: String representation of numeric value to check.
115 value_type: Expected type of the numeric value.
116 min_value: Minimum value constrain. If `None`, no minimum value constrain.
117 max_value: Maximum value constrain. If `None`, no maximum value constrain.
119 """
120 # Check if the string representation can be converted to the expected type
121 try:
122 result = value_type(value)
123 except (TypeError, ValueError):
124 self.error(f"invalid {value_type.__name__} value: {value}")
125 # Check if numeric value is within range
126 if (min_value is not None) and (result < min_value):
127 self.error(f"{value} is lower than minimum value ({min_value})")
128 if (max_value is not None) and (result > max_value):
129 self.error(f"{value} is greater than maximum value ({max_value})")
130 return result
132 def add_argument(self, *args: Any, **kwargs: Any) -> None: # type: ignore[override]
133 """Extends the parent function by excluding the default value in the help text when not provided.
135 Only applied to required arguments without a default value, i.e. positional arguments or optional
136 arguments with `required=True`.
138 """
139 if kwargs.get("required", False):
140 kwargs.setdefault("default", argparse.SUPPRESS)
141 super().add_argument(*args, **kwargs)
143 def add_argument_src_path(self, *args: Any, **kwargs: Any) -> None:
144 """Adds `pathlib.Path` argument, checking if it exists and it is readable at parsing time.
146 If "metavar" is not defined, it is added with "PATH" as value to improve help text readability.
148 """
149 kwargs.setdefault("metavar", "PATH")
150 kwargs["type"] = self._validate_src_path
151 self.add_argument(*args, **kwargs)
153 def add_argument_dst_path(self, *args: Any, exists_ok: bool = True, **kwargs: Any) -> None:
154 """Adds `pathlib.Path` argument, checking if it is writable at parsing time.
156 If "metavar" is not defined it is added with "PATH" as value to improve help text readability.
158 Args:
159 exists_ok: Do not raise an error if the destination path already exists.
161 """
162 kwargs.setdefault("metavar", "PATH")
163 kwargs["type"] = lambda x: self._validate_dst_path(x, exists_ok)
164 self.add_argument(*args, **kwargs)
166 def add_argument_url(self, *args: Any, **kwargs: Any) -> None:
167 """Adds `sqlalchemy.engine.URL` argument.
169 If "metavar" is not defined it is added with "URI" as value to improve help text readability.
171 """
172 kwargs.setdefault("metavar", "URI")
173 kwargs["type"] = make_url
174 self.add_argument(*args, **kwargs)
176 # pylint: disable=redefined-builtin
177 def add_numeric_argument(
178 self,
179 *args: Any,
180 type: Callable[[str], int | float] = float,
181 min_value: int | float | None = None,
182 max_value: int | float | None = None,
183 **kwargs: Any,
184 ) -> None:
185 """Adds a numeric argument with constrains on its type and its minimum or maximum value.
187 Note that the default value (if defined) is not checked unless the argument is an optional argument
188 and no value is provided in the command line.
190 Args:
191 type: Type to convert the argument value to when parsing.
192 min_value: Minimum value constrain. If `None`, no minimum value constrain.
193 max_value: Maximum value constrain. If `None`, no maximum value constrain.
195 """
196 # If both minimum and maximum values are defined, ensure min_value <= max_value
197 if (min_value is not None) and (max_value is not None) and (min_value > max_value):
198 raise ArgumentError("minimum value is greater than maximum value")
199 # Add lambda function to check numeric constrains when parsing argument
200 kwargs["type"] = lambda x: self._validate_number(x, type, min_value, max_value)
201 self.add_argument(*args, **kwargs)
203 # pylint: disable=redefined-builtin
204 def add_server_arguments(self, prefix: str = "", include_database: bool = False, help: str = "") -> None:
205 """Adds the usual set of arguments needed to connect to a server, i.e. `--host`, `--port`, `--user`
206 and `--password` (optional).
208 Note that the parser will assume this is a MySQL server.
210 Args:
211 prefix: Prefix to add the each argument, e.g. if prefix is `src_`, the arguments will be
212 `--src_host`, etc.
213 include_database: Include `--database` argument.
214 help: Description message to include for this set of arguments.
216 """
217 group = self.add_argument_group(f"{prefix}server connection arguments", description=help)
218 group.add_argument(f"--{prefix}host", required=True, metavar="HOST", help="host name")
219 group.add_argument(f"--{prefix}port", required=True, type=int, metavar="PORT", help="port number")
220 group.add_argument(f"--{prefix}user", required=True, metavar="USER", help="user name")
221 group.add_argument(f"--{prefix}password", metavar="PWD", help="host password")
222 if include_database:
223 group.add_argument(f"--{prefix}database", required=True, metavar="NAME", help="database name")
224 self.__server_groups.append(prefix)
226 def add_log_arguments(self, add_log_file: bool = False) -> None:
227 """Adds the usual set of arguments required to set and initialise a logging system.
229 The current set includes a mutually exclusive group for the default logging level: `--verbose`,
230 `--debug` or `--log LEVEL`.
232 Args:
233 add_log_file: Add arguments to allow storing messages into a file, i.e. `--log_file` and
234 `--log_file_level`.
236 """
237 # Define the list of log levels available
238 log_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
239 # NOTE: from 3.11 this list can be changed to: logging.getLevelNamesMapping().keys()
240 # Create logging arguments group
241 group = self.add_argument_group("logging arguments")
242 # Add 3 mutually exclusive options to set the logging level
243 subgroup = group.add_mutually_exclusive_group()
244 subgroup.add_argument(
245 "-v",
246 "--verbose",
247 action="store_const",
248 const="INFO",
249 dest="log_level",
250 help="verbose mode, i.e. 'INFO' log level",
251 )
252 subgroup.add_argument(
253 "--debug",
254 action="store_const",
255 const="DEBUG",
256 dest="log_level",
257 help="debugging mode, i.e. 'DEBUG' log level",
258 )
259 subgroup.add_argument(
260 "--log",
261 choices=log_levels,
262 type=str.upper,
263 default="WARNING",
264 metavar="LEVEL",
265 dest="log_level",
266 help="level of the events to track: %(choices)s",
267 )
268 subgroup.set_defaults(log_level="WARNING")
269 if add_log_file:
270 # Add log file-related arguments
271 group.add_argument(
272 "--log_file",
273 type=lambda x: self._validate_dst_path(x, exists_ok=True),
274 metavar="PATH",
275 default=None,
276 help="log file path",
277 )
278 group.add_argument(
279 "--log_file_level",
280 choices=log_levels,
281 type=str.upper,
282 default="DEBUG",
283 metavar="LEVEL",
284 help="level of the events to track in the log file: %(choices)s",
285 )
287 def parse_args(self, *args: Any, **kwargs: Any) -> argparse.Namespace: # type: ignore[override]
288 """Extends the parent function by adding a new URL argument for every server group added.
290 The type of this new argument will be `sqlalchemy.engine.URL`. It also logs all the parsed
291 arguments for debugging purposes when logging arguments have been added.
293 """
294 arguments = super().parse_args(*args, **kwargs)
295 # Build and add an sqlalchemy.engine.URL object for every server group added
296 for prefix in self.__server_groups:
297 # Raise an error rather than overwriting when the URL argument is already present
298 if f"{prefix}url" in arguments:
299 self.error(f"argument '{prefix}url' is already present")
300 server_url = URL.create(
301 "mysql",
302 getattr(arguments, f"{prefix}user"),
303 getattr(arguments, f"{prefix}password"),
304 getattr(arguments, f"{prefix}host"),
305 getattr(arguments, f"{prefix}port"),
306 getattr(arguments, f"{prefix}database", None),
307 )
308 setattr(arguments, f"{prefix}url", server_url)
309 return arguments