</head>

Train a Cellular Attention Network (CAN)

Cellular complexes are graphs that consisted of nodes,edges, and faces. Intuitively, Cellular Attention Networks work by learning edge features via a combination of incident edges that share a node, and incident edges that share a face.This implementation was contributed to the topological deep learning library TopoModelX.

We create and train a neural network for cellular complexes with layers using a message passing scheme provided by the down and up Laplacians as proposed in Rodenberry et. al: Signal processing on cell complexes (2022). We also train layers utilizing the cell attention mechanism originally proposed in Giusti et. al: Cell Attention Networks (2022). Intuitively updates are made

The Neural Network:

The equations of one layer of this neural network without the attention mechanism are given by:

  • A convolution from edges to edges using the down and up laplacian to pass messages:

🟥   my → {z} → x(1→0→1) = L↓,1 ⋅ hyt, (1) ⋅ Θt, (1→0→1)

🟥   my → {z} → x(1→2→1) = L↑,1 ⋅ hyt, (1) ⋅ Θt, (1→2→1)

🟥   mx → x(1→1) = hxt, (1) ⋅ Θt, (1→1)

🟧   mx(1→0→1) = ∑y ∈ ℬ(x)my → x(1→0→1)

🟧   mx(1→2→1) = ∑y ∈ 𝒞(x)my → x(1→2→1)

🟩:   mx(1) = mx(1→0→1) + mx → x(1→1) + mx(1→2→1)

🟦   hxt + 1, (1) = σ(mx(1))

Where the notations are defined in Papillon et al : Architectures of Topological Deep Learning: A Survey of Topological Neural Networks (2023).

The Task:

We train this model to perform entire complex classification on a small version of shrec16.

Set-up

import torch
import random
import numpy as np
from sklearn.model_selection import train_test_split

import toponetx.datasets as datasets

from topomodelx.nn.cell.can_layer_bis import CANLayer

If GPU's are available, we will make use of them. Otherwise, this will run on CPU.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
cpu

Pre-processing

Import data

The first step is to import the dataset, shrec16, a benchmark dataset for 3D mesh classification. We then lift each graph into our domain of choice, a cell complex.

We also retrieve:

  • input signals x_0,x_1, and x_2 on the nodes (0-cells), edges (1-cells), and faces (2-cells) for each complex: these will be the model's inputs,
  • a binary classification label y associated to the cell complex.
shrec, _ = datasets.mesh.shrec_16(size="small")

shrec = {key: np.array(value) for key, value in shrec.items()}
x_0s = shrec["node_feat"]
x_1s = shrec["edge_feat"]
x_2s = shrec["face_feat"]

ys = shrec["label"]
simplexes = shrec["complexes"]
Loading dataset...

done!
i_complex = 6
print(
    f"The {i_complex}th simplicial complex has {x_0s[i_complex].shape[0]} nodes with features of dimension {x_0s[i_complex].shape[1]}."
)
print(
    f"The {i_complex}th simplicial complex has {x_1s[i_complex].shape[0]} edges with features of dimension {x_1s[i_complex].shape[1]}."
)
print(
    f"The {i_complex}th simplicial complex has {x_2s[i_complex].shape[0]} faces with features of dimension {x_2s[i_complex].shape[1]}."
)
print(f"The {i_complex}th simplicial complex has label {ys[i_complex]}.")
The 6th simplicial complex has 252 nodes with features of dimension 6.
The 6th simplicial complex has 750 edges with features of dimension 10.
The 6th simplicial complex has 500 faces with features of dimension 7.
The 6th simplicial complex has label 9.

Lift into cell complex domain and define neighborhood structures

We lift each simplicial complex into a cell complex.

Then, we retrieve the neighborhood structures (i.e. their representative matrices) taht we will use to send messages on each cell complex. In th case of this architecture we need the down and up laplacians acting on 1-cells denoted by L↓,1, L↑,1

cc_list = []
down_laplacian_list = []
up_laplacian_list = []
for simplex in simplexes:
    cell_complex = simplex.to_cell_complex()
    cc_list.append(cell_complex)

    down_laplacian = cell_complex.down_laplacian_matrix(rank=1)
    up_laplacian = cell_complex.up_laplacian_matrix(rank=1)
    down_laplacian = torch.from_numpy(down_laplacian.todense()).to_sparse()
    up_laplacian = torch.from_numpy(up_laplacian.todense()).to_sparse()
    down_laplacian_list.append(down_laplacian)
    up_laplacian_list.append(up_laplacian)
i_complex = 6
print(
    f"The {i_complex}th cell complex has a down_laplacian matrix of shape {down_laplacian_list[i_complex].shape}."
)
print(
    f"The {i_complex}th cell complex has an up_laplacian matrix of shape {up_laplacian_list[i_complex].shape}."
)
The 6th cell complex has a down_laplacian matrix of shape torch.Size([750, 750]).
The 6th cell complex has an up_laplacian matrix of shape torch.Size([750, 750]).

Define neighborhood structures.

Implementing the CAN architecture will require to perform message passing along neighborhood structures of the cell complexes.

Thus, now we retrieve these neighborhood structures (i.e. their representative matrices) that we will use to send messages.

For the CAN, we need the down Laplacian matrix L↓,1 and the up Laplacian matrix L↑,1 of each cell complex.

up_laplacian_list = []
down_laplacian_list = []
for cell_complex in cc_list:
    up_laplacian = cell_complex.up_laplacian_matrix(rank=1)
    down_laplacian = cell_complex.down_laplacian_matrix(rank=1)
    up_laplacian = torch.from_numpy(up_laplacian.todense()).to_sparse()
    down_laplacian = torch.from_numpy(down_laplacian.todense()).to_sparse()
    up_laplacian_list.append(up_laplacian)
    down_laplacian_list.append(down_laplacian)

i_cc = 0
print(f"Up Laplacian of the {i_cc}-th complex: {up_laplacian_list[i_cc].shape}.")
print(f"Down Laplacian of the {i_cc}-th complex: {down_laplacian_list[i_cc].shape}.")
Up Laplacian of the 0-th complex: torch.Size([750, 750]).
Down Laplacian of the 0-th complex: torch.Size([750, 750]).

Create the Neural Network

Using the CANLayer class, we create a neural network which applies a CAN layer to the edges followed by linear layers on nodes, edges, and faces.

in_channels_0 = x_0s[0].shape[-1]
in_channels_1 = x_1s[0].shape[-1]
in_channels_2 = x_2s[0].shape[-1]
print(
    f"The dimension of input features on nodes, edges and faces are: {in_channels_0}, {in_channels_1} and {in_channels_2}."
)
The dimension of input features on nodes, edges and faces are: 6, 10 and 7.
class CAN(torch.nn.Module):
    """CAN.

    Parameters
    ----------
    in_channels_0 : int
        Dimension of input features on nodes.
    in_channels_1 : int
        Dimension of input features on edges.
    in_channels_2 : int
        Dimension of input features on faces.
    num_classes : int
        Number of classes.
    n_layers : int
        Number of CAN layers.
    att : bool
        Whether to use attention.

    """

    def __init__(
        self,
        in_channels_0,
        in_channels_1,
        in_channels_2,
        num_classes,
        n_layers=2,
        att=False,
    ):
        super().__init__()
        layers = []
        for _ in range(n_layers):
            layers.append(CANLayer(channels=in_channels_1, att=att))
        self.layers = layers
        self.lin_0 = torch.nn.Linear(in_channels_0, num_classes)
        self.lin_1 = torch.nn.Linear(in_channels_1, num_classes)
        self.lin_2 = torch.nn.Linear(in_channels_2, num_classes)

    def forward(self, x_0, x_1, x_2, down_laplacian, up_laplacian):
        """Forward computation through layers, then linear layers, then avg pooling.

        Parameters
        ----------
        x_0 : torch.Tensor, shape = [n_nodes, in_channels_0]
            Input features on the nodes (0-cells).
        x_1 : torch.Tensor, shape = [n_edges, in_channels_1]
            Input features on the edges (1-cells).
        x_2 : torch.Tensor, shape = [n_faces, in_channels_2]
            Input features on the faces (2-cells).
        down_laplacian : tensor, shape = [n_edges, n_edges]
            Down Laplacian of rank 1.
        up_laplacian : tensor, shape = [n_edges, n_edges]
            Up Laplacian of rank 1.

        Returns
        -------
        _ : tensor, shape = [1]
            Label assigned to whole complex.
        """
        for layer in self.layers:
            x_1 = layer(x_1, down_laplacian, up_laplacian)
        x_0 = self.lin_0(x_0)
        x_1 = self.lin_1(x_1)
        x_2 = self.lin_2(x_2)
        # Take the average of the 2D, 1D and 0D cell features. If they are NaN, convert them to 0.
        two_dimensional_cells_mean = torch.nanmean(x_2, dim=0)
        two_dimensional_cells_mean[torch.isnan(two_dimensional_cells_mean)] = 0
        one_dimensional_cells_mean = torch.nanmean(x_1, dim=0)
        one_dimensional_cells_mean[torch.isnan(one_dimensional_cells_mean)] = 0
        zero_dimensional_cells_mean = torch.nanmean(x_0, dim=0)
        zero_dimensional_cells_mean[torch.isnan(zero_dimensional_cells_mean)] = 0
        # Return the sum of the averages
        return (
            one_dimensional_cells_mean
            + zero_dimensional_cells_mean
            + two_dimensional_cells_mean
        )

Train the Neural Network

We specify the model, initialize loss, and specify an optimizer. We first try it without any attention mechanism.

model = CAN(in_channels_0, in_channels_1, in_channels_2, num_classes=1, n_layers=2)
model = model.to(device)
crit = torch.nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters(), lr=0.1)
loss_fn = torch.nn.MSELoss()

We split the dataset into train and test sets.

test_size = 0.2
x_0_train, x_0_test = train_test_split(x_0s, test_size=test_size, shuffle=False)
x_1_train, x_1_test = train_test_split(x_1s, test_size=test_size, shuffle=False)
x_2_train, x_2_test = train_test_split(x_2s, test_size=test_size, shuffle=False)
up_laplacian_train, up_laplacian_test = train_test_split(
    up_laplacian_list, test_size=test_size, shuffle=False
)
down_laplacian_train, down_laplacian_test = train_test_split(
    down_laplacian_list, test_size=test_size, shuffle=False
)
y_train, y_test = train_test_split(ys, test_size=test_size, shuffle=False)

We train the CAN using 10 epochs: we keep training minimal for the purpose of rapid testing.

test_interval = 2
num_epochs = 10
for epoch_i in range(1, num_epochs + 1):
    epoch_loss = []
    model.train()
    for x_0, x_1, x_2, down_laplacian, up_laplacian, y in zip(
        x_0_train,
        x_1_train,
        x_2_train,
        down_laplacian_train,
        up_laplacian_train,
        y_train,
    ):
        x_0, x_1, x_2, y = (
            torch.tensor(x_0).float().to(device),
            torch.tensor(x_1).float().to(device),
            torch.tensor(x_2).float().to(device),
            torch.tensor(y).float().to(device),
        )
        down_laplacian, up_laplacian = down_laplacian.float().to(
            device
        ), up_laplacian.float().to(device)
        opt.zero_grad()
        y_hat = model(x_0, x_1, x_2, down_laplacian, up_laplacian)
        loss = loss_fn(y_hat, y)
        loss.backward()
        opt.step()
        epoch_loss.append(loss.item())
    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f}",
        flush=True,
    )
    if epoch_i % test_interval == 0:
        with torch.no_grad():
            for x_0, x_1, x_2, down_laplacian, up_laplcian, y in zip(
                x_0_test,
                x_1_test,
                x_2_test,
                down_laplacian_test,
                up_laplacian_test,
                y_test,
            ):
                x_0, x_1, x_2, y = (
                    torch.tensor(x_0).float().to(device),
                    torch.tensor(x_1).float().to(device),
                    torch.tensor(x_2).float().to(device),
                    torch.tensor(y).float().to(device),
                )
                down_laplacian, up_laplacian = down_laplacian.float().to(
                    device
                ), up_laplacian.float().to(device)
                y_hat = model(x_0, x_1, x_2, down_laplacian, up_laplacian)
                test_loss = loss_fn(y_hat, y)
            print(f"Test_loss: {test_loss:.4f}", flush=True)
C:\Users\abrah\anaconda3\envs\topological_2\Lib\site-packages\torch\nn\modules\loss.py:536: UserWarning: Using a target size (torch.Size([])) that is different to the input size (torch.Size([1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
  return F.mse_loss(input, target, reduction=self.reduction)
Epoch: 1 loss: 89.3527
Epoch: 2 loss: 81.9469
Test_loss: 59.5947
Epoch: 3 loss: 81.1082
Epoch: 4 loss: 80.3476
Test_loss: 55.3188
Epoch: 5 loss: 79.6493
Epoch: 6 loss: 79.0137
Test_loss: 51.6529
Epoch: 7 loss: 78.4377
Epoch: 8 loss: 77.9167
Test_loss: 48.4984
Epoch: 9 loss: 77.4454
Epoch: 10 loss: 77.0191
Test_loss: 45.7851

Train the Neural Network with Attention

Now we create a new neural network, that uses the attention mechanism.

model = CAN(
    in_channels_0, in_channels_1, in_channels_2, num_classes=1, n_layers=2, att=True
)
model = model.to(device)
crit = torch.nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters(), lr=0.1)
loss_fn = torch.nn.MSELoss()

We run the training for this neural network:

test_interval = 2
num_epochs = 10
for epoch_i in range(1, num_epochs + 1):
    epoch_loss = []
    model.train()
    for x_0, x_1, x_2, down_laplacian, up_laplacian, y in zip(
        x_0_train,
        x_1_train,
        x_2_train,
        down_laplacian_train,
        up_laplacian_train,
        y_train,
    ):
        x_0, x_1, x_2, y = (
            torch.tensor(x_0).float().to(device),
            torch.tensor(x_1).float().to(device),
            torch.tensor(x_2).float().to(device),
            torch.tensor(y).float().to(device),
        )
        down_laplacian, up_laplacian = down_laplacian.float().to(
            device
        ), up_laplacian.float().to(device)
        opt.zero_grad()
        y_hat = model(x_0, x_1, x_2, down_laplacian, up_laplacian)
        loss = loss_fn(y_hat, y)
        loss.backward()
        opt.step()
        epoch_loss.append(loss.item())
    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f}",
        flush=True,
    )
    if epoch_i % test_interval == 0:
        with torch.no_grad():
            for x_0, x_1, x_2, down_laplacian, up_laplcian, y in zip(
                x_0_test,
                x_1_test,
                x_2_test,
                down_laplacian_test,
                up_laplacian_test,
                y_test,
            ):
                x_0, x_1, x_2, y = (
                    torch.tensor(x_0).float().to(device),
                    torch.tensor(x_1).float().to(device),
                    torch.tensor(x_2).float().to(device),
                    torch.tensor(y).float().to(device),
                )
                down_laplacian, up_laplacian = down_laplacian.float().to(
                    device
                ), up_laplacian.float().to(device)
                y_hat = model(x_0, x_1, x_2, down_laplacian, up_laplacian)
                test_loss = loss_fn(y_hat, y)
            print(f"Test_loss: {test_loss:.4f}", flush=True)
Epoch: 1 loss: 96.6442
Epoch: 2 loss: 76.8734
Test_loss: 52.6675
Epoch: 3 loss: 76.2222
Epoch: 4 loss: 75.4701
Test_loss: 46.0780
Epoch: 5 loss: 74.7502
Epoch: 6 loss: 74.0797
Test_loss: 40.4300
Epoch: 7 loss: 73.4617
Epoch: 8 loss: 72.8948
Test_loss: 35.5483
Epoch: 9 loss: 72.3760
Epoch: 10 loss: 71.9013
Test_loss: 31.3371

</html>

–>