• Vipul Vaibhaw

Hooked with Pytorch!

This post is coming after a long time. We were really occupied in a few things. Anyways, in this post we will discuss the amazing feature of pytorch known as hooks.


This feature come in handy during debugging. Whenever one wants to see the output from an intermediate layer during forward pass or the gradients from backward pass, the hooks turn out to be really useful.


You can checkout this repo - https://github.com/Chanakya-School-of-AI/pytorch-tutorials



Let's dive into code now -


We will start by importing necessary modules.


import torch 
import torch.nn as nn
import torch.optim as optim
import torchvision
import numpy as np 
import torch.nn.functional as F

Great, now we will write a simple neural network which we will be training on MNIST dataset.


class Net(nn.Module):
# We start by inheriting from nn.Module
    def __init__(self):
         super(Net, self).__init__()
         self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, 
                                kernel_size=3)
         self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, 
                                kernel_size=3)
         self.m_pool = nn.MaxPool2d(kernel_size=2,stride=2)
         self.fc1 = nn.Linear(in_features=64*5*5, 
                              out_features=120)
         self.fc2 = nn.Linear(in_features=120, out_features=84)
         self.fc3 = nn.Linear(in_features=84, out_features=10)
     
    def forward(self, input_x):
    # some gyan here, don't use nn.ReLU directly in this function. nn.ReLU return an object, not a tensor which is required as an input in self.m_pool()
    
        x = self.m_pool(F.relu(self.conv1(input_x)))
        x = self.m_pool(F.relu(self.conv2(x)))
        x = x.view(-1,64*5*5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        return x 

Amazing ! Now we also need to get the data ready.

# dataloader 
train_loader = torch.utils.data.DataLoader( 
                      torchvision.datasets.MNIST('./data/',              
                      train=True, download=True,
                      transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                               (0.1307,), (0.3081,))])), \                  
                      batch_size=1, shuffle=True)

Now we will set up some parameters.

net = Net()
learning_rate = 0.001
momentum = 0.9
log_interval = 10
n_epochs = 1 # make sure that you tune the above parameters if you want to train seriously.
criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=momentum)

train_losses = []
train_counter = []
# register hooks on each layer
hookF = [Hook(layer[1]) for layer in list(net._modules.items())]
hookB = [Hook(layer[1],backward=True) for layer in list(net._modules.items())]

Since, we have registered hooks on each layer. We will train our neural network to see the result of hooks.


def train(epoch):
    net.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = net(data)
 
    print('***'*3+'  Forward Hooks Inputs & Outputs  '+'***'*3)
    for hook in hookF:
       print(hook.input)
       print(hook.output)
       print('---'*17)
 
    print('\n')

    loss = criterion(output, target)
    loss.backward()
    optimizer.step()
 
    print('***'*3+'  Backward Hooks Inputs & Outputs  '+'***'*3)
    for hook in hookB:
        print(hook.input)
        print(hook.output)
        print('---'*17)
        
    if batch_idx % log_interval == 0:
        print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: 
            {:.6f}'.format(epoch, batch_idx * len(data), 
            len(train_loader.dataset), 100. * batch_idx / 
            len(train_loader), loss.item()))
        train_losses.append(loss.item())
        train_counter.append((batch_idx*64) + ((epoch 
            1)*len(train_loader.dataset)))
 
for epoch in range(1, n_epochs + 1):
    train(epoch)

Awesome! We have added hooks to neural networks. I hope that you liked reading this article, keep learning and sharing!


Happy coding! :)

132 views

©2019 by Deeplearned education pvt ltd