Neural Style Transfer#
Summary#
Introduction
Principle
Example: apply style transfer to an image
Introduction#
Neural style transfer in a nutshell#
Reproduce an image with a new artistic style provided by another image.
Blend a content image and a style reference image in a stylized output image.
First described in A Neural Algorithm of Artistic Style by Gatys et al (2015). Many refinements and variations since.
Example#
Principle#
Loss definition#
The resulting image \(\pmb{x}^*\) minimizes a weighted sum of three losses: the content loss \(\mathcal{L}_{content}(\pmb{c}, \pmb{x})\), the style loss \(\mathcal{L}_{style}(\pmb{s}, \pmb{x})\) and the total variation loss \(\mathcal{L}_{total\_variation}(\pmb{x})\).
\(\pmb{x}\): generated image.
\(\pmb{c}\): content image.
\(\pmb{s}\): reference style image.
\(\alpha\), \(\beta\) and \(\gamma\): weights.
The content loss#
Content = high-level structure of an image.
Can be captured by the upper layer of a convolutional neural network.
Content loss for a layer = distance between the feature maps of the content and generated images.
The style loss#
Style = low-level features of an image (textures, colors, visual patterns).
Can be captured by using correlations across the different feature maps (filter responses) of a convnet.
Feature correlations are computed via a Gram matrix (outer product of the feature maps for a given layer).
Style loss for a layer = distance between the Gram matrices of the feature maps for the style and generated images.
The total variation loss#
Sum of the absolute differences for neighboring pixel-values in an image. Measures how much noise is in the image.
Encourage local spatial continuity in the generated image (denoising).
Act as a regularization loss.
Gradient descent#
Objective: minimize the total loss.
Optimizer: L-BFGS (original choice made by Gatys et al.) or Adam.
Example: apply style transfer to an image#
(Heavily inspired by this Keras example)
Environment setup#
import platform
print(f"Python version: {platform.python_version()}")
assert platform.python_version_tuple() >= ("3", "6")
from IPython.display import Image, display
import numpy as np
Python version: 3.8.2
import tensorflow as tf
print(f"TensorFlow version: {tf.__version__}")
print(f"Keras version: {tf.keras.__version__}")
print('GPU found :)' if tf.config.list_physical_devices("GPU") else 'No GPU :(')
from tensorflow.keras import Model
from tensorflow.keras.applications import vgg19
from tensorflow.keras.utils import get_file
from tensorflow.keras.preprocessing.image import load_img, save_img, img_to_array
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.optimizers.schedules import ExponentialDecay
TensorFlow version: 2.2.0
Keras version: 2.3.0-tf
GPU found :)
Images loading#
# Download content and style images
base_image_path = get_file("paris.jpg", "https://i.imgur.com/F28w3Ac.jpg")
style_reference_image_path = get_file(
"starry_night.jpg", "https://i.imgur.com/9ooB60I.jpg"
)
width, height = load_img(base_image_path).size
print(f"Input image dimensions: {width, height}")
Input image dimensions: (910, 607)
# Show content images
display(Image(base_image_path))
# Show style reference image
display(Image(style_reference_image_path))
Utility functions#
def preprocess_image(image_path, height, width):
"""Open, resize and format a picture into appropriate tensors"""
img = load_img(
image_path, target_size=(height, width)
)
img = img_to_array(img)
img = np.expand_dims(img, axis=0)
# Convert image from RGB to BGR and zero-center each color channel w.r.t. the ImageNet dataset
img = vgg19.preprocess_input(img)
return tf.convert_to_tensor(img)
def deprocess_image(x, height, width):
"""Convert a tensor into a valid image"""
x = x.reshape((height, width, 3))
# Remove zero-center by mean pixel applied by vgg19.preprocess_input()
# The following values are the mean pixel values of each color channel for the ImageNet dataset
x[:, :, 0] += 103.939
x[:, :, 1] += 116.779
x[:, :, 2] += 123.68
# 'BGR'->'RGB'
x = x[:, :, ::-1]
x = np.clip(x, 0, 255).astype("uint8")
return x
Loss functions#
# The gram matrix of a 3D tensor (correlations between the feature maps of a convolutional layer)
def gram_matrix(x):
# Transpose feature maps tensor tensor so that 3rd dimension becomes 1st
x = tf.transpose(x, (2, 0, 1))
# Reshape feature maps tensor into a matrix. First dimension is the number of filters/maps
features = tf.reshape(x, (tf.shape(x)[0], -1))
# Compute the outer product of feature vectors with themselves
gram = tf.matmul(features, tf.transpose(features))
return gram
# The style loss is designed to maintain the style of the reference image in the generated image
# It is based on the gram matrices (which capture style) of feature maps from the style reference image
# and from the generated image
def style_loss(style, combination, height, width):
S = gram_matrix(style)
C = gram_matrix(combination)
channels = 3
size = height * width
# Compute distance between Gram matrices of style and generated images
return tf.reduce_sum(tf.square(S - C)) / (4.0 * (channels ** 2) * (size ** 2))
# The content loss is designed to maintain the "content" of the base image in the generated image
def content_loss(base, combination):
return tf.reduce_sum(tf.square(combination - base))
# The total variation loss is designed to keep the generated image locally coherent
def total_variation_loss(x, height, width):
a = tf.square(
x[:, : height - 1, : width - 1, :] - x[:, 1:, : width - 1, :]
)
b = tf.square(
x[:, : height - 1, : width - 1, :] - x[:, : height - 1, 1:, :]
)
return tf.reduce_sum(tf.pow(a + b, 1.25))
Model definition#
We use a VGG model pretrained on the ImageNet dataset.
# Using the convolutional base of VGG19, loaded with pre-trained ImageNet weights
vgg = vgg19.VGG19(weights="imagenet", include_top=False)
# Get the symbolic outputs of each "key" layer (we gave them unique names)
outputs_dict = dict([(layer.name, layer.output) for layer in vgg.layers])
# Set up a model that returns the activation values for every layer in VGG19 (as a dict)
feature_extractor = Model(inputs=vgg.inputs, outputs=outputs_dict)
vgg.summary()
Model: "vgg19"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, None, None, 3)] 0
_________________________________________________________________
block1_conv1 (Conv2D) (None, None, None, 64) 1792
_________________________________________________________________
block1_conv2 (Conv2D) (None, None, None, 64) 36928
_________________________________________________________________
block1_pool (MaxPooling2D) (None, None, None, 64) 0
_________________________________________________________________
block2_conv1 (Conv2D) (None, None, None, 128) 73856
_________________________________________________________________
block2_conv2 (Conv2D) (None, None, None, 128) 147584
_________________________________________________________________
block2_pool (MaxPooling2D) (None, None, None, 128) 0
_________________________________________________________________
block3_conv1 (Conv2D) (None, None, None, 256) 295168
_________________________________________________________________
block3_conv2 (Conv2D) (None, None, None, 256) 590080
_________________________________________________________________
block3_conv3 (Conv2D) (None, None, None, 256) 590080
_________________________________________________________________
block3_conv4 (Conv2D) (None, None, None, 256) 590080
_________________________________________________________________
block3_pool (MaxPooling2D) (None, None, None, 256) 0
_________________________________________________________________
block4_conv1 (Conv2D) (None, None, None, 512) 1180160
_________________________________________________________________
block4_conv2 (Conv2D) (None, None, None, 512) 2359808
_________________________________________________________________
block4_conv3 (Conv2D) (None, None, None, 512) 2359808
_________________________________________________________________
block4_conv4 (Conv2D) (None, None, None, 512) 2359808
_________________________________________________________________
block4_pool (MaxPooling2D) (None, None, None, 512) 0
_________________________________________________________________
block5_conv1 (Conv2D) (None, None, None, 512) 2359808
_________________________________________________________________
block5_conv2 (Conv2D) (None, None, None, 512) 2359808
_________________________________________________________________
block5_conv3 (Conv2D) (None, None, None, 512) 2359808
_________________________________________________________________
block5_conv4 (Conv2D) (None, None, None, 512) 2359808
_________________________________________________________________
block5_pool (MaxPooling2D) (None, None, None, 512) 0
=================================================================
Total params: 20,024,384
Trainable params: 20,024,384
Non-trainable params: 0
_________________________________________________________________
Loss computation#
# List of layers to use for the style loss
style_layer_names = [
"block1_conv1",
"block2_conv1",
"block3_conv1",
"block4_conv1",
"block5_conv1",
]
# The layer to use for the content loss
content_layer_name = "block5_conv2"
# Weights of the different loss components
total_variation_weight = 1e-6
style_weight = 1e-6
content_weight = 2.5e-8
def compute_loss(combination_image, base_image, style_reference_image, height, width):
input_tensor = tf.concat(
[base_image, style_reference_image, combination_image], axis=0
)
features = feature_extractor(input_tensor)
# Initialize the loss
loss = tf.zeros(shape=())
# Add content loss
layer_features = features[content_layer_name]
base_image_features = layer_features[0, :, :, :]
combination_features = layer_features[2, :, :, :]
loss = loss + content_weight * content_loss(
base_image_features, combination_features
)
# Add style loss
for layer_name in style_layer_names:
layer_features = features[layer_name]
style_reference_features = layer_features[1, :, :, :]
combination_features = layer_features[2, :, :, :]
sl = style_loss(style_reference_features, combination_features, height, width)
loss += (style_weight / len(style_layer_names)) * sl
# Add total variation loss
loss += total_variation_weight * total_variation_loss(combination_image, height, width)
return loss
@tf.function
def compute_loss_and_grads(combination_image, base_image, style_reference_image, height, width):
with tf.GradientTape() as tape:
loss = compute_loss(combination_image, base_image, style_reference_image, height, width)
grads = tape.gradient(loss, combination_image)
return loss, grads
Training loop#
# Generated image height
gen_height = 400
# Compute generated width so that input and generated images have same scale
gen_width = int(width * gen_height / height)
print(f"Generated image dimensions: {gen_width, gen_height}")
Generated image dimensions: (599, 400)
optimizer = SGD(
ExponentialDecay(
initial_learning_rate=100.0, decay_steps=100, decay_rate=0.96
)
)
base_image = preprocess_image(base_image_path, gen_height, gen_width)
style_reference_image = preprocess_image(style_reference_image_path, gen_height, gen_width)
combination_image = tf.Variable(preprocess_image(base_image_path, gen_height, gen_width))
# Training loop
n_epochs = 4000
for epoch in range(1, n_epochs + 1):
loss, grads = compute_loss_and_grads(
combination_image, base_image, style_reference_image, gen_height, gen_width
)
optimizer.apply_gradients([(grads, combination_image)])
if epoch % 100 == 0:
print(f"Epoch [{epoch}/{n_epochs}], loss: {loss:.2f}")
# Save final image
final_img = deprocess_image(combination_image.numpy(), gen_height, gen_width)
result_image_path = f"generated_image_epoch_{n_epochs}.png"
save_img(result_image_path, final_img)
Epoch [100/4000], loss: 11018.34
Epoch [200/4000], loss: 8514.22
Epoch [300/4000], loss: 7571.62
Epoch [400/4000], loss: 7064.01
Epoch [500/4000], loss: 6736.22
Epoch [600/4000], loss: 6501.81
Epoch [700/4000], loss: 6323.29
Epoch [800/4000], loss: 6181.49
Epoch [900/4000], loss: 6065.40
Epoch [1000/4000], loss: 5967.84
Epoch [1100/4000], loss: 5884.72
Epoch [1200/4000], loss: 5812.83
Epoch [1300/4000], loss: 5750.32
Epoch [1400/4000], loss: 5695.63
Epoch [1500/4000], loss: 5647.32
Epoch [1600/4000], loss: 5604.25
Epoch [1700/4000], loss: 5565.57
Epoch [1800/4000], loss: 5530.69
Epoch [1900/4000], loss: 5499.09
Epoch [2000/4000], loss: 5470.43
Epoch [2100/4000], loss: 5444.34
Epoch [2200/4000], loss: 5420.45
Epoch [2300/4000], loss: 5398.52
Epoch [2400/4000], loss: 5378.34
Epoch [2500/4000], loss: 5359.73
Epoch [2600/4000], loss: 5342.55
Epoch [2700/4000], loss: 5326.65
Epoch [2800/4000], loss: 5311.85
Epoch [2900/4000], loss: 5298.08
Epoch [3000/4000], loss: 5285.24
Epoch [3100/4000], loss: 5273.25
Epoch [3200/4000], loss: 5262.03
Epoch [3300/4000], loss: 5251.54
Epoch [3400/4000], loss: 5241.72
Epoch [3500/4000], loss: 5232.52
Epoch [3600/4000], loss: 5223.90
Epoch [3700/4000], loss: 5215.81
Epoch [3800/4000], loss: 5208.19
Epoch [3900/4000], loss: 5201.01
Epoch [4000/4000], loss: 5194.25
Generated image display#
# Show final generated image
display(Image(result_image_path))