import spektral
import tensorflow as tf
import numpy as np
import math
import random
import gc

from typing import List
from menten_gcn.wrappers import WrappedPose

[docs]class DecoratorDataCache: """ DecoratorDataCache prevents re-calculating the same node/edge data many times. You will need to create a different cache for each pose you work with. Also, we highly recommend you make this inside the DataMaker (calling data_maker.make_data_cache() ). This allows for further caching and speedups. Parameters ---------- wrapped_pose: WrappedPose Please pass the pose that we should make a cache for """ def __init__(self, wrapped_pose: WrappedPose): # lookup is edge_cache[i][j] self.edge_cache = [dict() for x in range(wrapped_pose.n_residues() + 1)] self.node_cache = [None for x in range(wrapped_pose.n_residues() + 1)] self.dict_cache = dict()
class NullDecoratorDataCache: def __init__(self): self.edge_cache = None self.node_cache = None self.dict_cache = None
[docs]class DataHolder: """ DataHolder is a wonderful class that automatically stores the direct output of the DataMaker. The DataHolder can then feed your data directly into kera's method using the generators below. There are descriptions for each method below but perhaps the best way to grasp the DataHolder's usage is to see the example at the bottom. Parameters ---------- dtype: np.dtype What NumPy dtype should we use to represent your data? """ def __init__(self, dtype: np.dtype = np.float32): self.Xs = [] self.As = [] self.Es = [] self.outs = [] self.dtype = dtype
[docs] def assert_mode(self, mode=spektral.layers.ops.modes.BATCH): """ For those of you using spektral, this ensures that your data is in the correct shape. Unfortunately this only currently checks X and A. More development is incoming """ if len(self.Xs) == 0: raise RuntimeError("DataHolder.assert_mode is called before any data is added") tf_As = tf.convert_to_tensor(np.asarray(self.As)) tf_Xs = tf.convert_to_tensor(np.asarray(self.Xs)) assert spektral.layers.ops.modes.autodetect_mode(tf_As, tf_Xs) == mode
[docs] def append(self, X: np.ndarray, A: np.ndarray, E: np.ndarray, out: np.ndarray): """ This is the most important method in this class: it gives the data to the dataholder. Parameters ---------- X: array-like Node features, shape=(N,F) A: array-like Adjacency Matrix, shape=(N,N) E: array-like Edge features, shape=(N,N,S) out: array-like What is the output of your model supposed to be? You decide the shape. """ # TODO assert shape self.Xs.append(np.asarray(X, dtype=self.dtype)) self.As.append(np.asarray(A, dtype=self.dtype)) self.Es.append(np.asarray(E, dtype=self.dtype)) self.outs.append(np.asarray(out, dtype=self.dtype))
def size(self) -> int: return len(self.Xs) def get_batch(self, begin: int, end: int): assert begin >= 0 assert end <= self.size() x = np.asarray(self.Xs[begin:end]) a = np.asarray(self.As[begin:end]) e = np.asarray(self.Es[begin:end]) o = np.asarray(self.outs[begin:end]) # TODO debug mode # for xi in x: # assert xi.flatten()[ 0 ] == 1 return [x, a, e], o def get_indices(self, inds): """ this stopped working at some point x = np.asarray(self.Xs[inds]) a = np.asarray(self.As[inds]) e = np.asarray(self.Es[inds]) o = np.asarray(self.outs[inds]) """ x = np.asarray([self.Xs[i] for i in inds]) a = np.asarray([self.As[i] for i in inds]) e = np.asarray([self.Es[i] for i in inds]) o = np.asarray([self.outs[i] for i in inds]) # TODO debug mode # for xi in x: # assert xi.flatten()[ 0 ] == 1 return [x, a, e], o
[docs] def save_to_file(self, fileprefix: str): """ Want to save this data for later? Use this method to cache it to disk. Users of this method may be interested in the CachedDataHolderInputGenerator below Parameters ---------- fileprefix: str Filename prefix for cache. fileprefix="foo/bar" will result in creating "./foo/bar.npz" """ np.savez_compressed( fileprefix + '.npz', x=np.asarray(self.Xs, dtype=self.dtype), a=np.asarray(self.As, dtype=self.dtype), e=np.asarray(self.Es, dtype=self.dtype), o=np.asarray(self.outs, dtype=self.dtype))
[docs] def load_from_file(self, fileprefix: str = None, filename: str = None): """ save_to_file's partner. Use this to load in caches already saved. Please provide either fileprefix or filename, but not both. This duplicity may seem silly. The goal for fileprefix is to be consistant with save_to_file (the two "fileprefix" args will be identical strings for both) whereas the goal for filename is to simply list the name of the file verbosely. Parameters ---------- fileprefix: str Filename prefix for cache. fileprefix="foo/bar" will result in reading "./foo/bar.npz" filename: str Filename for cache. fileprefix="foo/bar.npz" will result in reading "./foo/bar.npz" """ assert filename is None or fileprefix is None, "Please provide either fileprefix or filename" assert filename is not None or fileprefix is not None, "Please provide either fileprefix or filename" if filename is None: fn = fileprefix + '.npz' else: fn = filename cache = np.load(fn) self.Xs = cache['x'] self.As = cache['a'] self.Es = cache['e'] self.outs = cache['o'] assert not np.isnan(np.sum(self.Xs)), filename assert not np.isnan(np.sum(self.As)), filename assert not np.isnan(np.sum(self.Es)), filename assert not np.isnan(np.sum(self.outs)), filename
[docs]class DataHolderInputGenerator(tf.keras.utils.Sequence): """ This class is used to feed a DataHolder directly into Keras's protocol. See the example code below. Parameters ---------- data_holder: DataHolder A DataHolder that you just made batch_size: int How many elements should be grouped together in batches during training? """ def __init__(self, data_holder: DataHolder, batch_size: int = 32): self.holder = data_holder self.batch_size = batch_size self.indices = [i for i in range(0, data_holder.size())] def n_elem(self) -> int: return self.holder.size() def __len__(self): return int((self.holder.size() + self.batch_size - 1) / self.batch_size) def __getitem__(self, item_index): begin = item_index * self.batch_size end = min(begin + self.batch_size, len(self.indices)) inds = self.indices[begin:end] inp, out = self.holder.get_indices(inds) for i in inp: assert np.isfinite(i).all() assert np.isfinite(out).all() return inp, out def on_epoch_end(self): np.random.shuffle(self.indices) gc.collect()
[docs]class CachedDataHolderInputGenerator(tf.keras.utils.Sequence): """ This class is used to feed a DataHolder directly into Keras's protocol. The difference with this class is that it reads one or more DataHolders that have been saved onto disk. See the example code below. Parameters ---------- data_list_lines: list A list of filenames, each one for a different DataHolder. cache: bool If true, this class will load every DataHolder into memory once and keep them there. This can require a lot of memory. Otherwise, we will only read in one DataHolder at a time (once per epoch). This increases disk IO but is often worth it. batch_size: int How many elements should be grouped together in batches during training? autoshuffle: bool This is very nuanced so we recommend keeping the default value of None (this lets us pick the appropriate action). Long story short: YOU DO NOT WANT TO DO SHUFFLE=TRUE inside keras's when cache=False because disk IO goes through the roof. To counter this, we handle shuffling internally in a way that minimizes disk IO. However you DO WANT TO DO SHUFFLE=TRUE if cache=True because everything is in memory anyways. I know this is confusing. Maybe this will be cleaner in the future. """ def __init__(self, data_list_lines: List[str], cache: bool = False, batch_size: int = 32, autoshuffle: bool = None): print("Generating from", str(len(data_list_lines)), "files") self.data_list_lines = data_list_lines if autoshuffle is None: self.autoshuffle = not cache else: self.autoshuffle = autoshuffle assert not(self.autoshuffle and cache), "Autoshuffle is not compatible with caching yet." self.cache = cache self.cached_data = [None for i in self.data_list_lines] self.batch_size = batch_size self.sizes = [] self.total_size = 0 if not self.cache: self.indices = [] else: self.indices = None for i in range(0, len(self.data_list_lines)): filename = self.data_list_lines[i] holder = DataHolder() holder.load_from_file(filename=filename) size = holder.size() size = (int(math.floor(size / float(self.batch_size))) * self.batch_size) print("rounding", holder.size(), "to", size) # round DOWN to nearest multiple of batch size self.sizes.append(size) self.total_size += size if self.cache: self.cached_data[i] = holder else: del holder gc.collect() print(" ", self.total_size, "elements") self.sizes = np.asarray(self.sizes) self.cum_sizes = np.cumsum(self.sizes) self.currently_loaded_npz_index = -1 """ def n_elem(self): return len(self.data_list_lines) """ def __len__(self): """It is mandatory to implement it on Keras Sequence""" return int(self.total_size / self.batch_size) def get_npz_index_for_item(self, item_index): resized_i = item_index * self.batch_size for i in range(0, len(self.cum_sizes)): if resized_i < self.cum_sizes[i]: if i == 0: return i, item_index else: return i, int(item_index - (self.cum_sizes[i - 1] / self.batch_size)) assert False, "DEAD CODE IN get_npz_index_for_item" def __getitem__(self, item_index): npz_i, i = self.get_npz_index_for_item(item_index) if self.cache: self.holder = self.cached_data[npz_i] elif npz_i != self.currently_loaded_npz_index: self.holder = DataHolder() gc.collect() self.currently_loaded_npz_index = npz_i self.holder.load_from_file(filename=self.data_list_lines[self.currently_loaded_npz_index]) self.indices = [x for x in range(0, self.holder.size())] if self.autoshuffle: np.random.shuffle(self.indices) begin = i * self.batch_size end = min(begin + self.batch_size, len(self.holder.As)) if self.indices is None: inp, out = self.holder.get_batch(begin, end) else: assert end <= len(self.indices) inds = self.indices[begin:end] inp, out = self.holder.get_indices(inds) for i in inp: assert np.isfinite(i).all() assert np.isfinite(out).all() return inp, out def on_epoch_end(self): if self.autoshuffle: self.shuffle() gc.collect() def shuffle(self): # # TODO: get this to work with cached data assert not self.cache # shuffle temp = list(zip(self.data_list_lines, self.sizes)) random.shuffle(temp) self.data_list_lines, self.sizes = zip(*temp) # recalc self.cum_sizes = np.cumsum(self.sizes) # reset self.holder = None self.currently_loaded_npz_index = -1