Augmented Normalizing Flow based on Real NVP¶
In [ ]:
Copied!
# Import required packages
import torch
import numpy as np
import normflows as nf
from matplotlib import pyplot as plt
from tqdm import tqdm
# Import required packages
import torch
import numpy as np
import normflows as nf
from matplotlib import pyplot as plt
from tqdm import tqdm
In [ ]:
Copied!
# Set up model
# Define flows
K = 32
torch.manual_seed(0)
latent_size = 4
b = torch.Tensor([1] * (latent_size // 2) + [0] * (latent_size // 2))
flows = []
for i in range(K):
s = nf.nets.MLP([latent_size, 4 * latent_size, latent_size], init_zeros=True)
t = nf.nets.MLP([latent_size, 4 * latent_size, latent_size], init_zeros=True)
if i % 2 == 0:
flows += [nf.flows.MaskedAffineFlow(b, t, s)]
else:
flows += [nf.flows.MaskedAffineFlow(1 - b, t, s)]
flows += [nf.flows.ActNorm(latent_size)]
# Set augmented target
target = nf.distributions.TwoIndependent(nf.distributions.TwoMoons(),
nf.distributions.DiagGaussian(2))
# Set base distribution
q0 = nf.distributions.DiagGaussian(4)
# Construct flow model
nfm = nf.NormalizingFlow(q0=q0, flows=flows, p=target)
# Move model on GPU if available
enable_cuda = True
device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')
nfm = nfm.to(device)
nfm = nfm.double()
# Initialize ActNorm
z, _ = nfm.sample(num_samples=2 ** 7)
z_np = z.to('cpu').data.numpy()
plt.figure(figsize=(15, 15))
plt.hist2d(z_np[:, 0].flatten(), z_np[:, 1].flatten(), (50, 50), range=[[-3, 3], [-3, 3]])
plt.gca().set_aspect('equal', 'box')
plt.title("Standard coordinates")
plt.show()
plt.figure(figsize=(15, 15))
plt.hist2d(z_np[:, 2].flatten(), z_np[:, 3].flatten(), (50, 50), range=[[-3, 3], [-3, 3]])
plt.gca().set_aspect('equal', 'box')
plt.title("Augmented coordinates")
plt.show()
# Set up model
# Define flows
K = 32
torch.manual_seed(0)
latent_size = 4
b = torch.Tensor([1] * (latent_size // 2) + [0] * (latent_size // 2))
flows = []
for i in range(K):
s = nf.nets.MLP([latent_size, 4 * latent_size, latent_size], init_zeros=True)
t = nf.nets.MLP([latent_size, 4 * latent_size, latent_size], init_zeros=True)
if i % 2 == 0:
flows += [nf.flows.MaskedAffineFlow(b, t, s)]
else:
flows += [nf.flows.MaskedAffineFlow(1 - b, t, s)]
flows += [nf.flows.ActNorm(latent_size)]
# Set augmented target
target = nf.distributions.TwoIndependent(nf.distributions.TwoMoons(),
nf.distributions.DiagGaussian(2))
# Set base distribution
q0 = nf.distributions.DiagGaussian(4)
# Construct flow model
nfm = nf.NormalizingFlow(q0=q0, flows=flows, p=target)
# Move model on GPU if available
enable_cuda = True
device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')
nfm = nfm.to(device)
nfm = nfm.double()
# Initialize ActNorm
z, _ = nfm.sample(num_samples=2 ** 7)
z_np = z.to('cpu').data.numpy()
plt.figure(figsize=(15, 15))
plt.hist2d(z_np[:, 0].flatten(), z_np[:, 1].flatten(), (50, 50), range=[[-3, 3], [-3, 3]])
plt.gca().set_aspect('equal', 'box')
plt.title("Standard coordinates")
plt.show()
plt.figure(figsize=(15, 15))
plt.hist2d(z_np[:, 2].flatten(), z_np[:, 3].flatten(), (50, 50), range=[[-3, 3], [-3, 3]])
plt.gca().set_aspect('equal', 'box')
plt.title("Augmented coordinates")
plt.show()
In [ ]:
Copied!
# Plot augmented target
z = target.sample(num_samples=2 ** 16)
z_np = z.to('cpu').data.numpy()
plt.figure(figsize=(15, 15))
plt.hist2d(z_np[:, 0].flatten(), z_np[:, 1].flatten(), (50, 50), range=[[-3, 3], [-3, 3]])
plt.gca().set_aspect('equal', 'box')
plt.title("Standard coordinates")
plt.show()
plt.figure(figsize=(15, 15))
plt.hist2d(z_np[:, 2].flatten(), z_np[:, 3].flatten(), (50, 50), range=[[-3, 3], [-3, 3]])
plt.gca().set_aspect('equal', 'box')
plt.title("Augmented coordinates")
plt.show()
# Plot augmented target
z = target.sample(num_samples=2 ** 16)
z_np = z.to('cpu').data.numpy()
plt.figure(figsize=(15, 15))
plt.hist2d(z_np[:, 0].flatten(), z_np[:, 1].flatten(), (50, 50), range=[[-3, 3], [-3, 3]])
plt.gca().set_aspect('equal', 'box')
plt.title("Standard coordinates")
plt.show()
plt.figure(figsize=(15, 15))
plt.hist2d(z_np[:, 2].flatten(), z_np[:, 3].flatten(), (50, 50), range=[[-3, 3], [-3, 3]])
plt.gca().set_aspect('equal', 'box')
plt.title("Augmented coordinates")
plt.show()
In [ ]:
Copied!
# Train model
max_iter = 20000
num_samples = 2 * 10
anneal_iter = 10000
show_iter = 1000
loss_hist = np.array([])
optimizer = torch.optim.Adam(nfm.parameters(), lr=1e-4, weight_decay=1e-6)
for it in tqdm(range(max_iter)):
optimizer.zero_grad()
loss = nfm.reverse_kld(num_samples, beta=np.min([1., 0.01 + it / anneal_iter]))
if ~(torch.isnan(loss) | torch.isinf(loss)):
loss.backward()
optimizer.step()
loss_hist = np.append(loss_hist, loss.to('cpu').data.numpy())
# Plot learned posterior
if (it + 1) % show_iter == 0:
z, _ = nfm.sample(num_samples=2 ** 14)
z_np = z.to('cpu').data.numpy()
plt.figure(figsize=(15, 15))
plt.hist2d(z_np[:, 0].flatten(), z_np[:, 1].flatten(), (50, 50), range=[[-3, 3], [-3, 3]])
plt.gca().set_aspect('equal', 'box')
plt.title("Standard coordinates")
plt.show()
plt.figure(figsize=(15, 15))
plt.hist2d(z_np[:, 2].flatten(), z_np[:, 3].flatten(), (50, 50), range=[[-3, 3], [-3, 3]])
plt.gca().set_aspect('equal', 'box')
plt.title("Augmented coordinates")
plt.show()
# Train model
max_iter = 20000
num_samples = 2 * 10
anneal_iter = 10000
show_iter = 1000
loss_hist = np.array([])
optimizer = torch.optim.Adam(nfm.parameters(), lr=1e-4, weight_decay=1e-6)
for it in tqdm(range(max_iter)):
optimizer.zero_grad()
loss = nfm.reverse_kld(num_samples, beta=np.min([1., 0.01 + it / anneal_iter]))
if ~(torch.isnan(loss) | torch.isinf(loss)):
loss.backward()
optimizer.step()
loss_hist = np.append(loss_hist, loss.to('cpu').data.numpy())
# Plot learned posterior
if (it + 1) % show_iter == 0:
z, _ = nfm.sample(num_samples=2 ** 14)
z_np = z.to('cpu').data.numpy()
plt.figure(figsize=(15, 15))
plt.hist2d(z_np[:, 0].flatten(), z_np[:, 1].flatten(), (50, 50), range=[[-3, 3], [-3, 3]])
plt.gca().set_aspect('equal', 'box')
plt.title("Standard coordinates")
plt.show()
plt.figure(figsize=(15, 15))
plt.hist2d(z_np[:, 2].flatten(), z_np[:, 3].flatten(), (50, 50), range=[[-3, 3], [-3, 3]])
plt.gca().set_aspect('equal', 'box')
plt.title("Augmented coordinates")
plt.show()
In [ ]:
Copied!
# Plot loss
plt.figure(figsize=(10, 10))
plt.plot(loss_hist, label='loss')
plt.legend()
plt.show()
# Plot loss
plt.figure(figsize=(10, 10))
plt.plot(loss_hist, label='loss')
plt.legend()
plt.show()
In [ ]:
Copied!
# Plot learned distribution
z, _ = nfm.sample(num_samples=2 ** 16)
z_np = z.to('cpu').data.numpy()
plt.figure(figsize=(15, 15))
plt.hist2d(z_np[:, 0].flatten(), z_np[:, 1].flatten(), (50, 50), range=[[-3, 3], [-3, 3]])
plt.gca().set_aspect('equal', 'box')
plt.title("Standard coordinates")
plt.show()
plt.figure(figsize=(15, 15))
plt.hist2d(z_np[:, 2].flatten(), z_np[:, 3].flatten(), (50, 50), range=[[-3, 3], [-3, 3]])
plt.gca().set_aspect('equal', 'box')
plt.title("Augmented coordinates")
plt.show()
# Plot learned distribution
z, _ = nfm.sample(num_samples=2 ** 16)
z_np = z.to('cpu').data.numpy()
plt.figure(figsize=(15, 15))
plt.hist2d(z_np[:, 0].flatten(), z_np[:, 1].flatten(), (50, 50), range=[[-3, 3], [-3, 3]])
plt.gca().set_aspect('equal', 'box')
plt.title("Standard coordinates")
plt.show()
plt.figure(figsize=(15, 15))
plt.hist2d(z_np[:, 2].flatten(), z_np[:, 3].flatten(), (50, 50), range=[[-3, 3], [-3, 3]])
plt.gca().set_aspect('equal', 'box')
plt.title("Augmented coordinates")
plt.show()