Skip to content

Commit

Permalink
add store_dict
Browse files Browse the repository at this point in the history
parts are copied from store_xarray, and could probably be merged
  • Loading branch information
mschrimpf committed Jan 25, 2019
1 parent 6bcf307 commit c7605c5
Show file tree
Hide file tree
Showing 2 changed files with 288 additions and 158 deletions.
82 changes: 75 additions & 7 deletions result_caching/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,73 @@ def load_file(self, path):
return pickle.load(f)['data']


class _DictStorage(_DiskStorage):
"""
All fields in _combine_fields are combined into one file and loaded lazily
"""

def __init__(self, dict_key: str, *args, **kwargs):
"""
:param dict_key: the argument representing the dictionary key.
"""
super().__init__(*args, **kwargs)
self._dict_key = dict_key

def __call__(self, function):
def wrapper(*args, **kwargs):
call_args = self.getcallargs(function, *args, **kwargs)
assert self._dict_key in call_args
infile_call_args = {self._dict_key: call_args[self._dict_key]}
function_identifier = self.get_function_identifier(function, call_args)
stored_result, reduced_call_args = None, call_args
if self.is_stored(function_identifier):
self._logger.debug(f"Loading from storage: {function_identifier}")
stored_result = self.load(function_identifier)
infile_missing_call_args = self.missing_call_args(infile_call_args, stored_result)
if len(infile_missing_call_args) == 0:
# nothing else to run, but still need to filter
result = stored_result
reduced_call_args = None
else:
# need to run more args
non_variable_call_args = {key: value for key, value in call_args.items() if key != self._dict_key}
infile_missing_call_args = {self._dict_key: infile_missing_call_args}
reduced_call_args = {**non_variable_call_args, **infile_missing_call_args}
self._logger.debug(f"Computing missing: {reduced_call_args}")
if reduced_call_args:
# run function if some args are uncomputed
result = function(**reduced_call_args)
if not self.callargs_present(result, {self._dict_key: reduced_call_args[self._dict_key]}):
raise ValueError("result does not contain requested keys")
if stored_result is not None:
result = self.merge_results(stored_result, result)
# only save if new results
self._logger.debug("Saving to storage: {}".format(function_identifier))
self.save(result, function_identifier)
assert self.callargs_present(result, infile_call_args)
result = self.filter_callargs(result, infile_call_args)
return result

return wrapper

def merge_results(self, stored_result, result):
return {**stored_result, **result}

def callargs_present(self, result, infile_call_args):
# make sure coords are set equal to call_args
return len(self.missing_call_args(infile_call_args, result)) == 0

def missing_call_args(self, call_args, data):
assert len(call_args) == 1 and list(call_args.keys())[0] == self._dict_key
keys = list(call_args.values())[0]
return [key for key in keys if key not in data]

def filter_callargs(self, data, call_args):
assert len(call_args) == 1 and list(call_args.keys())[0] == self._dict_key
keys = list(call_args.values())[0]
return type(data)((key, value) for key, value in data.items() if key in keys)


class _XarrayStorage(_DiskStorage):
"""
All fields in _combine_fields are combined into one file and loaded lazily
Expand Down Expand Up @@ -289,13 +356,13 @@ def get_calling_function():
fr = inspect.stack()[1][0]
co = fr.f_code
for get in (
lambda: fr.f_globals[co.co_name],
lambda: getattr(fr.f_locals['self'], co.co_name),
lambda: getattr(fr.f_locals['cls'], co.co_name),
lambda: fr.f_back.f_locals[co.co_name], # nested
lambda: fr.f_back.f_locals['func'], # decorators
lambda: fr.f_back.f_locals['meth'],
lambda: fr.f_back.f_locals['f'],
lambda: fr.f_globals[co.co_name],
lambda: getattr(fr.f_locals['self'], co.co_name),
lambda: getattr(fr.f_locals['cls'], co.co_name),
lambda: fr.f_back.f_locals[co.co_name], # nested
lambda: fr.f_back.f_locals['func'], # decorators
lambda: fr.f_back.f_locals['meth'],
lambda: fr.f_back.f_locals['f'],
):
try:
func = get()
Expand All @@ -309,4 +376,5 @@ def get_calling_function():

cache = _MemoryStorage
store = _DiskStorage
store_dict = _DictStorage
store_xarray = _XarrayStorage
Loading

0 comments on commit c7605c5

Please sign in to comment.