jaxchem.models

Contents

GCN

class PadGCNPredicator(*args, **kwargs)[source]

GCN Predicator is a wrapper function using GCN and MLP.

__init__(in_feats, hidden_feats, activation=None, batch_norm=None, dropout=None, pooling_method='mean', predicator_hidden_feats=128, predicator_dropout=0.0, n_out=1, name=None)[source]

Initializes the module.

Parameters
  • in_feats (int) – Number of input node features.

  • hidden_feats (list[int]) – List of output node features.

  • activation (list[Activation] or None) – activation[i] is the activation function of the i-th GCN layer. len(activation) equals the number of GCN layers. By default, the activation each layer is relu function.

  • batch_norm (list[bool] or None) – batch_norm[i] decides if batch normalization is to be applied on the output of the i-th GCN layer. len(batch_norm) equals the number of GCN layers. By default, batch normalization is not applied for all GCN layers.

  • dropout (list[float] or None) – dropout[i] decides the dropout probability on the output of the i-th GCN layer. len(dropout) equals the number of GCN layers. By default, dropout is not performed for all layers.

  • pooling_method (Literal['max', 'min', 'mean', 'sum']) – pooling method name, default to ‘mean’.

  • predicator_hidden_feats (int) – Size of hidden graph representations in the predicator, default to 128.

  • predicator_dropout (float) – The probability for dropout in the predicator, default to 0.0.

  • n_out (int) – Number of the output size, default to 1.

  • name (Optional[str]) –

__call__(node_feats, adj, is_training)[source]

Predict logits or values

Parameters
  • node_feats (ndarray of shape (batch_size, N, in_feats)) – Batch input node features. N is the total number of nodes in the batch of graphs.

  • adj (ndarray of shape (batch_size, N, N)) – Batch adjacency matrix.

  • is_training (bool) – Whether the model is training or not.

Returns

out – Predicator output.

Return type

ndarray of shape (batch_size, n_out)

class PadGCN(*args, **kwargs)[source]

GCN module. Paper: Semi-Supervised Classification with Graph Convolutional Networks

__init__(in_feats, hidden_feats, activation=None, batch_norm=None, dropout=None, name=None)[source]

Initializes the module.

Parameters
  • in_feats (int) – Number of input node features.

  • hidden_feats (list[int]) – List of output node features.

  • activation (list[Activation] or None) – activation[i] is the activation function of the i-th GCN layer. len(activation) equals the number of GCN layers. By default, the activation each layer is relu function.

  • batch_norm (list[bool] or None) – batch_norm[i] decides if batch normalization is to be applied on the output of the i-th GCN layer. len(batch_norm) equals the number of GCN layers. By default, batch normalization is not applied for all GCN layers.

  • dropout (list[float] or None) – dropout[i] decides the dropout probability on the output of the i-th GCN layer. len(dropout) equals the number of GCN layers. By default, dropout is not performed for all layers.

  • name (Optional[str]) –

__call__(node_feats, adj, is_training)[source]

Update node features.

Parameters
  • node_feats (ndarray of shape (batch_size, N, in_feats)) – Batch input node features. N is the total number of nodes in the batch of graphs.

  • adj (ndarray of shape (batch_size, N, N)) – Batch adjacency matrix.

  • is_training (bool) – Whether the model is training or not.

Returns

new_node_feats – Batch new node features.

Return type

ndarray of shape (batch_size, N, out_feats)

class PadGCNLayer(*args, **kwargs)[source]

Single GCN layer from Semi-Supervised Classification with Graph Convolutional Networks

__init__(in_feats, out_feats, activation=None, bias=True, normalize=True, batch_norm=False, dropout=0.0, w_init=None, b_init=None, name=None)[source]

Initializes the module.

Parameters
  • in_feats (int) – Number of input node features.

  • out_feats (int) – Number of output node features.

  • activation (Activation or None) – activation function, default to be relu function.

  • bias (bool) – Whether to add bias after affine transformation, default to be True.

  • normalize (bool) – Whether to normalize the adjacency matrix or not, default to be True.

  • batch_norm (bool) – Whetehr to use BatchNormalization or not, default to be False.

  • dropout (float) – The probability for dropout, default to 0.0.

  • W_init (initialize function for weight) – Default to be He truncated normal distribution.

  • b_init (initialize function for bias) – Default to be truncated normal distribution.

  • w_init (Optional[Callable[[Sequence[int], Any], jax.numpy.lax_numpy.ndarray]]) –

  • name (Optional[str]) –

__call__(node_feats, adj, is_training)[source]

Update node features.

Parameters
  • node_feats (ndarray of shape (batch_size, N, in_feats)) – Batch input node features. N is the total number of nodes in the batch of graphs.

  • adj (ndarray of shape (batch_size, N, N)) – Batch adjacency matrix.

  • is_training (bool) – Whether the model is training or not.

Returns

new_node_feats – Batch new node features.

Return type

ndarray of shape (batch_size, N, out_feats)

class SparseGCNPredicator(*args, **kwargs)[source]

GCN Predicator is a wrapper function using GCN and MLP.

__init__(in_feats, hidden_feats, activation=None, batch_norm=None, dropout=None, pooling_method='mean', predicator_hidden_feats=128, predicator_dropout=0.0, n_out=1, name=None)[source]

Initializes the module.

Parameters
  • in_feats (int) – Number of input node features.

  • hidden_feats (list[int]) – List of output node features.

  • activation (list[Activation] or None) – activation[i] is the activation function of the i-th GCN layer. len(activation) equals the number of GCN layers. By default, the activation each layer is relu function.

  • batch_norm (list[bool] or None) – batch_norm[i] decides if batch normalization is to be applied on the output of the i-th GCN layer. len(batch_norm) equals the number of GCN layers. By default, batch normalization is not applied for all GCN layers.

  • dropout (list[float] or None) – dropout[i] decides the dropout probability on the output of the i-th GCN layer. len(dropout) equals the number of GCN layers. By default, dropout is not performed for all layers.

  • pooling_method (Literal['max', 'min', 'mean', 'sum']) – pooling method name, default to ‘mean’.

  • predicator_hidden_feats (int) – Size of hidden graph representations in the predicator, default to 128.

  • predicator_dropout (float) – The probability for dropout in the predicator, default to 0.0.

  • n_out (int) – Number of the output size, default to 1.

  • name (Optional[str]) –

__call__(node_feats, adj, graph_idx, is_training)[source]

Predict logits or values

Parameters
  • node_feats (ndarray of shape (N, in_feats)) – Batch input node features. N is the total number of nodes in the batch

  • adj (ndarray of shape (2, E)) – Batch adjacency list. E is the total number of edges in the batch

  • graph_idx (ndarray of shape (N,)) – This idx indicate a graph number for node_feats in the batch. When the two nodes shows the same graph idx, these belong to the same graph.

  • is_training (bool) – Whether the model is training or not.

Returns

out – Predicator output.

Return type

ndarray of shape (batch_size, n_out)

class SparseGCN(*args, **kwargs)[source]

GCN module. Paper: Semi-Supervised Classification with Graph Convolutional Networks

__init__(in_feats, hidden_feats, activation=None, batch_norm=None, dropout=None, name=None)[source]

Initializes the module.

Parameters
  • in_feats (int) – Number of input node features.

  • hidden_feats (list[int]) – List of output node features.

  • activation (list[Activation] or None) – activation[i] is the activation function of the i-th GCN layer. len(activation) equals the number of GCN layers. By default, the activation each layer is relu function.

  • batch_norm (list[bool] or None) – batch_norm[i] decides if batch normalization is to be applied on the output of the i-th GCN layer. len(batch_norm) equals the number of GCN layers. By default, batch normalization is not applied for all GCN layers.

  • dropout (list[float] or None) – dropout[i] decides the dropout probability on the output of the i-th GCN layer. len(dropout) equals the number of GCN layers. By default, dropout is not performed for all layers.

  • name (Optional[str]) –

__call__(node_feats, adj, is_training)[source]

Update node features.

Parameters
  • node_feats (ndarray of shape (N, in_feats)) – Batch input node features. N is the total number of nodes in the batch.

  • adj (ndarray of shape (2, E)) – Batch adjacency list. E is the total number of edges in the batch.

  • is_training (bool) – Whether the model is training or not.

Returns

new_node_feats – Batch new node features.

Return type

ndarray of shape (N, out_feats)

class SparseGCNLayer(*args, **kwargs)[source]

Single GCN layer from Semi-Supervised Classification with Graph Convolutional Networks

__init__(in_feats, out_feats, activation=None, bias=True, normalize=True, batch_norm=False, dropout=0.0, w_init=None, b_init=None, name=None)[source]

Initializes the module.

Parameters
  • in_feats (int) – Number of input node features.

  • out_feats (int) – Number of output node features.

  • activation (Activation or None) – activation function, default to be relu function.

  • bias (bool) – Whether to add bias after affine transformation, default to be True.

  • normalize (bool) – Whether to normalize or not, default to be True.

  • batch_norm (bool) – Whetehr to use BatchNormalization or not, default to be False.

  • dropout (float) – The probability for dropout, default to 0.0.

  • W_init (initialize function for weight) – Default to be He truncated normal distribution.

  • b_init (initialize function for bias) – Default to be truncated normal distribution.

  • w_init (Optional[Callable[[Sequence[int], Any], jax.numpy.lax_numpy.ndarray]]) –

  • name (Optional[str]) –

__call__(node_feats, adj, is_training)[source]

Update node features.

Parameters
  • node_feats (ndarray of shape (N, in_feats)) – Batch input node features. N is the total number of nodes in the batch

  • adj (ndarray of shape (2, E)) – Batch adjacency list. E is the total number of edges in the batch

  • is_training (bool) – Whether the model is training or not.

Returns

new_node_feats – Batch new node features.

Return type

ndarray of shape (N, out_feats)

Readout

pad_graph_pooling(method='mean')[source]

Pooling function for pad pattern graph data.

methodLiteral[‘max’, ‘min’, ‘mean’, ‘sum’]

pooling method name.

Returns

This function aggregates node_feats about axis=1.

Return type

Function

Parameters

method (typing_extensions.Literal['max', 'min', 'mean', 'sum']) –

sparse_graph_pooling(method='mean')[source]

Pooling function for sparse pattern graph data.

methodLiteral[‘max’, ‘min’, ‘mean’, ‘sum’]

pooling method name.

Returns

This function aggregates node_feats with graph_idx.

Return type

Function

Parameters

method (typing_extensions.Literal['max', 'min', 'mean', 'sum']) –