# $i$-RevNet and $i$-ResNet: Minor modifications to get an invertible neural network!

Our objective will be to design some neural networks that will be invertible.

In [None]:
import torch
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn

We first consider the model architecture given by: (it's a VGG-like model)

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(16, 16, 7,padding=3)
        self.conv2 = nn.Conv2d(16, 16, 7,padding=3)
        self.conv3 = nn.Conv2d(16, 16, 7,padding=3)
        self.conv4 = nn.Conv2d(16, 16, 7,padding=3)
        self.fc = nn.Linear(28*28, 10)
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = x.view(-1, 28*28)
        x = self.fc(x)
        return x

The routines to load the data from the dataset are given by:

In [None]:
import torchvision
import torchvision.transforms as transforms

from torch.utils.data import Subset

reshape = transforms.Lambda(lambda x:x.view(16,7,7))

dataset=torchvision.datasets.MNIST('./', download=True, transform=transforms.Compose([transforms.ToTensor(),reshape]), train=True)
train_indices = torch.arange(0,500)#torch.randperm(len(dataset))[:500]
train_dataset = Subset(dataset, train_indices)

dataset=torchvision.datasets.MNIST('./', download=True, transform=transforms.Compose([transforms.ToTensor(),reshape]), train=False)
test_indices = torch.arange(0,10000)#torch.randperm(len(dataset))[:1000]
test_dataset = Subset(dataset, test_indices)


and the dataloader, which is actually used to depile the data is given by:

In [None]:
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=64,
                                          shuffle=True, num_workers=0)

testloader = torch.utils.data.DataLoader(test_dataset, batch_size=16,
                                          shuffle=False, num_workers=0)

1. Try to visualize some data that we will use along this lab session.

2. Write a train function and a test function on this dataset. Train the neural network, and plot the training loss. __Hint:__ Use https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html .

In [None]:
def train(net, n_epoch=40, learning_rate=0.05):
    
    loss = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)
    
    for epoch in range(n_epoch):  # loop over the dataset multiple times
        running_loss = []
        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            
            # FILL IN HERE


    correct = 0
    total = 0
    with torch.no_grad():
        for data in trainloader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the 500 train images: %d %%' % (
        100 * correct / total))
    
    return net

def test(net):
    correct = 0
    total = 0
    # FILL IN HERE
    
net = Net()
net = train(net)
test(net)

# $i$-RevNets: NNs invertible by constructions
3. For a layer $j$, assume the features at layer write $X_j=[x_j,\tilde x_j]$ where $\text{dim}(x_j)=n_j$ and $\text{dim}(\tilde x_j)=m_j$ , let $F_j:\mathbb{R}^{n_j}\rightarrow \mathbb{R}^{m_{j+1}}$. Propose an invertible $\Phi_F$ architecture using only $F=(F_1,...F_J)$ and additions. We write $\mathcal{F}_J=\cup_{F} \{\Phi_{F}\}$ the set of such neural networks of depth $J$ for any width $m_j,n_j$.

4. Write $\mathcal{F}=\cup_J \mathcal{F}_J$. Give a necessary and suffisant condition for $\mathcal{F}$ to be a group.

5. Write a function that splits a tensor of size $(batch, chan, N, N)$ into two tensors of size $(batch, chan//2, N, N)$

In [None]:
class SplitAlongChannels(nn.Module):
    def __init__(self):
        super(SplitAlongChannels, self).__init__()
        pass
    def forward(self,x):
        pass
    def inverse(self,X):
        pass

6. Write the corresponding invertible operation.

In [None]:
class InvertibleLayer(nn.Module):
    def __init__(self, kernel_size=7):
        super(InvertibleLayer, self).__init__()
        self.conv = nn.Conv2d(8, 8, kernel_size,padding=kernel_size//2)
    def forward(self, X):
        pass
    def inverse(self, X):
        pass

7. Write a class i-RevNet that consists in 4 layers, which is analog to *Net*.

In [None]:
class iRevNet(nn.Module):
    def __init__(self):
        super(iRevNet, self).__init__()
        self.split = SplitAlongChannels()
        self.conv1 = InvertibleLayer()
        self.conv2 = InvertibleLayer()
        self.conv3 = InvertibleLayer()
        self.conv4 = InvertibleLayer()
        self.fc = nn.Linear(28*28, 10)
        
    def forward(self, x, classif=True):
        if classif:
            pass
        pass
    
    def inverse(self, x):
        pass

8. Implement and run such architecture using the questions above. Compare the accuracies.

In [None]:
net = iRevNet()
net = train(net)
test(net)

9. Validate on an example that your model is exactly invertible.

In [None]:
x = torch.randn(1,16,7,7)

This type of model is directly linked to: https://openreview.net/pdf?id=HJsjkMb0Z.

# $i$-ResNets: a simple modification to get an invertible resnet

We know consider the following class model, to which we refer as ResNet (see https://arxiv.org/abs/1512.03385):

In [None]:
class ResNet(nn.Module):
    def __init__(self):
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(16, 16, 7,padding=3)
        self.conv2 = nn.Conv2d(16, 16, 7,padding=3)
        self.conv3 = nn.Conv2d(16, 16, 7,padding=3)
        self.conv4 = nn.Conv2d(16, 16, 7,padding=3)
        self.fc = nn.Linear(28*28, 10)
    def forward(self, x):
        x = x - F.relu(self.conv1(x))
        x = x - F.relu(self.conv2(x))
        x = x - F.relu(self.conv3(x))
        x = x - F.relu(self.conv4(x))
        x = x.view(-1, 28*28)
        x = self.fc(x)
        return x

Note that those models have some identity loop, to which we refer as skip-connections, that were introduced in https://arxiv.org/abs/1512.03385.

10. Train the corresponding model.

In [None]:
net=ResNet()
net = train(net)
test(net)

11. Let's consider the layer of a Neural Net given by $x_{j+1}=x_j - F_j x_j$. Propose a condition on $F=(F_1,...,F_J)$ for $\Phi x$ to be invertible. Propose a simple iterative way to compute the inverse. How many iterations are needed to obtain a precision $\epsilon$?

12. Let be $N$ an integer,  $x\in \ell^2([0,N-1])$. Compute the norm of $W:y\rightarrow x\circledast  y$. Deduce a simple renormalization procedure to constraint $\Vert DF_j\Vert\leq\rho$ for some $0<\rho<1$.

13. Implement this layer and incorporate the normaliztion in the train function. You might want to look at the help of nn.Conv2D to access the weight of the conv. __Hint__: modify a parameter using .data to avoid breaking the tree of computation of pytorch.

In [None]:
from torch.fft import fft
from torch.fft import ifft
class InvertibleResidual(nn.Module):
    
    def __init__(self, kernel_size=7):
        super(InvertibleResidual, self).__init__()
        self.conv = nn.Conv2d(16, 16, kernel_size,padding=kernel_size//2)
        
    def forward(self, x):
        pass
    
    def normalize(self, rho=0.9):
        pass
    
    def inverse(self,y):
        pass

In [None]:
def train_norm(net, n_epoch=40, learning_rate=0.05):
    
   # Fill in
    
    return net

14. Write a class i-ResNet, and train it.

In [None]:
class iResNet(nn.Module):
    def __init__(self):
        super(iResNet, self).__init__()
        self.conv1 = InvertibleResidual()
        self.conv2 = InvertibleResidual()
        self.conv3 = InvertibleResidual()
        self.conv4 = InvertibleResidual()
        self.fc = nn.Linear(28*28, 10)
    
    def normalize(self, rho=0.9):
        # FILL IN
        return self

    def forward(self, x, classif=True):
        pass
        if classif:
            pass

    
    def inverse(self, x):
        pass


In [None]:
net = iResNet()

net = train_norm(net)
test(net)

15. Verify the invertibility of the final model.

In [None]:
x = torch.randn(1,16,7,7)


This type of model is quite discussed in https://arxiv.org/pdf/1811.00995.pdf.

16. For some invertible neural networks $\Phi$ and two images $x_0,x_1$, visualize for $0<t<1$ $x_t=\Phi^{-1}((1-t)\Phi x_0+t\Phi x_1$. What do you think?

17. Explain the major differences between the two approaches, the pros and the cons.