jaxchem.models¶
GCN¶
-
class
PadGCNPredicator
(*args, **kwargs)[source]¶ GCN Predicator is a wrapper function using GCN and MLP.
-
__init__
(in_feats, hidden_feats, activation=None, batch_norm=None, dropout=None, pooling_method='mean', predicator_hidden_feats=128, predicator_dropout=0.0, n_out=1, name=None)[source]¶ Initializes the module.
- Parameters
in_feats (int) – Number of input node features.
hidden_feats (list[int]) – List of output node features.
activation (list[Activation] or None) –
activation[i]
is the activation function of the i-th GCN layer.len(activation)
equals the number of GCN layers. By default, the activation each layer is relu function.batch_norm (list[bool] or None) –
batch_norm[i]
decides if batch normalization is to be applied on the output of the i-th GCN layer.len(batch_norm)
equals the number of GCN layers. By default, batch normalization is not applied for all GCN layers.dropout (list[float] or None) –
dropout[i]
decides the dropout probability on the output of the i-th GCN layer.len(dropout)
equals the number of GCN layers. By default, dropout is not performed for all layers.pooling_method (Literal['max', 'min', 'mean', 'sum']) – pooling method name, default to ‘mean’.
predicator_hidden_feats (int) – Size of hidden graph representations in the predicator, default to 128.
predicator_dropout (float) – The probability for dropout in the predicator, default to 0.0.
n_out (int) – Number of the output size, default to 1.
name (Optional[str]) –
-
__call__
(node_feats, adj, is_training)[source]¶ Predict logits or values
- Parameters
node_feats (ndarray of shape (batch_size, N, in_feats)) – Batch input node features. N is the total number of nodes in the batch of graphs.
adj (ndarray of shape (batch_size, N, N)) – Batch adjacency matrix.
is_training (bool) – Whether the model is training or not.
- Returns
out – Predicator output.
- Return type
ndarray of shape (batch_size, n_out)
-
-
class
PadGCN
(*args, **kwargs)[source]¶ GCN module. Paper: Semi-Supervised Classification with Graph Convolutional Networks
-
__init__
(in_feats, hidden_feats, activation=None, batch_norm=None, dropout=None, name=None)[source]¶ Initializes the module.
- Parameters
in_feats (int) – Number of input node features.
hidden_feats (list[int]) – List of output node features.
activation (list[Activation] or None) –
activation[i]
is the activation function of the i-th GCN layer.len(activation)
equals the number of GCN layers. By default, the activation each layer is relu function.batch_norm (list[bool] or None) –
batch_norm[i]
decides if batch normalization is to be applied on the output of the i-th GCN layer.len(batch_norm)
equals the number of GCN layers. By default, batch normalization is not applied for all GCN layers.dropout (list[float] or None) –
dropout[i]
decides the dropout probability on the output of the i-th GCN layer.len(dropout)
equals the number of GCN layers. By default, dropout is not performed for all layers.name (Optional[str]) –
-
__call__
(node_feats, adj, is_training)[source]¶ Update node features.
- Parameters
node_feats (ndarray of shape (batch_size, N, in_feats)) – Batch input node features. N is the total number of nodes in the batch of graphs.
adj (ndarray of shape (batch_size, N, N)) – Batch adjacency matrix.
is_training (bool) – Whether the model is training or not.
- Returns
new_node_feats – Batch new node features.
- Return type
ndarray of shape (batch_size, N, out_feats)
-
-
class
PadGCNLayer
(*args, **kwargs)[source]¶ Single GCN layer from Semi-Supervised Classification with Graph Convolutional Networks
-
__init__
(in_feats, out_feats, activation=None, bias=True, normalize=True, batch_norm=False, dropout=0.0, w_init=None, b_init=None, name=None)[source]¶ Initializes the module.
- Parameters
in_feats (int) – Number of input node features.
out_feats (int) – Number of output node features.
activation (Activation or None) – activation function, default to be relu function.
bias (bool) – Whether to add bias after affine transformation, default to be True.
normalize (bool) – Whether to normalize the adjacency matrix or not, default to be True.
batch_norm (bool) – Whetehr to use BatchNormalization or not, default to be False.
dropout (float) – The probability for dropout, default to 0.0.
W_init (initialize function for weight) – Default to be He truncated normal distribution.
b_init (initialize function for bias) – Default to be truncated normal distribution.
w_init (Optional[Callable[[Sequence[int], Any], jax.numpy.lax_numpy.ndarray]]) –
name (Optional[str]) –
-
__call__
(node_feats, adj, is_training)[source]¶ Update node features.
- Parameters
node_feats (ndarray of shape (batch_size, N, in_feats)) – Batch input node features. N is the total number of nodes in the batch of graphs.
adj (ndarray of shape (batch_size, N, N)) – Batch adjacency matrix.
is_training (bool) – Whether the model is training or not.
- Returns
new_node_feats – Batch new node features.
- Return type
ndarray of shape (batch_size, N, out_feats)
-
-
class
SparseGCNPredicator
(*args, **kwargs)[source]¶ GCN Predicator is a wrapper function using GCN and MLP.
-
__init__
(in_feats, hidden_feats, activation=None, batch_norm=None, dropout=None, pooling_method='mean', predicator_hidden_feats=128, predicator_dropout=0.0, n_out=1, name=None)[source]¶ Initializes the module.
- Parameters
in_feats (int) – Number of input node features.
hidden_feats (list[int]) – List of output node features.
activation (list[Activation] or None) –
activation[i]
is the activation function of the i-th GCN layer.len(activation)
equals the number of GCN layers. By default, the activation each layer is relu function.batch_norm (list[bool] or None) –
batch_norm[i]
decides if batch normalization is to be applied on the output of the i-th GCN layer.len(batch_norm)
equals the number of GCN layers. By default, batch normalization is not applied for all GCN layers.dropout (list[float] or None) –
dropout[i]
decides the dropout probability on the output of the i-th GCN layer.len(dropout)
equals the number of GCN layers. By default, dropout is not performed for all layers.pooling_method (Literal['max', 'min', 'mean', 'sum']) – pooling method name, default to ‘mean’.
predicator_hidden_feats (int) – Size of hidden graph representations in the predicator, default to 128.
predicator_dropout (float) – The probability for dropout in the predicator, default to 0.0.
n_out (int) – Number of the output size, default to 1.
name (Optional[str]) –
-
__call__
(node_feats, adj, graph_idx, is_training)[source]¶ Predict logits or values
- Parameters
node_feats (ndarray of shape (N, in_feats)) – Batch input node features. N is the total number of nodes in the batch
adj (ndarray of shape (2, E)) – Batch adjacency list. E is the total number of edges in the batch
graph_idx (ndarray of shape (N,)) – This idx indicate a graph number for node_feats in the batch. When the two nodes shows the same graph idx, these belong to the same graph.
is_training (bool) – Whether the model is training or not.
- Returns
out – Predicator output.
- Return type
ndarray of shape (batch_size, n_out)
-
-
class
SparseGCN
(*args, **kwargs)[source]¶ GCN module. Paper: Semi-Supervised Classification with Graph Convolutional Networks
-
__init__
(in_feats, hidden_feats, activation=None, batch_norm=None, dropout=None, name=None)[source]¶ Initializes the module.
- Parameters
in_feats (int) – Number of input node features.
hidden_feats (list[int]) – List of output node features.
activation (list[Activation] or None) –
activation[i]
is the activation function of the i-th GCN layer.len(activation)
equals the number of GCN layers. By default, the activation each layer is relu function.batch_norm (list[bool] or None) –
batch_norm[i]
decides if batch normalization is to be applied on the output of the i-th GCN layer.len(batch_norm)
equals the number of GCN layers. By default, batch normalization is not applied for all GCN layers.dropout (list[float] or None) –
dropout[i]
decides the dropout probability on the output of the i-th GCN layer.len(dropout)
equals the number of GCN layers. By default, dropout is not performed for all layers.name (Optional[str]) –
-
__call__
(node_feats, adj, is_training)[source]¶ Update node features.
- Parameters
node_feats (ndarray of shape (N, in_feats)) – Batch input node features. N is the total number of nodes in the batch.
adj (ndarray of shape (2, E)) – Batch adjacency list. E is the total number of edges in the batch.
is_training (bool) – Whether the model is training or not.
- Returns
new_node_feats – Batch new node features.
- Return type
ndarray of shape (N, out_feats)
-
-
class
SparseGCNLayer
(*args, **kwargs)[source]¶ Single GCN layer from Semi-Supervised Classification with Graph Convolutional Networks
-
__init__
(in_feats, out_feats, activation=None, bias=True, normalize=True, batch_norm=False, dropout=0.0, w_init=None, b_init=None, name=None)[source]¶ Initializes the module.
- Parameters
in_feats (int) – Number of input node features.
out_feats (int) – Number of output node features.
activation (Activation or None) – activation function, default to be relu function.
bias (bool) – Whether to add bias after affine transformation, default to be True.
normalize (bool) – Whether to normalize or not, default to be True.
batch_norm (bool) – Whetehr to use BatchNormalization or not, default to be False.
dropout (float) – The probability for dropout, default to 0.0.
W_init (initialize function for weight) – Default to be He truncated normal distribution.
b_init (initialize function for bias) – Default to be truncated normal distribution.
w_init (Optional[Callable[[Sequence[int], Any], jax.numpy.lax_numpy.ndarray]]) –
name (Optional[str]) –
-
__call__
(node_feats, adj, is_training)[source]¶ Update node features.
- Parameters
node_feats (ndarray of shape (N, in_feats)) – Batch input node features. N is the total number of nodes in the batch
adj (ndarray of shape (2, E)) – Batch adjacency list. E is the total number of edges in the batch
is_training (bool) – Whether the model is training or not.
- Returns
new_node_feats – Batch new node features.
- Return type
ndarray of shape (N, out_feats)
-
Readout¶
-
pad_graph_pooling
(method='mean')[source]¶ Pooling function for pad pattern graph data.
- methodLiteral[‘max’, ‘min’, ‘mean’, ‘sum’]
pooling method name.
- Returns
This function aggregates node_feats about axis=1.
- Return type
Function
- Parameters
method (typing_extensions.Literal['max', 'min', 'mean', 'sum']) –
-
sparse_graph_pooling
(method='mean')[source]¶ Pooling function for sparse pattern graph data.
- methodLiteral[‘max’, ‘min’, ‘mean’, ‘sum’]
pooling method name.
- Returns
This function aggregates node_feats with graph_idx.
- Return type
Function
- Parameters
method (typing_extensions.Literal['max', 'min', 'mean', 'sum']) –