Coverage for src/ensembl/utils/argparse.py: 99%

98 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-11-06 14:10 +0000

1# See the NOTICE file distributed with this work for additional information 

2# regarding copyright ownership. 

3# 

4# Licensed under the Apache License, Version 2.0 (the "License"); 

5# you may not use this file except in compliance with the License. 

6# You may obtain a copy of the License at 

7# 

8# http://www.apache.org/licenses/LICENSE-2.0 

9# 

10# Unless required by applicable law or agreed to in writing, software 

11# distributed under the License is distributed on an "AS IS" BASIS, 

12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

13# See the License for the specific language governing permissions and 

14# limitations under the License. 

15"""Provide an extended version of `argparse.ArgumentParser` with additional functionality. 

16 

17Examples: 

18 

19 >>> from pathlib import Path 

20 >>> from ensembl.util.argparse import ArgumentParser 

21 >>> parser = ArgumentParser(description="Tool description") 

22 >>> parser.add_argument_src_path("--src_file", required=True, help="Path to source file") 

23 >>> parser.add_server_arguments(help="Server to connect to") 

24 >>> args = parser.parse_args() 

25 >>> args 

26 Namespace(host='localhost', port=3826, src_file=PosixPath('/path/to/src_file.txt'), 

27 url=URL('mysql://username@localhost:3826'), user='username') 

28 

29""" 

30 

31from __future__ import annotations 

32 

33__all__ = [ 

34 "ArgumentError", 

35 "ArgumentParser", 

36] 

37 

38import argparse 

39import os 

40from pathlib import Path 

41import re 

42from typing import Any, Callable 

43 

44from sqlalchemy.engine import make_url, URL 

45 

46from ensembl.utils import StrPath 

47 

48 

49class ArgumentError(Exception): 

50 """An error from creating an argument (optional or positional).""" 

51 

52 

53class ArgumentParser(argparse.ArgumentParser): 

54 """Extends `argparse.ArgumentParser` with additional methods and functionality. 

55 

56 The default behaviour of the help text will be to display the default values on every non-required 

57 argument, i.e. optional arguments with `required=False`. 

58 

59 """ 

60 

61 def __init__(self, *args: Any, **kwargs: Any) -> None: 

62 """Extends the base class to include the information about default argument values by default.""" 

63 super().__init__(*args, **kwargs) 

64 self.formatter_class = argparse.ArgumentDefaultsHelpFormatter 

65 

66 def _validate_src_path(self, src_path: StrPath) -> Path: 

67 """Returns the path if exists and it is readable, raises an error through the parser otherwise. 

68 

69 Args: 

70 src_path: File or directory path to check. 

71 

72 """ 

73 src_path = Path(src_path) 

74 if not src_path.exists(): 

75 self.error(f"'{src_path}' not found") 

76 elif not os.access(src_path, os.R_OK): 

77 self.error(f"'{src_path}' not readable") 

78 return src_path 

79 

80 def _validate_dst_path(self, dst_path: StrPath, exists_ok: bool = False) -> Path: 

81 """Returns the path if it is writable, raises an error through the parser otherwise. 

82 

83 Args: 

84 dst_path: File or directory path to check. 

85 exists_ok: Do not raise an error during parsing if the destination path already exists. 

86 

87 """ 

88 dst_path = Path(dst_path) 

89 if dst_path.exists(): 

90 if os.access(dst_path, os.W_OK): 

91 if exists_ok: 

92 return dst_path 

93 self.error(f"'{dst_path}' already exists") 

94 else: 

95 self.error(f"'{dst_path}' is not writable") 

96 # Check if the first parent directory that exists is writable 

97 for parent_path in dst_path.parents: 97 ↛ 102line 97 didn't jump to line 102 because the loop on line 97 didn't complete

98 if parent_path.exists(): 

99 if not os.access(parent_path, os.W_OK): 

100 self.error(f"'{dst_path}' is not writable") 

101 break 

102 return dst_path 

103 

104 def _validate_number( 

105 self, 

106 value: str, 

107 value_type: Callable[[str], int | float], 

108 min_value: int | float | None, 

109 max_value: int | float | None, 

110 ) -> int | float: 

111 """Returns the numeric value if it is of the expected type and it is within the specified range. 

112 

113 Args: 

114 value: String representation of numeric value to check. 

115 value_type: Expected type of the numeric value. 

116 min_value: Minimum value constrain. If `None`, no minimum value constrain. 

117 max_value: Maximum value constrain. If `None`, no maximum value constrain. 

118 

119 """ 

120 # Check if the string representation can be converted to the expected type 

121 try: 

122 result = value_type(value) 

123 except (TypeError, ValueError): 

124 self.error(f"invalid {value_type.__name__} value: {value}") 

125 # Check if numeric value is within range 

126 if (min_value is not None) and (result < min_value): 

127 self.error(f"{value} is lower than minimum value ({min_value})") 

128 if (max_value is not None) and (result > max_value): 

129 self.error(f"{value} is greater than maximum value ({max_value})") 

130 return result 

131 

132 def add_argument(self, *args: Any, **kwargs: Any) -> None: # type: ignore[override] 

133 """Extends the parent function by excluding the default value in the help text when not provided. 

134 

135 Only applied to required arguments without a default value, i.e. positional arguments or optional 

136 arguments with `required=True`. 

137 

138 """ 

139 if kwargs.get("required", False): 

140 kwargs.setdefault("default", argparse.SUPPRESS) 

141 super().add_argument(*args, **kwargs) 

142 

143 def add_argument_src_path(self, *args: Any, **kwargs: Any) -> None: 

144 """Adds `pathlib.Path` argument, checking if it exists and it is readable at parsing time. 

145 

146 If "metavar" is not defined, it is added with "PATH" as value to improve help text readability. 

147 

148 """ 

149 kwargs.setdefault("metavar", "PATH") 

150 kwargs["type"] = self._validate_src_path 

151 self.add_argument(*args, **kwargs) 

152 

153 def add_argument_dst_path(self, *args: Any, exists_ok: bool = True, **kwargs: Any) -> None: 

154 """Adds `pathlib.Path` argument, checking if it is writable at parsing time. 

155 

156 If "metavar" is not defined it is added with "PATH" as value to improve help text readability. 

157 

158 Args: 

159 exists_ok: Do not raise an error if the destination path already exists. 

160 

161 """ 

162 kwargs.setdefault("metavar", "PATH") 

163 kwargs["type"] = lambda x: self._validate_dst_path(x, exists_ok) 

164 self.add_argument(*args, **kwargs) 

165 

166 def add_argument_url(self, *args: Any, **kwargs: Any) -> None: 

167 """Adds `sqlalchemy.engine.URL` argument. 

168 

169 If "metavar" is not defined it is added with "URI" as value to improve help text readability. 

170 

171 """ 

172 kwargs.setdefault("metavar", "URI") 

173 kwargs["type"] = make_url 

174 self.add_argument(*args, **kwargs) 

175 

176 # pylint: disable=redefined-builtin 

177 def add_numeric_argument( 

178 self, 

179 *args: Any, 

180 type: Callable[[str], int | float] = float, 

181 min_value: int | float | None = None, 

182 max_value: int | float | None = None, 

183 **kwargs: Any, 

184 ) -> None: 

185 """Adds a numeric argument with constrains on its type and its minimum or maximum value. 

186 

187 Note that the default value (if defined) is not checked unless the argument is an optional argument 

188 and no value is provided in the command line. 

189 

190 Args: 

191 type: Type to convert the argument value to when parsing. 

192 min_value: Minimum value constrain. If `None`, no minimum value constrain. 

193 max_value: Maximum value constrain. If `None`, no maximum value constrain. 

194 

195 """ 

196 # If both minimum and maximum values are defined, ensure min_value <= max_value 

197 if (min_value is not None) and (max_value is not None) and (min_value > max_value): 

198 raise ArgumentError("minimum value is greater than maximum value") 

199 # Add lambda function to check numeric constrains when parsing argument 

200 kwargs["type"] = lambda x: self._validate_number(x, type, min_value, max_value) 

201 self.add_argument(*args, **kwargs) 

202 

203 # pylint: disable=redefined-builtin 

204 def add_server_arguments( 

205 self, prefix: str = "", include_database: bool = False, help: str | None = None 

206 ) -> None: 

207 """Adds the usual set of arguments needed to connect to a server, i.e. `--host`, `--port`, `--user` 

208 and `--password` (optional). 

209 

210 Note that the parser will assume this is a MySQL server. 

211 

212 Args: 

213 prefix: Prefix to add the each argument, e.g. if prefix is `src_`, the arguments will be 

214 `--src_host`, etc. 

215 include_database: Include `--database` argument. 

216 help: Description message to include for this set of arguments. 

217 

218 """ 

219 group = self.add_argument_group(f"{prefix}server connection arguments", description=help) 

220 group.add_argument( 

221 f"--{prefix}host", required=True, metavar="HOST", default=argparse.SUPPRESS, help="host name" 

222 ) 

223 group.add_argument( 

224 f"--{prefix}port", 

225 required=True, 

226 type=int, 

227 metavar="PORT", 

228 default=argparse.SUPPRESS, 

229 help="port number", 

230 ) 

231 group.add_argument( 

232 f"--{prefix}user", required=True, metavar="USER", default=argparse.SUPPRESS, help="user name" 

233 ) 

234 group.add_argument(f"--{prefix}password", metavar="PWD", help="host password") 

235 if include_database: 

236 group.add_argument( 

237 f"--{prefix}database", 

238 required=True, 

239 metavar="NAME", 

240 default=argparse.SUPPRESS, 

241 help="database name", 

242 ) 

243 

244 def add_log_arguments(self, add_log_file: bool = False) -> None: 

245 """Adds the usual set of arguments required to set and initialise a logging system. 

246 

247 The current set includes a mutually exclusive group for the default logging level: `--verbose`, 

248 `--debug` or `--log LEVEL`. 

249 

250 Args: 

251 add_log_file: Add arguments to allow storing messages into a file, i.e. `--log_file` and 

252 `--log_file_level`. 

253 

254 """ 

255 # Define the list of log levels available 

256 log_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] 

257 # NOTE: from 3.11 this list can be changed to: logging.getLevelNamesMapping().keys() 

258 # Create logging arguments group 

259 group = self.add_argument_group("logging arguments") 

260 # Add 3 mutually exclusive options to set the logging level 

261 subgroup = group.add_mutually_exclusive_group() 

262 subgroup.add_argument( 

263 "-v", 

264 "--verbose", 

265 action="store_const", 

266 const="INFO", 

267 dest="log_level", 

268 help="verbose mode, i.e. 'INFO' log level", 

269 ) 

270 subgroup.add_argument( 

271 "--debug", 

272 action="store_const", 

273 const="DEBUG", 

274 dest="log_level", 

275 help="debugging mode, i.e. 'DEBUG' log level", 

276 ) 

277 subgroup.add_argument( 

278 "--log", 

279 choices=log_levels, 

280 type=str.upper, 

281 default="WARNING", 

282 metavar="LEVEL", 

283 dest="log_level", 

284 help="level of the events to track: %(choices)s", 

285 ) 

286 subgroup.set_defaults(log_level="WARNING") 

287 if add_log_file: 

288 # Add log file-related arguments 

289 group.add_argument( 

290 "--log_file", 

291 type=lambda x: self._validate_dst_path(x, exists_ok=True), 

292 metavar="PATH", 

293 default=None, 

294 help="log file path", 

295 ) 

296 group.add_argument( 

297 "--log_file_level", 

298 choices=log_levels, 

299 type=str.upper, 

300 default="DEBUG", 

301 metavar="LEVEL", 

302 help="level of the events to track in the log file: %(choices)s", 

303 ) 

304 

305 def parse_args(self, *args: Any, **kwargs: Any) -> argparse.Namespace: # type: ignore[override] 

306 """Extends the parent function by adding a new URL argument for every server group added. 

307 

308 The type of this new argument will be `sqlalchemy.engine.URL`. It also logs all the parsed 

309 arguments for debugging purposes when logging arguments have been added. 

310 

311 """ 

312 arguments = super().parse_args(*args, **kwargs) 

313 # Build and add an sqlalchemy.engine.URL object for every server group added 

314 pattern = re.compile(r"([\w-]*)host$") 

315 server_prefixes = [x.group(1) for x in map(pattern.match, vars(arguments)) if x] 

316 for prefix in server_prefixes: 

317 # Raise an error rather than overwriting when the URL argument is already present 

318 if f"{prefix}url" in arguments: 

319 self.error(f"argument '{prefix}url' is already present") 

320 try: 

321 server_url = URL.create( 

322 "mysql", 

323 getattr(arguments, f"{prefix}user"), 

324 getattr(arguments, f"{prefix}password"), 

325 getattr(arguments, f"{prefix}host"), 

326 getattr(arguments, f"{prefix}port"), 

327 getattr(arguments, f"{prefix}database", None), 

328 ) 

329 except AttributeError: 

330 # Not a database server host argument 

331 continue 

332 setattr(arguments, f"{prefix}url", server_url) 

333 return arguments