Sparse Mode

This modification of the “Simple Train” example utilizes Spektral’s disjoint mode to model a sparse representation of the graph.

This can result in lower memory usage depending on the connectivity of your graph.

The key differences are:

  • data_maker.generate_graph_for_resid has sparse=True

  • data_maker.generate_XAE_input_tensors has sparse=True and returns a 4th input

    • inputs=[X_in,A_in,E_in,I_in] when building the model

  • We are making a Spektral Dataset and feeding it into the DisjointLoader

  • We are using a Spektral Graph instead of freefloating lists. This change can be done with dense mode too.

    • ‘y’ is the output value in spektral graphs

      • Please read Spektral’s documentation for options regarding ‘y’

import pyrosetta

import menten_gcn as mg
import menten_gcn.decorators as decs

from spektral.layers import *
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model

import numpy as np

decorators = [ decs.StandardBBGeometry(), decs.Sequence() ]

data_maker = mg.DataMaker( decorators=decorators,
                           edge_distance_cutoff_A=10.0, # Create edges between all residues within 10 Angstroms of each other
                           max_residues=20,             # Do not include more than 20 residues total in this network
                           nbr_distance_cutoff_A=25.0 ) # Do not include any residue that is more than 25 Angstroms from the focus residue(s)


class MyDataset(
    def __init__(self, **kwargs):
        self.graphs = [], **kwargs)

    def read(self):
        return self.graphs

dataset = MyDataset()

for pdb in [ "test1.pdb", "test2.pdb", "test3.pdb", "test4.pdb", "test5.pdb" ]:

    pose = pyrosetta.pose_from_pdb( pdb )
    wrapped_pose = mg.RosettaPoseWrapper( pose )
    cache = data_maker.make_data_cache( wrapped_pose )

    for resid in range( 1, pose.size() + 1 ):
        g, resids = data_maker.generate_graph_for_resid( wrapped_pose, resid, data_cache=cache, sparse=True )

        # for the sake of keeping this simple, let's have this model predict if this residue is an N-term
        if wrapped_pose.resid_is_N_term( resid ):
             g.y = [1.0,]
             g.y = [0.0,]

        dataset.graphs.append( g )

# Note we have a 4th input now
X_in, A_in, E_in, I_in = data_maker.generate_XAE_input_tensors( sparse=True )

# ECCConv is called EdgeConditionedConv in older versions of spektral
L1 = ECCConv( 30, activation='relu' )([X_in, A_in, E_in])
# Try this if the first one fails:
#L1 = EdgeConditionedConv( 30, activation='relu' )([X_in, A_in, E_in])

L2 = GlobalSumPool()(L1)
L3 = Flatten()(L2)
output = Dense( 1, name="out" )(L3)

# Make sure to include the 4th input because the DisjointLoader will pass it
model = Model(inputs=[X_in,A_in,E_in,I_in], outputs=output)
model.compile(optimizer='adam', loss='binary_crossentropy' )

loader =, steps_per_epoch=loader.steps_per_epoch)
# This part can sometimes fail due to tensorflow / numpy versioning.
# See the troubleshooting page of our documentation for details