Coverage for src / ensembl / utils / argparse.py: 100%

100 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:45 +0000

1# See the NOTICE file distributed with this work for additional information 

2# regarding copyright ownership. 

3# 

4# Licensed under the Apache License, Version 2.0 (the "License"); 

5# you may not use this file except in compliance with the License. 

6# You may obtain a copy of the License at 

7# 

8# http://www.apache.org/licenses/LICENSE-2.0 

9# 

10# Unless required by applicable law or agreed to in writing, software 

11# distributed under the License is distributed on an "AS IS" BASIS, 

12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

13# See the License for the specific language governing permissions and 

14# limitations under the License. 

15"""Provide an extended version of `argparse.ArgumentParser` with additional functionality. 

16 

17Examples: 

18 

19 >>> from pathlib import Path 

20 >>> from ensembl.util.argparse import ArgumentParser 

21 >>> parser = ArgumentParser(description="Tool description") 

22 >>> parser.add_argument_src_path("--src_file", required=True, help="Path to source file") 

23 >>> parser.add_server_arguments(help="Server to connect to") 

24 >>> args = parser.parse_args() 

25 >>> args 

26 Namespace(host='localhost', port=3826, src_file=PosixPath('/path/to/src_file.txt'), 

27 url=URL('mysql://username@localhost:3826'), user='username') 

28 

29""" 

30 

31from __future__ import annotations 

32 

33__all__ = [ 

34 "ArgumentError", 

35 "ArgumentParser", 

36] 

37 

38import argparse 

39import os 

40from pathlib import Path 

41import re 

42from typing import Any, Callable 

43 

44from sqlalchemy.engine import make_url, URL 

45 

46from ensembl.utils import StrPath 

47 

48 

49class ArgumentError(Exception): 

50 """An error from creating an argument (optional or positional).""" 

51 

52 

53class ArgumentParser(argparse.ArgumentParser): 

54 """Extends `argparse.ArgumentParser` with additional methods and functionality. 

55 

56 The default behaviour of the help text will be to display the default values on every non-required 

57 argument, i.e. optional arguments with `required=False`. 

58 

59 """ 

60 

61 def __init__(self, *args: Any, **kwargs: Any) -> None: 

62 """Extends the base class to include the information about default argument values by default.""" 

63 super().__init__(*args, **kwargs) 

64 self.formatter_class = argparse.ArgumentDefaultsHelpFormatter 

65 

66 def _validate_src_path(self, src_path: StrPath) -> Path: 

67 """Returns the path if exists and it is readable, raises an error through the parser otherwise. 

68 

69 Args: 

70 src_path: File or directory path to check. 

71 

72 """ 

73 src_path = Path(src_path) 

74 if not src_path.exists(): 

75 self.error(f"'{src_path}' not found") 

76 elif not os.access(src_path, os.R_OK): 

77 self.error(f"'{src_path}' not readable") 

78 return src_path 

79 

80 def _validate_dst_path(self, dst_path: StrPath, exists_ok: bool = False) -> Path: 

81 """Returns the path if it is writable, raises an error through the parser otherwise. 

82 

83 Args: 

84 dst_path: File or directory path to check. 

85 exists_ok: Do not raise an error during parsing if the destination path already exists. 

86 

87 """ 

88 dst_path = Path(dst_path) 

89 if dst_path.exists(): 

90 if os.access(dst_path, os.W_OK): 

91 if exists_ok: 

92 return dst_path 

93 self.error(f"'{dst_path}' already exists") 

94 else: 

95 self.error(f"'{dst_path}' is not writable") 

96 # Check if the first parent directory that exists is writable 

97 for parent_path in dst_path.parents: 

98 if parent_path.exists(): 

99 if not os.access(parent_path, os.W_OK): 

100 self.error(f"'{dst_path}' is not writable") 

101 break 

102 else: 

103 self.error(f"'{dst_path}' is not writable") 

104 return dst_path 

105 

106 def _validate_number( 

107 self, 

108 value: str, 

109 value_type: Callable[[str], int | float], 

110 min_value: int | float | None, 

111 max_value: int | float | None, 

112 ) -> int | float: 

113 """Returns the numeric value if it is of the expected type and it is within the specified range. 

114 

115 Args: 

116 value: String representation of numeric value to check. 

117 value_type: Expected type of the numeric value. 

118 min_value: Minimum value constrain. If `None`, no minimum value constrain. 

119 max_value: Maximum value constrain. If `None`, no maximum value constrain. 

120 

121 """ 

122 # Check if the string representation can be converted to the expected type 

123 try: 

124 result = value_type(value) 

125 except (TypeError, ValueError): 

126 self.error(f"invalid {value_type.__name__} value: {value}") 

127 # Check if numeric value is within range 

128 if (min_value is not None) and (result < min_value): 

129 self.error(f"{value} is lower than minimum value ({min_value})") 

130 if (max_value is not None) and (result > max_value): 

131 self.error(f"{value} is greater than maximum value ({max_value})") 

132 return result 

133 

134 def add_argument(self, *args: Any, **kwargs: Any) -> None: # type: ignore[override] 

135 """Extends the parent function by excluding the default value in the help text when not provided. 

136 

137 Only applied to required arguments without a default value, i.e. positional arguments or optional 

138 arguments with `required=True`. 

139 

140 """ 

141 if kwargs.get("required", False): 

142 kwargs.setdefault("default", argparse.SUPPRESS) 

143 super().add_argument(*args, **kwargs) 

144 

145 def add_argument_src_path(self, *args: Any, **kwargs: Any) -> None: 

146 """Adds `pathlib.Path` argument, checking if it exists and it is readable at parsing time. 

147 

148 If "metavar" is not defined, it is added with "PATH" as value to improve help text readability. 

149 

150 """ 

151 kwargs.setdefault("metavar", "PATH") 

152 kwargs["type"] = self._validate_src_path 

153 self.add_argument(*args, **kwargs) 

154 

155 def add_argument_dst_path(self, *args: Any, exists_ok: bool = True, **kwargs: Any) -> None: 

156 """Adds `pathlib.Path` argument, checking if it is writable at parsing time. 

157 

158 If "metavar" is not defined it is added with "PATH" as value to improve help text readability. 

159 

160 Args: 

161 exists_ok: Do not raise an error if the destination path already exists. 

162 

163 """ 

164 kwargs.setdefault("metavar", "PATH") 

165 kwargs["type"] = lambda x: self._validate_dst_path(x, exists_ok) 

166 self.add_argument(*args, **kwargs) 

167 

168 def add_argument_url(self, *args: Any, **kwargs: Any) -> None: 

169 """Adds `sqlalchemy.engine.URL` argument. 

170 

171 If "metavar" is not defined it is added with "URI" as value to improve help text readability. 

172 

173 """ 

174 kwargs.setdefault("metavar", "URI") 

175 kwargs["type"] = make_url 

176 self.add_argument(*args, **kwargs) 

177 

178 # pylint: disable=redefined-builtin 

179 def add_numeric_argument( 

180 self, 

181 *args: Any, 

182 type: Callable[[str], int | float] = float, 

183 min_value: int | float | None = None, 

184 max_value: int | float | None = None, 

185 **kwargs: Any, 

186 ) -> None: 

187 """Adds a numeric argument with constrains on its type and its minimum or maximum value. 

188 

189 Note that the default value (if defined) is not checked unless the argument is an optional argument 

190 and no value is provided in the command line. 

191 

192 Args: 

193 type: Type to convert the argument value to when parsing. 

194 min_value: Minimum value constrain. If `None`, no minimum value constrain. 

195 max_value: Maximum value constrain. If `None`, no maximum value constrain. 

196 

197 """ 

198 # If both minimum and maximum values are defined, ensure min_value <= max_value 

199 if (min_value is not None) and (max_value is not None) and (min_value > max_value): 

200 raise ArgumentError("minimum value is greater than maximum value") 

201 # Add lambda function to check numeric constrains when parsing argument 

202 kwargs["type"] = lambda x: self._validate_number(x, type, min_value, max_value) 

203 self.add_argument(*args, **kwargs) 

204 

205 # pylint: disable=redefined-builtin 

206 def add_server_arguments( 

207 self, prefix: str = "", include_database: bool = False, help: str | None = None 

208 ) -> None: 

209 """Adds the usual set of arguments needed to connect to a server, i.e. `--host`, `--port`, `--user` 

210 and `--password` (optional). 

211 

212 Note that the parser will assume this is a MySQL server. 

213 

214 Warning: 

215 Avoid passing ``--password`` directly on the command line as it will be visible in the 

216 process list and shell history. Use an environment variable or an interactive prompt via 

217 ``getpass`` instead. 

218 

219 Args: 

220 prefix: Prefix to add the each argument, e.g. if prefix is `src_`, the arguments will be 

221 `--src_host`, etc. 

222 include_database: Include `--database` argument. 

223 help: Description message to include for this set of arguments. 

224 

225 """ 

226 group = self.add_argument_group(f"{prefix}server connection arguments", description=help) 

227 group.add_argument( 

228 f"--{prefix}host", required=True, metavar="HOST", default=argparse.SUPPRESS, help="host name" 

229 ) 

230 group.add_argument( 

231 f"--{prefix}port", 

232 required=True, 

233 type=int, 

234 metavar="PORT", 

235 default=argparse.SUPPRESS, 

236 help="port number", 

237 ) 

238 group.add_argument( 

239 f"--{prefix}user", required=True, metavar="USER", default=argparse.SUPPRESS, help="user name" 

240 ) 

241 group.add_argument(f"--{prefix}password", metavar="PWD", help="host password") 

242 if include_database: 

243 group.add_argument( 

244 f"--{prefix}database", 

245 required=True, 

246 metavar="NAME", 

247 default=argparse.SUPPRESS, 

248 help="database name", 

249 ) 

250 

251 def add_log_arguments(self, add_log_file: bool = False) -> None: 

252 """Adds the usual set of arguments required to set and initialise a logging system. 

253 

254 The current set includes a mutually exclusive group for the default logging level: `--verbose`, 

255 `--debug`, `--quiet` or `--log LEVEL`. 

256 

257 Args: 

258 add_log_file: Add arguments to allow storing messages into a file, i.e. `--log_file` and 

259 `--log_file_level`. 

260 

261 """ 

262 # Define the list of log levels available 

263 log_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] 

264 # NOTE: from 3.11 this list can be changed to: logging.getLevelNamesMapping().keys() 

265 # Create logging arguments group 

266 group = self.add_argument_group("logging arguments") 

267 # Add 3 mutually exclusive options to set the logging level 

268 subgroup = group.add_mutually_exclusive_group() 

269 subgroup.add_argument( 

270 "-v", 

271 "--verbose", 

272 action="store_const", 

273 const="INFO", 

274 dest="log_level", 

275 help="verbose mode, i.e. 'INFO' log level", 

276 ) 

277 subgroup.add_argument( 

278 "--debug", 

279 action="store_const", 

280 const="DEBUG", 

281 dest="log_level", 

282 help="debugging mode, i.e. 'DEBUG' log level", 

283 ) 

284 subgroup.add_argument( 

285 "--quiet", 

286 action="store_const", 

287 const="CRITICAL", 

288 dest="log_level", 

289 help="quiet mode, i.e. 'CRITICAL' log level", 

290 ) 

291 subgroup.add_argument( 

292 "--log", 

293 choices=log_levels, 

294 type=str.upper, 

295 default="WARNING", 

296 metavar="LEVEL", 

297 dest="log_level", 

298 help="level of the events to track: %(choices)s", 

299 ) 

300 subgroup.set_defaults(log_level="WARNING") 

301 if add_log_file: 

302 # Add log file-related arguments 

303 group.add_argument( 

304 "--log_file", 

305 type=lambda x: self._validate_dst_path(x, exists_ok=True), 

306 metavar="PATH", 

307 default=None, 

308 help="log file path", 

309 ) 

310 group.add_argument( 

311 "--log_file_level", 

312 choices=log_levels, 

313 type=str.upper, 

314 default="DEBUG", 

315 metavar="LEVEL", 

316 help="level of the events to track in the log file: %(choices)s", 

317 ) 

318 

319 def parse_args(self, *args: Any, **kwargs: Any) -> argparse.Namespace: # type: ignore[override] 

320 """Extends the parent function by adding a new URL argument for every server group added. 

321 

322 The type of this new argument will be `sqlalchemy.engine.URL`. It also logs all the parsed 

323 arguments for debugging purposes when logging arguments have been added. 

324 

325 """ 

326 arguments = super().parse_args(*args, **kwargs) 

327 # Build and add an sqlalchemy.engine.URL object for every server group added 

328 pattern = re.compile(r"([\w-]*)host$") 

329 server_prefixes = [x.group(1) for x in map(pattern.match, vars(arguments)) if x] 

330 for prefix in server_prefixes: 

331 # Raise an error rather than overwriting when the URL argument is already present 

332 if f"{prefix}url" in arguments: 

333 self.error(f"argument '{prefix}url' is already present") 

334 try: 

335 server_url = URL.create( 

336 "mysql", 

337 getattr(arguments, f"{prefix}user"), 

338 getattr(arguments, f"{prefix}password"), 

339 getattr(arguments, f"{prefix}host"), 

340 getattr(arguments, f"{prefix}port"), 

341 getattr(arguments, f"{prefix}database", None), 

342 ) 

343 except AttributeError: 

344 # Not a database server host argument 

345 continue 

346 setattr(arguments, f"{prefix}url", server_url) 

347 return arguments