PyTorch Style Transfer Tutorial
In 2015, Leon A. Gatys et al. published a paper using deep neural networks to transfer the style of one image to another. The beauty of this paper is to use DNN to extract content and style of pictures, which produces good results to simulate the masterpieces of previous artists.
I have read this paper long time ago, but I still wondered how the perform gradient descent on input image, luckily PyTorch provides a good tutorial on how to do this. In this post I try to extract the core part that relates to the input image updating.
Why `detach` target from graph
class ContentLoss(nn.Module):
def __init__(self, target): super(ContentLoss, self).__init__() # we 'detach' the target content from the tree used # to dynamically compute the gradient: this is a stated value, # not a variable. Otherwise the forward method of the criterion # will throw an error. self.target = target.detach() def forward(self, input): self.loss = F.mse_loss(input, self.target) return input
...target = model(img)content_loss = ContentLoss(target)model.add_module("content_loss", content_loss)
This is because later we just want to back propagate through the graph with input image, not the target image img. Otherwise PyTorch will try to back propagate the graph twice, however the classical error will appear:
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
Why normalization?
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)# create a module to normalize input image so we can easily put it in a nn.Sequentialclass Normalization(nn.Module): def __init__(self, mean, std): super(Normalization, self).__init__() self.mean = torch.tensor(mean).view(-1, 1, 1) self.std = torch.tensor(std).view(-1, 1, 1) def forward(self, img): # normalize img return (img - self.mean) / self.stdnormalization = Normalization(cnn_normalization_mean, cnn_normalization_std).to(device)model = nn.Sequential(normalization)
This is because the default VGG network in PyTorch will normalize the image with specific means and variances, which are shown above.
How to update the INPUT IMAGE?
input_img = torch.randn(img.data.size(), device=device)def get_input_optimizer(input_img): # this line to show that input is a parameter that requires a gradient optimizer = optim.LBFGS([input_img.requires_grad_()]) return optimizeroptimizer = get_input_optimizer(input_img)
- Make the
input_img
requires gradient computation byinput_img.requires_grad_()
. - Set up the optimizer to optimize the
input_img
.
The result
Here I just reconstruct the content of a specific layer. The result is consistent with the original paper.



I think these are all the tricks that I did not understand when I read the tutorials, hope this post could help you better understand the tutorial.