Simple TrainΒΆ
This model builds off of the hello world but has some extra complexity and takes us all the way to training
import pyrosetta
pyrosetta.init()
import menten_gcn as mg
import menten_gcn.decorators as decs
from spektral.layers import *
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
import numpy as np
# Pick some decorators to add to your network
decorators = [ decs.StandardBBGeometry(), decs.Sequence() ]
data_maker = mg.DataMaker( decorators=decorators,
edge_distance_cutoff_A=10.0, # Create edges between all residues within 10 Angstroms of each other
max_residues=20, # Do not include more than 20 residues total in this network
nbr_distance_cutoff_A=25.0 ) # Do not include any residue that is more than 25 Angstroms from the focus residue(s)
data_maker.summary()
Xs = []
As = []
Es = []
outs = []
# This part is all very hand-wavy
for pdb in [ "test1.pdb", "test2.pdb", "test3.pdb", "test4.pdb", "test5.pdb" ]:
pose = pyrosetta.pose_from_pdb( pdb )
wrapped_pose = mg.RosettaPoseWrapper( pose )
cache = data_maker.make_data_cache( wrapped_pose )
for resid in range( 1, pose.size() + 1 ):
X, A, E, resids = data_maker.generate_input_for_resid( wrapped_pose, resid, data_cache=cache )
Xs.append( X )
As.append( A )
Es.append( E )
# for the sake of keeping this simple, let's have this model predict if this residue is an N-term
if wrapped_pose.resid_is_N_term( resid ):
outs.append( [1.0,] )
else:
outs.append( [0.0,] )
# Okay now we need to define a model.
# The data_maker can tell use the right sizes to use.
# Better yet, the data_maker can simply create the input layers for us:
X_in, A_in, E_in = data_maker.generate_XAE_input_tensors()
# GCN model architectures are tricky
# Here's just a very simple one to get us off the ground
# ECCConv is called EdgeConditionedConv in older versions of spektral
L1 = ECCConv( 30, activation='relu' )([X_in, A_in, E_in])
# Try this if the first one fails:
#L1 = EdgeConditionedConv( 30, activation='relu' )([X_in, A_in, E_in])
L2 = GlobalSumPool()(L1)
L3 = Flatten()(L2)
output = Dense( 1, name="out" )(L3)
model = Model(inputs=[X_in,A_in,E_in], outputs=output)
model.compile(optimizer='adam', loss='binary_crossentropy' )
model.summary()
Xs = np.asarray( Xs )
As = np.asarray( As )
Es = np.asarray( Es )
outs = np.asarray( outs )
print( Xs.shape )
print( As.shape )
print( Es.shape )
print( outs.shape )
model.fit( x=[Xs,As,Es], y=outs, batch_size=32, epochs=10, validation_split=0.2 )