PyTorch Style Transfer Tutorial

Fangda Han
3 min readAug 22, 2018

In 2015, Leon A. Gatys et al. published a paper using deep neural networks to transfer the style of one image to another. The beauty of this paper is to use DNN to extract content and style of pictures, which produces good results to simulate the masterpieces of previous artists.

I have read this paper long time ago, but I still wondered how the perform gradient descent on input image, luckily PyTorch provides a good tutorial on how to do this. In this post I try to extract the core part that relates to the input image updating.

Why `detach` target from graph

class ContentLoss(nn.Module):  
def __init__(self, target):
super(ContentLoss, self).__init__() # we 'detach' the target content from the tree used # to dynamically compute the gradient: this is a stated value, # not a variable. Otherwise the forward method of the criterion # will throw an error. self.target = target.detach() def forward(self, input): self.loss = F.mse_loss(input, self.target) return input
...
target = model(img)content_loss = ContentLoss(target)model.add_module("content_loss", content_loss)

This is because later we just want to back propagate through the graph with input image, not the target image img. Otherwise PyTorch will try to back propagate the graph twice, however the classical error will appear:

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

Why normalization?

cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)# create a module to normalize input image so we can easily put it in a nn.Sequentialclass Normalization(nn.Module):  def __init__(self, mean, std):    super(Normalization, self).__init__()    self.mean = torch.tensor(mean).view(-1, 1, 1)    self.std = torch.tensor(std).view(-1, 1, 1)  def forward(self, img):    # normalize img    return (img - self.mean) / self.stdnormalization = Normalization(cnn_normalization_mean, cnn_normalization_std).to(device)model = nn.Sequential(normalization)

This is because the default VGG network in PyTorch will normalize the image with specific means and variances, which are shown above.

How to update the INPUT IMAGE?

input_img = torch.randn(img.data.size(), device=device)def get_input_optimizer(input_img):  # this line to show that input is a parameter that requires a gradient  optimizer = optim.LBFGS([input_img.requires_grad_()])  return optimizeroptimizer = get_input_optimizer(input_img)
  1. Make the input_img requires gradient computation by input_img.requires_grad_().
  2. Set up the optimizer to optimize the input_img.

The result

Here I just reconstruct the content of a specific layer. The result is consistent with the original paper.

Input image
Reconstruction of conv2 in VGG19
Reconstruction of conv5 in VGG19

I think these are all the tricks that I did not understand when I read the tutorials, hope this post could help you better understand the tutorial.

--

--