CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!
CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!
Path: blob/main/17_generative_adversarial_networks_implementation.ipynb
Views: 35
Lecture 17: Generative adversarial networks implementation
In this lecture, we are going to implement a version of generative adversarial training.
Prepare the codebase
To get started, please clone a version of needle repo from the github. You should be able to use the need repo after finishng HW2
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive
/content/drive/MyDrive/10714f22
Cloning into 'lecture17'...
remote: Enumerating objects: 917, done.
remote: Counting objects: 100% (184/184), done.
remote: Compressing objects: 100% (115/115), done.
remote: Total 917 (delta 104), reused 122 (delta 68), pack-reused 733
Receiving objects: 100% (917/917), 265.21 KiB | 1.99 MiB/s, done.
Resolving deltas: 100% (531/531), done.
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pybind11
Downloading pybind11-2.10.1-py3-none-any.whl (216 kB)
|████████████████████████████████| 216 kB 5.3 MB/s
Installing collected packages: pybind11
Successfully installed pybind11-2.10.1
We can then run the following command to make the path to the package available in colab's environment as well as the PYTHONPATH.
Components of a generative advesarial network
There are two main components in a generative adversarial network
A generator that takes a random vector and maps it to a generated(fake) data .
A discriminator that attempts to tell the difference between the real dataset and the fake one.
Parpare the training dataset
For demonstration purpose, we create our "real" dataset as a two dimensional gaussian distribution.
Our goal is to create a generator that can generate a distribution that matches this distribution.
Generator network
Now we are ready to build our generator network G, to keep things simple, we make generator an one layer linear neural network.
At the initialization phase, we just randomly initialized the weight of , as a result, it certainly does not match the training data. Our goal is to setup a generative adveserial training to get it to close to the training data.
Discriminator
Now let us build a discriminator network that classifies the real data from the fake one. Here we use a three layer neural network. Additionally, we make use of the Softmax loss to measure the classification likelihood. Because we are only classifying two classes. Softmax function becomes the sigmoid function for prediction.
We simply reuse SoftmaxLoss here since this is readily available in our current set of homework iterations. Most implementation will use a binary classification closs instead (BCELoss).
Generative advesarial training
A Generative adversarial training process iteratively update the generator and discriminator to play a "minimax" game.
Note that however, in practice, the update step usually use an alternative objective function.
Generator update
Now we are ready to setup the generator update. In the generator update step, we need to optimize the following goal:
Let us first setup an optimizer for G's parameters.
To optimize the above loss function, we just need to generate a fake data , send it through the discriminator and compute the negative log-likelihood that the fake dataset is categorized as real. In another word, we will feed in as label here.
Discriminator update
Now, let us also setup the discriminator update step. The discriminator step optimizes the following objective:
Let us first setup an optimizer to learn 's parameters.
The discriminator loss is also a normal classification loss, by labeling the generated data as (fake) and real data as (real). Importantly, we also do not need to propagate gradient back to the generator in discriminator update, so we will use the detach function to stop the gradient propagation.
Putting it together
Now we can put it together, to summarize, the generative adverserial training cycles through the following steps:
The discriminator update step
Generator update step
We can plot the generated data of the trained generator after a number of iterations. As we can see, the generated dataset after get closer to the real data after training.
Inspect the trained generator
We can compare the weight/bias of trained generator to the parameters we use to genrate the dataset. Importantly, we need to compare the covariance here instead of the transformation matrix.
We can also compare the mean
Modularizing GAN "Loss"
We can modularize GAN step as in a similar way as loss function. The following codeblock shows one way to do so.