How to create Anime Faces using GANs in PyTorch?

How to create Anime Faces using GANs in PyTorch?
How to create Anime Faces using GANs in PyTorch?

AI-based filtering, face changing and photo editing software has taken the internet by storm. People today spend endless hours trying to find the perfect filters to beautify their photos, try quirky animal or rockstar filters or just trying to anticipate how their future self would look like.

With the advancements in virtual reality and augmented reality, developers are trying to constantly inject more life to their models. This desire to generate realistic AI faces has been made easy by the development of Generative Adversarial Networks (GANs for short) in 2014 by Ian J. Goodfellow and his co-authors. In this blog, we will learn how to utilise GANs and PyTorch to create anime faces.

What are Generative Adversarial Networks (GANs)?

As defined by Machine Learning Mastery, ‘Generative modelling is an unsupervised learning task in machine learning that involves automatically discovering and learning the regularities or patterns in input data in such a way that the model can be used to generate or output new examples that plausibly could have been drawn from the original dataset.’

What this essentially means is that GANs belong to a set of generative models which are able to produce novel content based on the data set provided to it. Although generative modelling is an unsupervised learning technique, the GAN architecture frames the training of the model as a supervised learning technique.

This can be described with the help of the following diagram:

Source: Machine Learning Mastery

Mostly used in the field of image transformation, GANs generate novel content with the help of two sub-models, Generator and Discriminator. The generator model takes in a random input vector and generates a new fake content, while the discriminator model tries to classify the examples as real (present in the training data) or fake (developed by the generator model).

These two models are trained together in tandem. The discriminator model is trained for a few epochs, followed by the generator model being trained for a few epochs, and this process is repeated continuously, ensuring both the generator model and discriminator model get better with each iteration. The two models compete against each other in a zero-sum adversarial game. This means that each time the discriminator model successfully distinguishes the real from the fake samples, it is rewarded. No change is made to the models’ parameters, while the generator is penalised and its parameters are updated.

Alternatively, each time the generator model fools the discriminator model, it is rewarded. No change is made to the models’ parameters, while the discriminator is penalised and its parameters are updated. Until the discriminator model is fooled more than 50% of the times, the process is continued to ensure that the generator model generates only plausible content. To get a sense of the power of these generative models, visit thispersondoesnotexist.com, which generates a new face of a person who does not exist every time the page is reloaded.

What is PyTorch?

PyTorch is a free and open-source machine learning library developed by Facebook’s AI Research Lab (FAIR) under the Modified BSD License. It is widely used for computer vision and natural language processing since it is an optimised tensor library used for complex deep learning solutions by utilising GPUs and CPUs. We generally use PyTorch for generating AI faces because it provides two high-level features – Tensor computing and deep neural networks.

Similar to NumPy arrays, the tensors help store and operate a tensor class on a homogeneous multidimensional rectangular array of numbers, while the deep neural networks are built on a tape-based automatic differentiation system. Hence, PyTorch perfectly suits our requirements to create the novel anime faces.

Creating Anime Faces Using GANs:

The detailed code for generating Anime faces can be found at https://www.kaggle.com/kmldas/gan-in-pytorch-deep-fake-anime-faces. Here, I will briefly explain the various steps and functions involved in the process.

  • Dataset:

We can find around 63.5k cropped high quality anime faces dataset in JPG format at https://www.kaggle.com/splcher/animefacedataset. Before training the GAN, we load this dataset using ImageFolder class from torchvision and resize and crop the images to a standard size of 64 x 64 px. Then we normalise the pixel values with mean and standard deviation of 0.5 each to ensure the pixel values are in the range (-1, 1), which is desirable to train the discriminator.

  • Generator Network:

The generator takes in a vector of random inputs, called latent tensors, which serve as seeds for generating the novel image. It converts the latent tensors of shape 128 x 1 x 1 into image tensors of shape 3 x 28 x 28 by using ConvTranspose2d layer from PyTorch to perform deconvolution or transposed convolution, which is the process of filtering a signal to compensate for the undesired convolution by recreating the signal which existed before the convolution process occurred.

We use the ReLU activation inside the generator and TanH activation function on the output layer of the generator since abounded activation (TanH) allows the model to learn quickly and saturate and cover the entire colour space of the training distribution.

  • Discriminator Network:

The discriminator is essentially a binary classification model which distinguishes images into real and fake images, with the hero of a binary cross-entropy loss function. We also use a convolutional neural network (CNN) to output a single number for every image and a leaky ReLU activation which supports higher resolution modelling well since it lows small gradient signals of negative values to pass through. The output is a single number ranging from 0 to 1 which tells the probability of the real input image being picked.

  • Generator Training:

To train the generator model, we make use of a loss function. We first generate a batch of images from the random input vector and pass it into the discriminator, which classifies the image as real or fake. If the discriminator labels the image as real, the generator has fooled the discriminator However, if the discriminator labels the image as fake, we calculate the loss function and perform gradient descent function on the weights of the generator to fine tune the parameters required to produce an image which can fool the discriminator.

  • Discriminator Training:

Just like training the generator model, to train the discriminator model, we make use of a loss function. The discriminator classifies the image provided to it as real or fake. If there is a mismatch between the origin of the image and the classification by the discriminator, it means that the generator has fooled the discriminator and the discriminator uses the loss function to perform gradient descent to adjust its weights to classify better from the next iterations.

  • Output:

The generated images after 1st, 5th, 10th, 20th and 25th epochs of training are given below. As we can see, with an increase in epochs, the new anime faces become more real and clearer.

Source: https://www.kaggle.com/kmldas/gan-in-pytorch-deep-fake-anime-faces

Conclusion:

To conclude, we can use General Adversarial Networks to create very believable anime faces. The different domains in which we can utilise this technology are endless. We can create more realistic video games with better graphics and better virtual reality and augmented reality systems. We can also improve the quality of filters and utilise this technique to transform images.

We hope you liked this article, explore our courses on the website now.

By Saarthak Jain