Coverage for src/ensembl/utils/argparse.py: 99%

94 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-05 15:47 +0000

1# See the NOTICE file distributed with this work for additional information 

2# regarding copyright ownership. 

3# 

4# Licensed under the Apache License, Version 2.0 (the "License"); 

5# you may not use this file except in compliance with the License. 

6# You may obtain a copy of the License at 

7# 

8# http://www.apache.org/licenses/LICENSE-2.0 

9# 

10# Unless required by applicable law or agreed to in writing, software 

11# distributed under the License is distributed on an "AS IS" BASIS, 

12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

13# See the License for the specific language governing permissions and 

14# limitations under the License. 

15"""Provide an extended version of `argparse.ArgumentParser` with additional functionality. 

16 

17Examples: 

18 

19 >>> from pathlib import Path 

20 >>> from ensembl.util.argparse import ArgumentParser 

21 >>> parser = ArgumentParser(description="Tool description") 

22 >>> parser.add_argument_src_path("--src_file", required=True, help="Path to source file") 

23 >>> parser.add_server_arguments(help="Server to connect to") 

24 >>> args = parser.parse_args() 

25 >>> args 

26 Namespace(host='myserver', port=3826, src_file=PosixPath('/path/to/src_file.txt'), 

27 url=URL('mysql://username@myserver:3826'), user='username') 

28 

29""" 

30 

31from __future__ import annotations 

32 

33__all__ = [ 

34 "ArgumentError", 

35 "ArgumentParser", 

36] 

37 

38import argparse 

39import os 

40from pathlib import Path 

41from typing import Any, Callable 

42 

43from sqlalchemy.engine import make_url, URL 

44 

45from ensembl.utils import StrPath 

46 

47 

48class ArgumentError(Exception): 

49 """An error from creating an argument (optional or positional).""" 

50 

51 

52class ArgumentParser(argparse.ArgumentParser): 

53 """Extends `argparse.ArgumentParser` with additional methods and functionality. 

54 

55 The default behaviour of the help text will be to display the default values on every non-required 

56 argument, i.e. optional arguments with `required=False`. 

57 

58 """ 

59 

60 def __init__(self, *args: Any, **kwargs: Any) -> None: 

61 """Extends the base class to include the information about default argument values by default.""" 

62 super().__init__(*args, **kwargs) 

63 self.formatter_class = argparse.ArgumentDefaultsHelpFormatter 

64 self.__server_groups: list[str] = [] 

65 

66 def _validate_src_path(self, src_path: StrPath) -> Path: 

67 """Returns the path if exists and it is readable, raises an error through the parser otherwise. 

68 

69 Args: 

70 src_path: File or directory path to check. 

71 

72 """ 

73 src_path = Path(src_path) 

74 if not src_path.exists(): 

75 self.error(f"'{src_path}' not found") 

76 elif not os.access(src_path, os.R_OK): 

77 self.error(f"'{src_path}' not readable") 

78 return src_path 

79 

80 def _validate_dst_path(self, dst_path: StrPath, exists_ok: bool = False) -> Path: 

81 """Returns the path if it is writable, raises an error through the parser otherwise. 

82 

83 Args: 

84 dst_path: File or directory path to check. 

85 exists_ok: Do not raise an error during parsing if the destination path already exists. 

86 

87 """ 

88 dst_path = Path(dst_path) 

89 if dst_path.exists(): 

90 if os.access(dst_path, os.W_OK): 

91 if exists_ok: 

92 return dst_path 

93 self.error(f"'{dst_path}' already exists") 

94 else: 

95 self.error(f"'{dst_path}' is not writable") 

96 # Check if the first parent directory that exists is writable 

97 for parent_path in dst_path.parents: 97 ↛ 102line 97 didn't jump to line 102 because the loop on line 97 didn't complete

98 if parent_path.exists(): 

99 if not os.access(parent_path, os.W_OK): 

100 self.error(f"'{dst_path}' is not writable") 

101 break 

102 return dst_path 

103 

104 def _validate_number( 

105 self, 

106 value: str, 

107 value_type: Callable[[str], int | float], 

108 min_value: int | float | None, 

109 max_value: int | float | None, 

110 ) -> int | float: 

111 """Returns the numeric value if it is of the expected type and it is within the specified range. 

112 

113 Args: 

114 value: String representation of numeric value to check. 

115 value_type: Expected type of the numeric value. 

116 min_value: Minimum value constrain. If `None`, no minimum value constrain. 

117 max_value: Maximum value constrain. If `None`, no maximum value constrain. 

118 

119 """ 

120 # Check if the string representation can be converted to the expected type 

121 try: 

122 result = value_type(value) 

123 except (TypeError, ValueError): 

124 self.error(f"invalid {value_type.__name__} value: {value}") 

125 # Check if numeric value is within range 

126 if (min_value is not None) and (result < min_value): 

127 self.error(f"{value} is lower than minimum value ({min_value})") 

128 if (max_value is not None) and (result > max_value): 

129 self.error(f"{value} is greater than maximum value ({max_value})") 

130 return result 

131 

132 def add_argument(self, *args: Any, **kwargs: Any) -> None: # type: ignore[override] 

133 """Extends the parent function by excluding the default value in the help text when not provided. 

134 

135 Only applied to required arguments without a default value, i.e. positional arguments or optional 

136 arguments with `required=True`. 

137 

138 """ 

139 if kwargs.get("required", False): 

140 kwargs.setdefault("default", argparse.SUPPRESS) 

141 super().add_argument(*args, **kwargs) 

142 

143 def add_argument_src_path(self, *args: Any, **kwargs: Any) -> None: 

144 """Adds `pathlib.Path` argument, checking if it exists and it is readable at parsing time. 

145 

146 If "metavar" is not defined, it is added with "PATH" as value to improve help text readability. 

147 

148 """ 

149 kwargs.setdefault("metavar", "PATH") 

150 kwargs["type"] = self._validate_src_path 

151 self.add_argument(*args, **kwargs) 

152 

153 def add_argument_dst_path(self, *args: Any, exists_ok: bool = True, **kwargs: Any) -> None: 

154 """Adds `pathlib.Path` argument, checking if it is writable at parsing time. 

155 

156 If "metavar" is not defined it is added with "PATH" as value to improve help text readability. 

157 

158 Args: 

159 exists_ok: Do not raise an error if the destination path already exists. 

160 

161 """ 

162 kwargs.setdefault("metavar", "PATH") 

163 kwargs["type"] = lambda x: self._validate_dst_path(x, exists_ok) 

164 self.add_argument(*args, **kwargs) 

165 

166 def add_argument_url(self, *args: Any, **kwargs: Any) -> None: 

167 """Adds `sqlalchemy.engine.URL` argument. 

168 

169 If "metavar" is not defined it is added with "URI" as value to improve help text readability. 

170 

171 """ 

172 kwargs.setdefault("metavar", "URI") 

173 kwargs["type"] = make_url 

174 self.add_argument(*args, **kwargs) 

175 

176 # pylint: disable=redefined-builtin 

177 def add_numeric_argument( 

178 self, 

179 *args: Any, 

180 type: Callable[[str], int | float] = float, 

181 min_value: int | float | None = None, 

182 max_value: int | float | None = None, 

183 **kwargs: Any, 

184 ) -> None: 

185 """Adds a numeric argument with constrains on its type and its minimum or maximum value. 

186 

187 Note that the default value (if defined) is not checked unless the argument is an optional argument 

188 and no value is provided in the command line. 

189 

190 Args: 

191 type: Type to convert the argument value to when parsing. 

192 min_value: Minimum value constrain. If `None`, no minimum value constrain. 

193 max_value: Maximum value constrain. If `None`, no maximum value constrain. 

194 

195 """ 

196 # If both minimum and maximum values are defined, ensure min_value <= max_value 

197 if (min_value is not None) and (max_value is not None) and (min_value > max_value): 

198 raise ArgumentError("minimum value is greater than maximum value") 

199 # Add lambda function to check numeric constrains when parsing argument 

200 kwargs["type"] = lambda x: self._validate_number(x, type, min_value, max_value) 

201 self.add_argument(*args, **kwargs) 

202 

203 # pylint: disable=redefined-builtin 

204 def add_server_arguments(self, prefix: str = "", include_database: bool = False, help: str = "") -> None: 

205 """Adds the usual set of arguments needed to connect to a server, i.e. `--host`, `--port`, `--user` 

206 and `--password` (optional). 

207 

208 Note that the parser will assume this is a MySQL server. 

209 

210 Args: 

211 prefix: Prefix to add the each argument, e.g. if prefix is `src_`, the arguments will be 

212 `--src_host`, etc. 

213 include_database: Include `--database` argument. 

214 help: Description message to include for this set of arguments. 

215 

216 """ 

217 group = self.add_argument_group(f"{prefix}server connection arguments", description=help) 

218 group.add_argument(f"--{prefix}host", required=True, metavar="HOST", help="host name") 

219 group.add_argument(f"--{prefix}port", required=True, type=int, metavar="PORT", help="port number") 

220 group.add_argument(f"--{prefix}user", required=True, metavar="USER", help="user name") 

221 group.add_argument(f"--{prefix}password", metavar="PWD", help="host password") 

222 if include_database: 

223 group.add_argument(f"--{prefix}database", required=True, metavar="NAME", help="database name") 

224 self.__server_groups.append(prefix) 

225 

226 def add_log_arguments(self, add_log_file: bool = False) -> None: 

227 """Adds the usual set of arguments required to set and initialise a logging system. 

228 

229 The current set includes a mutually exclusive group for the default logging level: `--verbose`, 

230 `--debug` or `--log LEVEL`. 

231 

232 Args: 

233 add_log_file: Add arguments to allow storing messages into a file, i.e. `--log_file` and 

234 `--log_file_level`. 

235 

236 """ 

237 # Define the list of log levels available 

238 log_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] 

239 # NOTE: from 3.11 this list can be changed to: logging.getLevelNamesMapping().keys() 

240 # Create logging arguments group 

241 group = self.add_argument_group("logging arguments") 

242 # Add 3 mutually exclusive options to set the logging level 

243 subgroup = group.add_mutually_exclusive_group() 

244 subgroup.add_argument( 

245 "-v", 

246 "--verbose", 

247 action="store_const", 

248 const="INFO", 

249 dest="log_level", 

250 help="verbose mode, i.e. 'INFO' log level", 

251 ) 

252 subgroup.add_argument( 

253 "--debug", 

254 action="store_const", 

255 const="DEBUG", 

256 dest="log_level", 

257 help="debugging mode, i.e. 'DEBUG' log level", 

258 ) 

259 subgroup.add_argument( 

260 "--log", 

261 choices=log_levels, 

262 type=str.upper, 

263 default="WARNING", 

264 metavar="LEVEL", 

265 dest="log_level", 

266 help="level of the events to track: %(choices)s", 

267 ) 

268 subgroup.set_defaults(log_level="WARNING") 

269 if add_log_file: 

270 # Add log file-related arguments 

271 group.add_argument( 

272 "--log_file", 

273 type=lambda x: self._validate_dst_path(x, exists_ok=True), 

274 metavar="PATH", 

275 default=None, 

276 help="log file path", 

277 ) 

278 group.add_argument( 

279 "--log_file_level", 

280 choices=log_levels, 

281 type=str.upper, 

282 default="DEBUG", 

283 metavar="LEVEL", 

284 help="level of the events to track in the log file: %(choices)s", 

285 ) 

286 

287 def parse_args(self, *args: Any, **kwargs: Any) -> argparse.Namespace: # type: ignore[override] 

288 """Extends the parent function by adding a new URL argument for every server group added. 

289 

290 The type of this new argument will be `sqlalchemy.engine.URL`. It also logs all the parsed 

291 arguments for debugging purposes when logging arguments have been added. 

292 

293 """ 

294 arguments = super().parse_args(*args, **kwargs) 

295 # Build and add an sqlalchemy.engine.URL object for every server group added 

296 for prefix in self.__server_groups: 

297 # Raise an error rather than overwriting when the URL argument is already present 

298 if f"{prefix}url" in arguments: 

299 self.error(f"argument '{prefix}url' is already present") 

300 server_url = URL.create( 

301 "mysql", 

302 getattr(arguments, f"{prefix}user"), 

303 getattr(arguments, f"{prefix}password"), 

304 getattr(arguments, f"{prefix}host"), 

305 getattr(arguments, f"{prefix}port"), 

306 getattr(arguments, f"{prefix}database", None), 

307 ) 

308 setattr(arguments, f"{prefix}url", server_url) 

309 return arguments