I'm learning about hooks and working with binarized neural network. The issue is that sometimes my gradients are 0 in the backwards pass. I'm trying to replace those gradients with a certain value.
Say I have the following network
import torch
import torch.nn as nn
import torch.optim as optim
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.fc1 = nn.Linear(1, 2)
self.fc2 = nn.Linear(2, 3)
self.fc3 = nn.Linear(3, 1)
def forward(self, x):
x = self.fc1(x)
x = torch.relu(x)
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Model()
opt = optim.Adam(net.parameters())
And also some features
features = torch.rand((3,1))
I can train it normally using:
for i in range(10):
opt.zero_grad()
out = net(features)
loss = torch.mean(torch.square(torch.tensor(5) - torch.sum(out)))
loss.backward()
opt.step()
How can I attach a hook function that will have the following conditions for the backwards pass (for each layer):
If all the gradients in a single layer are 0, change them to 1.0.
If one of the gradients is 0 but the there's at least one gradient that is not 0, change it to 0.5.