Source code for offspect.cache.file

"""
CacheFile
---------
The python interface to the :py:class:`~.CacheFile` which checks for filename
validity during instantiation. When one of its properties are called, it loads
and parses the metadata and datasets fresh from the hdf5 and aggregatates them.

Examples
++++++++

Peek
****

The most straightforward example would be loading a CacheFile and printing its content. 

.. code-block:: python

   from offspect.api import CacheFile
   cf = CacheFile("example.hdf5")
   print(cf)
   


Iterate
*******


Another use case would be printing a TraceAttribute across all traces in the file, using the iterator interface of the CacheFile, which returns data and attributes of each Trace.

.. code-block:: python

   from offspect.api import CacheFile
   cf = CacheFile("example.hdf5")
   for data, attrs in cf:
        print("rejected?:", attrs["reject"])

Manipulate
**********

We can change the value for a key of the annotations for a specific trace by indexing :meth:`CacheFile.get_trace_attrs` with a specific index. Please note that we now decode and encode the values of the attrs. This is because they are stored as string in the CacheFile, but we need them in their respective type to manipulate them properly. Additionally, we encode them, before we set the attributes again with :meth:`~.CacheFile.set_trace_attrs`. 

.. code-block:: python

    from offspect.api import CacheFile, encode
    cf = CacheFile("example.hdf5")
    attrs = cf.get_trace_attrs(0)
    attrs["stimulation_intensity_mso"] = encode(66)
    cf.set_trace_attrs(0, attrs)


Batch-Manipulate
****************

Another typical use case would be changing one TraceAttribute across all traces in the file. Here, we iterate across all traces, and shift the onset of the TMS 5 samples to the right. 

.. code-block:: python

   from offspect.api import CacheFile, decode, encode
   cf = CacheFile("merged.hdf5")
   for ix, (data, attrs) in enumerate(cf):
       key = "onset_shift"
       old = decode(attrs[key])
       print(f"Trace {ix} {key}:", old, end=" ")
       new = old + 5
       attrs["onset_shift"] = encode(new)
       cf.set_trace_attrs(ix, attrs)
       test = decode(cf.get_trace_attrs(ix)["onset_shift"])
       print("to", test)

Plotting
********

Eventually, and ideally after visual inspection, you might want to plot the resulting map. You can do so with using :py:func:`~.plot_map`, as in the following example.

.. code-block:: python

    from offspect.api import plot_map, CacheFile    
    # we load a cachefile
    cf = CacheFile("example.hdf5")
    # and plot and show it.
    display = plot_map(cf)
    display.show()
    # you can also save the figure with
    display.savefig("example_map.png")

There is a variety of options to tune the plotting to your whims. For example, you can normalize the values, e.g. by taking the logarithm or thresholding by giving the foo argument a sensible Callable. Note that we add 1 to be able to deal with a Vpp of 0 from e.g. MEP-negative traces.

.. code-block:: python

    from math import log10
    # taking the log10
    plot_map(cf, foo = lambda x : log10(x + 1))
    # thresholding
    def threshold(x):
        return float(x>50)
    plot_map(cf, foo = threshold)
    
Additionally, you can use all the keywords from :py:func:`~.plot_glass` to beautify your plot.

.. code-block:: python

    plot_map(cf, vmax=100, title="Example", smooth=25)


"""
from typing import Union, List, Dict, Tuple, Iterator
from pathlib import Path
import h5py
import yaml
import numpy as np
from numpy import ndarray
from functools import partial
import ast
from numpy import ndarray as TraceData
from functools import lru_cache
from offspect.cache.check import (
    check_consistency,
    check_valid_suffix,
    FileName,
    MetaData,
    Annotations,
    TraceData,
    TraceAttributes,
    isindex,
)
from math import inf, nan
from offspect.cache.attrs import encode
from offspect.cache.steps import PreProcessor

read_file = partial(
    h5py.File, mode="r", libver="latest", swmr=True
)  #: open an hdf5 file in single-write-multiple-reader mode


write_file = partial(
    h5py.File, mode="r+", libver="latest", swmr=True
)  #: open an hdf5 file in single-write-multiple-reader mode

# -----------------------------------------------------------------------------
[docs]class CacheFile: """instantiate a new cachefile from HDD args ---- fname: FileName path to the file For each readout, a specific set of fields must be in the metadata of a trace. Whenever attributes are read or written, the validity of the metadata will automatically be checked to be consistent with its 'readout'. """ def __init__(self, fname: FileName): self.fname = Path(fname).expanduser().absolute() if self.fname.exists() == False: raise FileNotFoundError(f"{self.fname} does not exist") check_valid_suffix(fname)
[docs] def get_trace_data(self, idx: int) -> TraceData: """return TraceData for a specific traces in the file args ---- idx: int which trace to pick. returns ------- attrs: TraceData the date stored for this trace. .. note:: This is a read-only attribute, and raw data can never be overwritten with the CacheFile interface. If you need to perform any preprocessing steps, manage the TraceData with a low-level interface, e.g. :func:`~.populate`. """ return read_trace(self, idx=idx, what="data")
[docs] def get_trace_attrs(self, idx: int) -> TraceAttributes: """read the TraceAttributes for a specific traces in the file args ---- idx: int which trace to pick. returns ------- attrs: TraceAttributes the collapsed attributes for this trace. Example:: cf = CacheFile("example.hdf5") for i in len(cf): attrs = cf.get_trace_attrs(i) .. note:: The TraceAttributes contain the metadata of this trace, and the metadata of its parent group, i.e. sourcefile. Additionally, two fields will be added, containing information about the 'cache_file' and the 'cache_file_index'. The number of fields is therefore larger than the number of fields valid for TraceAttributes according to :func:`~.filter_trace_attrs`. This is no problem, because when you update with :meth:`~.set_trace_attrs`, these fields will be used for safety checks and subsequently discarded. """ return read_trace(self, idx=idx, what="attrs")
[docs] def set_trace_attrs(self, idx: int, attrs: TraceAttributes): """update the attributes of a specific trace args ---- idx: int at which index to overwrite attrs: TraceAttributes with which attributes to overwrite Example:: import datetime now = str(datetime.datetime.now()) cf = CacheFile("example.hdf5") attrs = cf.get_trace_attrs(0) attrs["comment"] = now cf.set_trace_attrs(0, attrs) .. note:: Because we expect the TraceAttributes to originate from a CacheFiles :meth:`~.get_trace_attrs` method, we expect them to have information about their original file and index included. For safety reasons, you have to specify the index when calling this setter. Additionally, the original file must be this instance of CacheFile. If you want to directly overwrite an arbitrary attribute without this safety checks, update the values for original_file and original_index and use :func:`~.update_trace_attributes`. Additionally, please note that while :meth:`~.get_trace_attrs` returns a complete dictionary of attributes, including thise that apply to the whole group or origin file, only valid fields for trace metadata will be saved, i.e. those fields which are in correspondence with the "readout" parameter (see :func:`~.filter_trace_attrs`). """ if not "cache_file" in attrs.keys() or not "cache_file_index" in attrs.keys(): raise ValueError( "This attributes do not originate from a CacheFile. Information about its origin is missing" ) if not str(self.fname) == attrs["cache_file"]: raise ValueError("These attributes did not originate from this CacheFile") if not idx == int(attrs["cache_file_index"]): raise ValueError( "These attributes did originate from a different trace in this CacheFile" ) update_trace_attributes(attrs)
@property def origins(self) -> List[str]: "returns a list of original files used in creating this cachefile" with read_file(self.fname) as f: origins = list(f.keys()) return origins def __str__(self) -> str: s = "" h = "-" * 79 + "\n" gap = 20 for ox, attrs in enumerate(recover_annotations(self), start=1): k = "origin" v = attrs[k] o = f"{k:{gap}s} : {v}\n" a = "" for k, v in attrs["attrs"].items(): a += f"{k:{gap}s} : {v}\n" v = len(attrs["traces"]) k = "traces_count" tc = f"{k:{gap}s} : {v}\n" s += "".join((h, o, a, tc)) s += h s += f"In total, there are {len(self)} traces from {ox} origins" return s @lru_cache(maxsize=1) def __len__(self) -> int: cnt = 0 with read_file(self.fname) as f: for origin in f.keys(): idx = None for idx, _ in enumerate(f[origin]["traces"], start=cnt + 1): pass if idx is None: print(f'WARNING: Origin {f[origin].name} had zero traces') cnt = idx or cnt return int(cnt)
[docs] def __iter__(self): """iterate over all traces in the cachefile return ------ data: TraceData the data of this trace attrs: TraceAttributes the attributes of this trace """ for i in range(len(self)): yield self.get_trace_data(i), self.get_trace_attrs(i)
# -----------------------------------------------------------------------------
[docs]def sort_keys(okeys: List[str]) -> List[str]: # because keys are stored as strings, the are sorted alphanumerically, but we need them sorted numerically keys = [int(k) for k in okeys] # for indexing, we need them again as str, though sorted_keys = [str(k) for k in sorted(keys)] return sorted_keys
[docs]def update_trace_attributes(attrs: TraceAttributes): """overwrite the traceattributes for a trace the original file and index of the trace are specified as field within the TraceAttributes args ---- attrs: TraceAttributes """ index: int if isindex(attrs["cache_file_index"]): # this is a transient attribute index = int(attrs["cache_file_index"]) else: raise ValueError("Index must be an integer") fname = attrs["cache_file"] # attrs = filter_trace_attrs(attrs) # skip transient attributes old = attrs attrs = dict() for key, value in old.items(): if key == "cache_file_index": continue if key == "cache_file": continue attrs[key] = value if index >= 0: cnt = -1 with write_file(fname) as f: for origin in f.keys(): # because keys are stored as strings, the are sorted alphanumerically, but we need them sorted numerically keys = sort_keys(f[origin]["traces"].keys()) # we use a running index across origin files, so we start at the last index (defaulting to -1, so -1+1=> 0) for idx, key in enumerate(keys, start=cnt + 1): if idx == index: dset = f[origin]["traces"][key] for key in attrs.keys(): dset.attrs[encode(key)] = encode(attrs[key]) return cnt = idx raise IndexError(f"{index} not in cachefile")
[docs]def read_trace( cf: CacheFile, idx: int, what: str = "attrs" ) -> Union[TraceData, TraceAttributes]: """read either metadata or attributes for a specific trace args ---- cf: CacheFile for which file idx: int which trace to load what: str whether to load 'data' or 'attrs'. defaults to attrs """ if type(idx) != int: raise ValueError("Index must be an integer") if idx >= 0: cnt = -1 # we use cnt to allow indexing across origin files with read_file(cf.fname) as f: for origin in f.keys(): # because keys are stored as strings, the are sorted alphanumerically, but we need them sorted numerically keys = sort_keys(f[origin]["traces"].keys()) # we use a running index across origin files, so we start at the last index (defaulting to -1, so -1+1=> 0) for ix, key in enumerate(keys, start=cnt + 1): # if the trace is the one indexed, we load the dset # fresh from hdd if idx == ix: dset = f[origin]["traces"][key] dset.id.refresh() # load fresh from file if what == "attrs": # attrs = parse_traceattrs(dset.attrs) attrs = asdict(dset.attrs) # attrs["origin"] = encode(str(origin)) attrs["origin"] = encode(origin) attrs["cache_file"] = encode(cf.fname) attrs["cache_file_index"] = encode(idx) # check_metadata(str(attrs["readout"]), attrs) return attrs elif what == "data": data = parse_tracedata(dset) return data else: raise NotImplementedError(f"{what} can not be loaded") # we set cnt to ix to allow cross-origin indexing cnt = ix raise IndexError(f"{idx} not in cachefile")
[docs]def write_tracedata(cf, data: ndarray, idx: int): if type(idx) != int: raise ValueError("Index must be an integer") if idx >= 0: cnt = -1 # we use cnt to allow indexing across origin files with write_file(cf.fname) as f: for origin in f.keys(): # because keys are stored as strings, the are sorted alphanumerically, but we need them sorted numerically keys = sort_keys(f[origin]["traces"].keys()) print(f"CF: Searching through {len(keys)} traces from {origin}") # we use a running index across origin files, so we start at the last index (defaulting to -1, so -1+1=> 0) for ix, key in enumerate(keys, start=cnt + 1): # if the trace is the one indexed, we load the dset # fresh from hdd if idx == ix: dset = f[origin]["traces"][key] if dset.shape == data.shape: dset = f[origin]["traces"][key][:] = data print( "CF: Overwriting data for trace #", idx, "id:", key, "from", origin, ) return else: print( "CF: Trace shape #", idx, "id:", key, "from", origin, "does not conform. Can not overwrite", ) return cnt = ix
[docs]def asdict(attrs: h5py.AttributeManager) -> Dict[str, str]: "parse the metadata from a cachefile and return it as dictionary" return dict(attrs)
[docs]def parse_traceattrs(attrs: h5py.AttributeManager) -> MetaData: """parse any metadata from a cachefile and return it as Dict """ d = dict(attrs) for key, val in d.items(): try: d[key] = ast.literal_eval(val) except (SyntaxError, ValueError): # for subject and filedate if val == "inf": d[key] = inf if val == "nan": d[key] = nan pass if key == "xyz_coords": xyz = yaml.load("[nan, nan, nan]", Loader=yaml.Loader) d[key] = [float(p) for p in xyz] return d
[docs]def parse_tracedata(dset: h5py.Dataset) -> TraceData: "parse a hdf5 dataset from a cachefile and return it as a trace" return np.asanyarray(dset, dtype=float)
[docs]def recover_annotations(cf: CacheFile) -> List[Annotations]: """"recover the file and annotations from a cachefile args ---- cf: CacheFile the cachefile from which to recover returns ------- annotations: List[Annotations] a list of annotations, where annotations are the collapsed metadata of all sourcefiles in the cachefile organized as [sourcesfiles][Annotations] :class:`~.offspect.cache.file.Annotations` """ with read_file(cf.fname) as f: events = [] for origin in f.keys(): yml = dict() yml["origin"] = origin yml["attrs"] = parse_traceattrs(f[origin].attrs) readout = yml["attrs"]["readout"] trace_attrs = [] for idx in f[origin]["traces"]: dset = f[origin]["traces"][idx] dset.id.refresh() # load fresh from file tattr = parse_traceattrs(dset.attrs) # check_metadata(readout, tattr) trace_attrs.append(tattr) yml["traces"] = trace_attrs events.append(yml) return events
[docs]def recover_parts(cf: CacheFile) -> Tuple[List[Annotations], List[List[TraceData]]]: """recover the two parts of a cachefile, i.e. annotations and traces args ---- cf: CacheFile the cachefile from which to recover returns ------- annotations: List[Annotations] a list of annotations, i.e the metadata of all sourcefiles in the cachefile organized as [sourcesfiles][Annotations] traces: List[List[TraceData]] a list of the traces of all sourcefiles saved in the cachefile organized as [sourcefiles][traceidx][TraceData] """ with read_file(cf.fname) as f: events, traces = [], [] for origin in f.keys(): yml = dict() yml["origin"] = origin yml["attrs"] = parse_traceattrs(f[origin].attrs) trace_attrs = [] trace_data = [] for idx in f[origin]["traces"]: dset = f[origin]["traces"][idx] dset.id.refresh() # load fresh from file trace_attrs.append(parse_traceattrs(dset.attrs)) trace_data.append(parse_tracedata(dset)) yml["traces"] = trace_attrs events.append(yml) traces.append(trace_data) return events, traces
[docs]def populate( tf: FileName, annotations: List[Annotations], traceslist: List[List[TraceData]] ) -> FileName: """create a new cachefile from a annotations and traces args ---- tf: FileName the name of the file to be created. will overwrite an existing file annotations: List[Attributes] a list of annotation dictionaries traceslist: List[List[TraceData]] a list of list of traces returns ------- fname: FileName the path to the freshly populated cachefile """ tf = Path(tf).expanduser().absolute() # populate the cachefile with h5py.File(tf, "w") as f: print(f"Merging into {tf.name} from:") for settings, traces in zip(annotations, traceslist): print(" -", settings["origin"]) ofile = f.create_group(settings["origin"]) # fill with ofile-attributes attrs = settings["attrs"] for key, val in attrs.items(): ofile.attrs.modify(str(key), str(val)) # fill with trace-data and trace-attributes tracegrp = ofile.create_group("traces") for tattr, trace in zip(settings["traces"], traces): tattr.update(**attrs) idx = str(tattr["id"]) cftrace = tracegrp.create_dataset(idx, data=trace) for k, v in tattr.items(): cftrace.attrs.modify(str(k), str(v)) return tf
[docs]def merge(to: FileName, sources: List[FileName]) -> FileName: """merge one or more cachefiles into one file args ---- to: FileName the name of the file to be written into. Will be overwritten, if already existing sources: List[FileName] a list of source files from which we will read traces and annotations returns ------- fname: FileName the name of the target file """ sources = [Path(source).expanduser().absolute() for source in sources] to = Path(to).expanduser().absolute() check_valid_suffix(to) if to.exists(): print(f"MERGE:WARNING: {to.name} already exists and will be overwritten") to.unlink() a: List[Dict] = [] t: List[ndarray] = [] for source in sources: attrs, traces = recover_parts(CacheFile(source)) a.extend(attrs) t.extend(traces) print(a) check_consistency(a) fname = populate(tf=to, annotations=a, traceslist=t) return fname