Generate Handwritten Digits Using GAN

project
gen-ai
GAN
pytorch
Author

Akshara Soman

Published

July 12, 2024

Modified

August 15, 2024

Objective of the Project

Build a DC-GAN (Deep Convolutional Generative Adversarial Network) to generate images of handwritten digits.

Important Details

  1. Dataset: MNIST handwritten digits dataset (grayscale)
  2. Model: Generative Adversarial Network (GAN)
  3. Code available at: https://github.com/aksharasoman/dcgan
  4. It can be built in google colab: python-notebook

Overview

A Generative Adversarial Network (GAN) model has two major components: a generator and a discriminator. Figure 1 gives outline of a GAN model.

Figure 1: Basic GAN architecture

A generator creates fake samples that mimic the real samples provided to the discriminator network. The discriminator is a binary classifier that evaluates these inputs, determining whether each one is real or fake. The generator’s objective is to produce fake samples that are so similar to real ones that the discriminator incorrectly identifies them as genuine.

GAN loss function consists of two parts: generator loss and discriminator loss.

GAN Training Strategy

During generator training, the discriminator’s weights are kept constant and are not updated, and vice versa.

Implementation

This project can be divided into 7 tasks.

  1. Configurations
  2. Load dataset
  3. Load dataset into batches
  4. Create discriminator network
  5. Create generator network
  6. Create loss function & optimizer
  7. Training Loop

For ease of understanding, you may refer to the iPython notebook, where each task is coded in separate sections.

  1. What is Generative Adversarial Network
  2. Applications (Current state-of-art performers for these applications)
  3. What is Generator?
  4. What is discriminator?
  5. Understanding architecture
  6. Loss functions
  7. How to generate a fake image using GAN?
  8. How to download and transform data in Pytorch?
  9. How to calculate input image size for each layer?
  10. How to build a GAN model from scratch in pytorch?
  11. How to train a Generative Adversarial Network?
    1. How to train the model on colab with GPU?
    2. How to train the model in a remote cluster environment?
  12. Challenges in GAN

Results Snapshot

Digits generated after the first epoch

Digits generated after the first epoch

Digits generated after 15 epochs

Digits generated after 15 epochs

References

  1. Coursera Guided Project: “Deep Learning with PyTorch : Generative Adversarial Network