Changing the base distribution of a flow model¶
This example shows how one can easily change the base distribution with our API. First, let's look at how the normalizing flow can learn a two moons target distribution with a Gaussian distribution as the base.
# Import packages
import torch
import numpy as np
import normflows as nf
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
from tqdm import tqdm
Setting up a flow model with a 2D Gaussian base distribution¶
# Set up model
# Define 2D Gaussian base distribution
base = nf.distributions.base.DiagGaussian(2)
# Define list of flows
num_layers = 32
flows = []
for i in range(num_layers):
# Neural network with two hidden layers having 64 units each
# Last layer is initialized by zeros making training more stable
param_map = nf.nets.MLP([1, 64, 64, 2], init_zeros=True)
# Add flow layer
flows.append(nf.flows.AffineCouplingBlock(param_map))
# Swap dimensions
flows.append(nf.flows.Permute(2, mode='swap'))
# Construct flow model
model = nf.NormalizingFlow(base, flows)
# Move model on GPU if available
enable_cuda = True
device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')
model = model.to(device)
# Define target distribution
target = nf.distributions.TwoMoons()
# Plot target distribution
grid_size = 200
xx, yy = torch.meshgrid(torch.linspace(-3, 3, grid_size), torch.linspace(-3, 3, grid_size))
zz = torch.cat([xx.unsqueeze(2), yy.unsqueeze(2)], 2).view(-1, 2)
zz = zz.to(device)
log_prob = target.log_prob(zz).to('cpu').view(*xx.shape)
prob = torch.exp(log_prob)
prob[torch.isnan(prob)] = 0
plt.figure(figsize=(15, 15))
plt.pcolormesh(xx, yy, prob.data.numpy(), cmap='coolwarm')
plt.gca().set_aspect('equal', 'box')
plt.show()
# Plot initial flow distribution
model.eval()
log_prob = model.log_prob(zz).to('cpu').view(*xx.shape)
model.train()
prob = torch.exp(log_prob)
prob[torch.isnan(prob)] = 0
plt.figure(figsize=(15, 15))
plt.pcolormesh(xx, yy, prob.data.numpy(), cmap='coolwarm')
plt.gca().set_aspect('equal', 'box')
plt.show()
Training the model¶
# Train model
max_iter = 4000
num_samples = 2 ** 9
show_iter = 500
loss_hist = np.array([])
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-5)
for it in tqdm(range(max_iter)):
optimizer.zero_grad()
# Get training samples
x = target.sample(num_samples).to(device)
# Compute loss
loss = model.forward_kld(x)
# Do backprop and optimizer step
if ~(torch.isnan(loss) | torch.isinf(loss)):
loss.backward()
optimizer.step()
# Log loss
loss_hist = np.append(loss_hist, loss.to('cpu').data.numpy())
# Plot learned distribution
if (it + 1) % show_iter == 0:
model.eval()
log_prob = model.log_prob(zz)
model.train()
prob = torch.exp(log_prob.to('cpu').view(*xx.shape))
prob[torch.isnan(prob)] = 0
plt.figure(figsize=(15, 15))
plt.pcolormesh(xx, yy, prob.data.numpy(), cmap='coolwarm')
plt.gca().set_aspect('equal', 'box')
plt.show()
# Plot loss
plt.figure(figsize=(10, 10))
plt.plot(loss_hist, label='loss')
plt.legend()
plt.show()
Visualizing the learned distribution¶
# Plot target distribution
f, ax = plt.subplots(1, 2, sharey=True, figsize=(15, 7))
log_prob = target.log_prob(zz).to('cpu').view(*xx.shape)
prob = torch.exp(log_prob)
prob[torch.isnan(prob)] = 0
ax[0].pcolormesh(xx, yy, prob.data.numpy(), cmap='coolwarm')
ax[0].set_aspect('equal', 'box')
ax[0].set_axis_off()
ax[0].set_title('Target', fontsize=24)
# Plot learned distribution
model.eval()
log_prob = model.log_prob(zz).to('cpu').view(*xx.shape)
model.train()
prob = torch.exp(log_prob)
prob[torch.isnan(prob)] = 0
ax[1].pcolormesh(xx, yy, prob.data.numpy(), cmap='coolwarm')
ax[1].set_aspect('equal', 'box')
ax[1].set_axis_off()
ax[1].set_title('Real NVP', fontsize=24)
plt.subplots_adjust(wspace=0.1)
plt.show()
Notice there is a bridge between the two modes of the learned target. This is not a big deal usually since the bridge is really thin, and going to higher dimensional space will make it expoentially unlike to have samples within the bridge. However, we can see the shape of each mode is also a bit distorted. So it would be nice to get rid of the bridge. Now let's try to use a Gaussian mixture distribution as our base distribution, instead of a single Gaussian.
Use a Gaussian mixture model as the base instead¶
# Set up model
# Define a mixture of Gaussians with 2 modes.
base = nf.distributions.base.GaussianMixture(2,2, loc=[[-2,0],[2,0]],scale=[[0.3,0.3],[0.3,0.3]])
# Define list of flows
num_layers = 32
flows = []
for i in range(num_layers):
# Neural network with two hidden layers having 64 units each
# Last layer is initialized by zeros making training more stable
param_map = nf.nets.MLP([1, 64, 64, 2], init_zeros=True)
# Add flow layer
flows.append(nf.flows.AffineCouplingBlock(param_map))
# Swap dimensions
flows.append(nf.flows.Permute(2, mode='swap'))
# Construct flow model
model = nf.NormalizingFlow(base, flows).cuda()
# Plot initial flow distribution
model.eval()
log_prob = model.log_prob(zz).to('cpu').view(*xx.shape)
model.train()
prob = torch.exp(log_prob)
prob[torch.isnan(prob)] = 0
plt.figure(figsize=(15, 15))
plt.pcolormesh(xx, yy, prob.data.numpy(), cmap='coolwarm')
plt.gca().set_aspect('equal', 'box')
plt.show()
Train the new model¶
# Train model
max_iter = 4000
num_samples = 2 ** 9
show_iter = 500
loss_hist = np.array([])
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-5)
for it in tqdm(range(max_iter)):
optimizer.zero_grad()
# Get training samples
x = target.sample(num_samples).to(device)
# Compute loss
loss = model.forward_kld(x)
# Do backprop and optimizer step
if ~(torch.isnan(loss) | torch.isinf(loss)):
loss.backward()
optimizer.step()
# Log loss
loss_hist = np.append(loss_hist, loss.to('cpu').data.numpy())
# Plot learned distribution
if (it + 1) % show_iter == 0:
model.eval()
log_prob = model.log_prob(zz)
model.train()
prob = torch.exp(log_prob.to('cpu').view(*xx.shape))
prob[torch.isnan(prob)] = 0
plt.figure(figsize=(15, 15))
plt.pcolormesh(xx, yy, prob.data.numpy(), cmap='coolwarm')
plt.gca().set_aspect('equal', 'box')
plt.show()
# Plot loss
plt.figure(figsize=(10, 10))
plt.plot(loss_hist, label='loss')
plt.legend()
plt.show()
Now the modes are in better shape! And there is no bridge between the two modes!