Coverage for src/scrilla/files.py: 55%

217 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2022-07-18 18:14 +0000

1# This file is part of scrilla: https://github.com/chinchalinchin/scrilla. 

2 

3# scrilla is free software: you can redistribute it and/or modify 

4# it under the terms of the GNU General Public License version 3 

5# as published by the Free Software Foundation. 

6 

7# scrilla is distributed in the hope that it will be useful, 

8# but WITHOUT ANY WARRANTY; without even the implied warranty of 

9# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

10# GNU General Public License for more details. 

11 

12# You should have received a copy of the GNU General Public License 

13# along with scrilla. If not, see <https://www.gnu.org/licenses/> 

14# or <https://github.com/chinchalinchin/scrilla/blob/develop/main/LICENSE>. 

15 

16 

17""" 

18` files` is in charge of all application file handling. In addition, this module handles requests for large csv files retrieved from external services. The metadata files from 'AlphaVantage' and 'Quandl' are returned as zipped csv files. The functions within in this module perform all the tasks necessary for parsing this response for the application. 

19""" 

20import os 

21import io 

22import json 

23import csv 

24import zipfile 

25from typing import Any, Dict, Union 

26import requests 

27 

28from scrilla import settings 

29from scrilla.cloud import aws 

30from scrilla.static import keys, constants, formats 

31from scrilla.util import outputter, helper, errors 

32 

33 

34logger = outputter.Logger("scrilla.files", settings.LOG_LEVEL) 

35 

36static_tickers_blob, static_econ_blob, static_crypto_blob = None, None, None 

37 

38 

39def memory_json_skeleton() -> dict: 

40 return { 

41 'static': False, 

42 'cache': { 

43 'sqlite': { 

44 'prices': False, 

45 'interest': False, 

46 'correlations': False, 

47 'profile': False 

48 }, 

49 'dynamodb': { 

50 'prices': False, 

51 'interest': False, 

52 'correlations': False, 

53 'profile': False 

54 } 

55 } 

56 } 

57 

58 

59def save_memory_json(persist: Union[dict, None] = None): 

60 if persist is None or not isinstance(persist, dict): 60 ↛ 61line 60 didn't jump to line 61, because the condition on line 60 was never true

61 return 

62 save_file(persist, settings.MEMORY_FILE) 

63 

64 

65def get_memory_json(): 

66 if os.path.isfile(settings.MEMORY_FILE): 

67 memory_json = load_file(settings.MEMORY_FILE) 

68 return memory_json 

69 return memory_json_skeleton() 

70 

71 

72def load_file(file_name: str) -> Any: 

73 """ 

74 Infers the file extensions from the provided `file_name` and parses the file appropriately.  

75 """ 

76 ext = file_name.split('.')[-1] 

77 with open(file_name, 'r') as infile: 

78 if ext == "json": 

79 return json.load(infile) 

80 return infile.read() 

81 # TODO: implement other file loading extensions 

82 

83 

84def save_file(file_to_save: Dict[str, Any], file_name: str) -> bool: 

85 ext = file_name.split('.')[-1] 

86 try: 

87 with open(file_name, 'w') as outfile: 

88 if ext == "json": 88 ↛ 90line 88 didn't jump to line 90, because the condition on line 88 was never false

89 json.dump(file_to_save, outfile) 

90 elif ext == "csv": 

91 # TODO: assume input is dict since ll functions in library return dict. 

92 writer = csv.DictWriter(outfile, file_to_save.keys()) 

93 writer.writeheader() 

94 # TODO: implement other file saving extensions. 

95 return True 

96 except OSError as e: 

97 logger.error(e, 'save_file') 

98 return False 

99 

100 

101def set_credentials(value: str, which_key: str) -> bool: 

102 file_name = os.path.join( 

103 settings.COMMON_DIR, f'{which_key}.{settings.FILE_EXT}') 

104 if settings.FILE_EXT == 'json': 

105 key_dict = {which_key: value} 

106 return save_file(file_to_save=key_dict, file_name=file_name) 

107 

108 

109def get_credentials(which_key: str) -> str: 

110 file_name = os.path.join( 

111 settings.COMMON_DIR, f'{which_key}.{settings.FILE_EXT}') 

112 return load_file(file_name=file_name) 

113 

114 

115def parse_csv_response_column(column: int, url: str, firstRowHeader: str = None, savefile: str = None, zipped: str = None): 

116 """ 

117 Dumps a column from a CSV located at `url` into a JSON file located at `savefile`. The csv file may be zipped, in which case the function needs to made aware of the filename within the zipfile through the parameter `zipped`. 

118 

119 Parameters 

120 ---------- 

121 1. **column**: ``int`` 

122 Index of the column you wish to retrieve from the response. 

123 2. **url**: ``str`` 

124 The url, already formatted with appropriate query and key, that will respond with the csv file, zipped or unzipped (see zipped argument for more info), you wish to parse. 

125 3. **firstRowHeader**: ``str``  

126 *Optional*. name of the header for the column you wish to parse. if specified, the parsed response will ignore the row header. Do not include if you wish to have the row header in the return result. 

127 4. **savefile**: ``str`` 

128 Optional. the absolute path of the file you wish to save the parsed response column to. 

129 5. **zipped** : ``str`` 

130 if the response returns a zip file, this argument needs to be set equal to the file within the zipped archive you wish to parse. 

131 """ 

132 col, big_mother = [], [] 

133 

134 with requests.Session() as s: 

135 download = s.get(url) 

136 

137 if zipped is not None: 137 ↛ 138line 137 didn't jump to line 138, because the condition on line 137 was never true

138 zipdata = io.BytesIO(download.content) 

139 unzipped = zipfile.ZipFile(zipdata) 

140 with unzipped.open(zipped, 'r') as f: 

141 for line in f: 

142 big_mother.append( 

143 helper.replace_troublesome_chars(line.decode("utf-8"))) 

144 cr = csv.reader(big_mother, delimiter=',') 

145 

146 else: 

147 decoded_content = download.content.decode('utf-8') 

148 cr = csv.reader(decoded_content.splitlines(), delimiter=',') 

149 

150 s.close() 

151 

152 for row in cr: 

153 if row[column] != firstRowHeader: 

154 col.append(row[column]) 

155 

156 if savefile is not None: 156 ↛ 162line 156 didn't jump to line 162, because the condition on line 156 was never false

157 ext = savefile.split('.')[-1] 

158 with open(savefile, 'w') as outfile: 

159 if ext == "json": 159 ↛ 158line 159 didn't jump to line 158

160 json.dump(col, outfile) 

161 

162 return col 

163 

164 

165def init_static_data(): 

166 """ 

167 Initializes the three static files defined in settings: `scrilla.settings.STATIC_TICKERS_FILE`, `scrilla.settings.STATIC_CRYPTO_FILE` and `scrilla.settings.STATIC_ECON_FILE`. The data for these files is retrieved from the service managers. While this function blurs the lines between file management and service management, the function has been included in the `files.py` module rather than the `services.py` module due the unique response types of static metadata. All metadata is returned as a csv or zipped csvs. These responses require specialized functions. Moreover, these files should only be initialized the first time the application executes. Subsequent executions will refer to their cached versions residing in the local filesytems.  

168 """ 

169 

170 memory = get_memory_json() 

171 

172 if not memory['static']: 

173 global static_tickers_blob 

174 global static_econ_blob 

175 global static_crypto_blob 

176 

177 # grab ticker symbols and store in STATIC_DIR 

178 if ( 178 ↛ 194line 178 didn't jump to line 194

179 settings.PRICE_MANAGER == "alpha_vantage" and 

180 not os.path.isfile(settings.STATIC_TICKERS_FILE) 

181 ): 

182 service_map = keys.keys["SERVICES"]["PRICES"]["ALPHA_VANTAGE"]["MAP"] 

183 logger.debug( 

184 f'Missing {settings.STATIC_TICKERS_FILE}, querying \'{settings.PRICE_MANAGER}\'', 'init_static_data') 

185 

186 # TODO: services calls should be in services.py! need to put this and the helper method 

187 # into services.py in the future. 

188 query = f'{service_map["PARAMS"]["FUNCTION"]}={service_map["ARGUMENTS"]["EQUITY_LISTING"]}' 

189 url = f'{settings.AV_URL}?{query}&{service_map["PARAMS"]["KEY"]}={settings.av_key()}' 

190 static_tickers_blob = parse_csv_response_column(column=0, url=url, savefile=settings.STATIC_TICKERS_FILE, 

191 firstRowHeader=service_map['KEYS']['EQUITY']['HEADER']) 

192 

193 # grab crypto symbols and store in STATIC_DIR 

194 if ( 194 ↛ 206line 194 didn't jump to line 206

195 settings.PRICE_MANAGER == "alpha_vantage" and 

196 not os.path.isfile(settings.STATIC_CRYPTO_FILE) 

197 ): 

198 service_map = keys.keys["SERVICES"]["PRICES"]["ALPHA_VANTAGE"]["MAP"] 

199 logger.debug( 

200 f'Missing {settings.STATIC_CRYPTO_FILE}, querying \'{settings.PRICE_MANAGER}\'.', 'init_static_data') 

201 url = settings.AV_CRYPTO_LIST 

202 static_crypto_blob = parse_csv_response_column(column=0, url=url, savefile=settings.STATIC_CRYPTO_FILE, 

203 firstRowHeader=service_map['KEYS']['CRYPTO']['HEADER']) 

204 

205 # grab econominc indicator symbols and store in STATIC_DIR 

206 if ( 206 ↛ 210line 206 didn't jump to line 210

207 settings.STAT_MANAGER == "quandl" and 

208 not os.path.isfile(settings.STATIC_ECON_FILE) 

209 ): 

210 service_map = keys.keys["SERVICES"]["STATISTICS"]["QUANDL"]["MAP"] 

211 

212 logger.debug( 

213 f'Missing {settings.STATIC_ECON_FILE}, querying \'{settings.STAT_MANAGER}\'.', 'init_static_data') 

214 

215 query = f'{service_map["PATHS"]["FRED"]}/{service_map["PARAMS"]["METADATA"]}' 

216 url = f'{settings.Q_META_URL}/{query}?{service_map["PARAMS"]["KEY"]}={settings.Q_KEY}' 

217 static_econ_blob = parse_csv_response_column(column=0, url=url, savefile=settings.STATIC_ECON_FILE, 

218 firstRowHeader=service_map["KEYS"]["HEADER"], 

219 zipped=service_map["KEYS"]["ZIPFILE"]) 

220 

221 memory['static'] = True 

222 save_memory_json(memory) 

223 

224 else: 

225 logger.debug('Static data already initialized!', 'init_static_data') 

226 

227 

228def get_static_data(static_type): 

229 """ 

230 Retrieves static data saved in the local file system.  

231 

232 Parameters 

233 ---------- 

234 1. **static_type**: ``str`` 

235 A string corresponding to the type of static data to be retrieved. The types can be statically accessed through the `scrilla.static.['ASSETS']` dictionary. 

236 """ 

237 path, blob = None, None 

238 global static_crypto_blob 

239 global static_econ_blob 

240 global static_tickers_blob 

241 

242 if static_type == keys.keys['ASSETS']['CRYPTO']: 

243 if static_crypto_blob is not None: 243 ↛ 246line 243 didn't jump to line 246, because the condition on line 243 was never false

244 blob = static_crypto_blob 

245 else: 

246 path = settings.STATIC_CRYPTO_FILE 

247 

248 elif static_type == keys.keys['ASSETS']['EQUITY']: 248 ↛ 254line 248 didn't jump to line 254, because the condition on line 248 was never false

249 if static_tickers_blob: 249 ↛ 252line 249 didn't jump to line 252, because the condition on line 249 was never false

250 blob = static_tickers_blob 

251 else: 

252 path = settings.STATIC_TICKERS_FILE 

253 

254 elif static_type == keys.keys['ASSETS']['STAT']: 

255 if static_econ_blob: 

256 blob = static_econ_blob 

257 else: 

258 path = settings.STATIC_ECON_FILE 

259 

260 else: 

261 return None 

262 

263 if blob is not None: 263 ↛ 268line 263 didn't jump to line 268, because the condition on line 263 was never false

264 logger.verbose( 

265 f'Found in-memory {static_type} symbols.', 'get_static_data') 

266 return blob 

267 

268 if path is not None: 

269 if not os.path.isfile(path): 

270 init_static_data() 

271 logger.verbose( 

272 f'Loading in cached {static_type} symbols.', 'get_static_data') 

273 

274 ext = path.split('.')[-1] 

275 

276 with open(path, 'r') as infile: 

277 if ext == "json": 

278 symbols = json.load(infile) 

279 # TODO: implement other file loading exts 

280 

281 if static_type == keys.keys['ASSETS']['CRYPTO']: 

282 static_crypto_blob = symbols 

283 elif static_type == keys.keys['ASSETS']['EQUITY']: 

284 static_tickers_blob = symbols 

285 elif static_type == keys.keys['ASSETS']['STAT']: 

286 static_econ_blob = symbols 

287 return symbols 

288 

289 return None 

290 

291# NOTE: output from get_overlapping_symbols: 

292# OVERLAP = ['ABT', 'AC', 'ADT', 'ADX', 'AE', 'AGI', 'AI', 'AIR', 'AMP', 'AVT', 'BCC', 'BCD', 'BCH', 'BCX', 'BDL', 'BFT', 'BIS', 'BLK', 'BQ', 'BRX', 

293# 'BTA', 'BTG', 'CAT', 'CMP', 'CMT', 'CNX', 'CTR', 'CURE', 'DAR', 'DASH', 'DBC', 'DCT', 'DDF', 'DFS', 'DTB', 'DYN', 'EBTC', 'ECC', 'EFL', 'ELA', 'ELF', 

294# 'EMB', 'ENG', 'ENJ', 'EOS', 'EOT', 'EQT', 'ERC', 'ETH', 'ETN', 'EVX', 'EXP', 'FCT', 'FLO', 'FLT', 'FTC', 'FUN', 'GAM', 'GBX', 'GEO', 'GLD', 'GNT', 

295# 'GRC', 'GTO', 'INF', 'INS', 'INT', 'IXC', 'KIN', 'LBC', 'LEND', 'LTC', 'MAX', 'MCO', 'MEC', 'MED', 'MGC', 'MINT', 'MLN', 'MNE', 'MOD', 'MSP', 'MTH', 

296# 'MTN', 'MUE', 'NAV', 'NEO', 'NEOS', 'NET', 'NMR', 'NOBL', 'NXC', 'OCN', 'OPT', 'PBT', 'PING', 'PPC', 'PPT', 'PRG', 'PRO', 'PST', 'PTC', 'QLC', 'QTUM', 

297# 'R', 'RDN', 'REC', 'RVT', 'SALT', 'SAN', 'SC', 'SKY', 'SLS', 'SPR', 'SNX', 'STK', 'STX', 'SUB', 'SWT', 'THC', 'TKR', 'TRC', 'TRST', 'TRUE', 'TRX', 

298# 'TX', 'UNB', 'VERI', 'VIVO', 'VOX', 'VPN', 'VRM', 'VRS', 'VSL', 'VTC', 'VTR', 'WDC', 'WGO', 'WTT', 'XEL', 'NEM', 'ZEN'] 

299 

300# TODO: need some way to distinguish between overlap. 

301 

302 

303def get_overlapping_symbols(equities=None, cryptos=None): 

304 """ 

305 Returns an array of symbols which are contained in both the `scrilla.settings.STATIC_TICKERS_FILE` and `scrilla.settings.STATIC_CRYPTO_FILE`, i.e. ticker symbols which have both a tradeable equtiy and a tradeable crypto asset.  

306 """ 

307 if equities is None: 307 ↛ 309line 307 didn't jump to line 309, because the condition on line 307 was never false

308 equities = list(get_static_data(keys.keys['ASSETS']['EQUITY'])) 

309 if cryptos is None: 309 ↛ 310line 309 didn't jump to line 310, because the condition on line 309 was never true

310 cryptos = list(get_static_data(keys.keys['ASSETS']['CRYPTO'])) 

311 overlap = [] 

312 for crypto in cryptos: 

313 if crypto in equities: 

314 overlap.append(crypto) 

315 return overlap 

316 

317 

318def get_asset_type(symbol: str) -> str: 

319 """" 

320 Returns the asset type of the supplied ticker symbol. 

321 

322 Output 

323 ------ 

324 ``str``.  

325 Represents the asset type of the symbol. Types are statically accessible through the `scrilla.keys['ASSETS]` dictionary. 

326 """ 

327 symbols = list(get_static_data(keys.keys['ASSETS']['CRYPTO'])) 

328 overlap = get_overlapping_symbols(cryptos=symbols) 

329 

330 if symbol not in overlap: 

331 if symbol in symbols: 

332 return keys.keys['ASSETS']['CRYPTO'] 

333 

334 # if other asset types are introduced, then uncomment these lines 

335 # and add new asset type to conditional. Keep in mind the static 

336 # equity data is HUGE. 

337 # symbols = list(get_static_data(keys['ASSETS']['EQUITY'])) 

338 # if symbol in symbols: 

339 # return keys['ASSETS']['EQUITY'] 

340 # return None 

341 return keys.keys['ASSETS']['EQUITY'] 

342 # default to equity for overlap until a better method is determined. 

343 return keys.keys['ASSETS']['EQUITY'] 

344 

345 

346def get_watchlist() -> list: 

347 """ 

348 Description 

349 ----------- 

350 Retrieves the list of watchlisted equity ticker symbols saved in /data/common/watchlist.json. 

351 """ 

352 logger.debug('Loading in Watchlist symbols.', 'get_watchlist') 

353 

354 if os.path.isfile(settings.COMMON_WATCHLIST_FILE): 

355 logger.debug('Watchlist found.', 'get_watchlist') 

356 ext = settings.COMMON_WATCHLIST_FILE.split('.')[-1] 

357 with open(settings.COMMON_WATCHLIST_FILE, 'r') as infile: 

358 if ext == "json": 

359 watchlist = json.load(infile) 

360 logger.verbose( 

361 'Watchlist loaded in JSON format.', 'get_watchlist') 

362 

363 # TODO: implement other file loading exts 

364 else: 

365 logger.error('Watchlist not found.', 'get_watchlist') 

366 watchlist = [] 

367 

368 return watchlist 

369 

370 

371def add_watchlist(new_tickers: list) -> None: 

372 """ 

373 Description 

374 ----------- 

375 Retrieves the list of watchlisted equity ticker symbols saved in /data/common/watchlist.json and then appends to it the list of tickers supplied as arguments. After appending, the list is sorted in alphabetical order. The tickers to add must exist in the /data/static/tickers.json file in order to be added to the watchlist, i.e. the tickers must have price histories that can be retrieved (the static file tickers.json contains a list of all equities with retrievable price histories.) \n \n  

376 """ 

377 logger.debug('Saving tickers to Watchlist', 'add_watchlist') 

378 

379 current_tickers = get_watchlist() 

380 all_tickers = get_static_data(keys.keys['ASSETS']['EQUITY']) 

381 

382 for ticker in new_tickers: 

383 if ticker not in current_tickers and ticker in all_tickers: 

384 logger.debug( 

385 f'New ticker being added to Watchlist: {ticker}', 'add_watchlist') 

386 current_tickers.append(ticker) 

387 

388 current_tickers = sorted(current_tickers) 

389 

390 ext = settings.COMMON_WATCHLIST_FILE.split('.')[-1] 

391 with open(settings.COMMON_WATCHLIST_FILE, 'w+') as outfile: 

392 if ext == "json": 

393 json.dump(current_tickers, outfile) 

394 # TODO: implement other file extensions 

395 

396 

397def save_allocation(allocation, portfolio, file_name, investment=None, latest_prices=None): 

398 save_format = formats.format_allocation( 

399 allocation=allocation, portfolio=portfolio, investment=investment, latest_prices=latest_prices) 

400 save_file(file_to_save=save_format, file_name=file_name) 

401 

402 

403def save_frontier(portfolio, frontier, file_name, investment=None, latest_prices=None): 

404 save_format = formats.format_frontier( 

405 portfolio=portfolio, frontier=frontier, investment=investment, latest_prices=latest_prices) 

406 save_file(file_to_save=save_format, file_name=file_name) 

407 

408 

409def save_moving_averages(tickers, averages_output, file_name): 

410 save_format = formats.format_moving_averages( 

411 tickers=tickers, averages_output=averages_output) 

412 save_file(file_to_save=save_format, file_name=file_name) 

413 

414 

415def save_correlation_matrix(tickers, correlation_matrix, file_name): 

416 save_format = formats.format_correlation_matrix( 

417 tickers=tickers, correlation_matrix=correlation_matrix) 

418 save_file(file_to_save=save_format, file_name=file_name) 

419 

420 

421def clear_directory(directory: str, retain: bool = True) -> bool: 

422 """ 

423 Wipes a directory of files without deleting the directory itself. 

424 

425 Parameters 

426 ---------- 

427 1. **directory**: ``str`` 

428 Path of the directory to be cleared. 

429 

430 2. **retain** : ``bool`` 

431 If set to True, the method will skip files named '.gitkeep' within the directory, i.e. version control configuration files, and keep the directory structure in tact. 

432 """ 

433 try: 

434 filelist = list(os.listdir(directory)) 

435 for f in filelist: 

436 filename = os.path.basename(f) 

437 if retain and filename == constants.constants['KEEP_FILE']: 

438 continue 

439 os.remove(os.path.join(directory, f)) 

440 return True 

441 except OSError as e: 

442 logger.error(e, 'clear_directory') 

443 return False 

444 

445 

446def is_non_zero_file(fpath: str) -> bool: 

447 return os.path.isfile(fpath) and os.path.getsize(fpath) > 0 

448 

449 

450def clear_cache(mode: str = settings.CACHE_MODE) -> bool: 

451 tables = ['prices', 'interest', 'correlations', 'profile'] 

452 memory = get_memory_json() 

453 

454 for table in tables: 

455 memory['cache'][mode][table] = False 

456 

457 save_memory_json(memory) 

458 

459 if mode == 'sqlite': 459 ↛ 466line 459 didn't jump to line 466, because the condition on line 459 was never false

460 try: 

461 os.remove(settings.CACHE_SQLITE_FILE) 

462 return True 

463 except OSError as e: 

464 logger.error(e, 'clear_cache') 

465 return False 

466 elif mode == 'dynamodb': 

467 return aws.dynamo_drop_table(tables) 

468 

469 raise errors.ConfigurationError('`CACHE_MODE` not set!')