Learn Distribution given by an Image using Real NVP¶
In [ ]:
Copied!
# Import required packages
import torch
import numpy as np
import normflows as nf
from matplotlib import pyplot as plt
from tqdm import tqdm
# Import required packages
import torch
import numpy as np
import normflows as nf
from matplotlib import pyplot as plt
from tqdm import tqdm
In [ ]:
Copied!
# Set up model
# Define flows
K = 32
#torch.manual_seed(0)
b = torch.tensor([0, 1])
flows = []
for i in range(K):
s = nf.nets.MLP([2, 4, 4, 2])
t = nf.nets.MLP([2, 4, 4, 2])
if i % 2 == 0:
flows += [nf.flows.MaskedAffineFlow(b, t, s)]
else:
flows += [nf.flows.MaskedAffineFlow(1 - b, t, s)]
# Set target and base distribution
img = 1 - plt.imread('img.png')[:, :, 0] # Specify the path to your image here
target = nf.distributions.ImagePrior(img)
q0 = nf.distributions.DiagGaussian(2)
# Construct flow model
nfm = nf.NormalizingFlow(q0=q0, flows=flows, p=target)
# Move model on GPU if available
enable_cuda = True
device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')
nfm = nfm.to(device)
nfm = nfm.double()
# Set up model
# Define flows
K = 32
#torch.manual_seed(0)
b = torch.tensor([0, 1])
flows = []
for i in range(K):
s = nf.nets.MLP([2, 4, 4, 2])
t = nf.nets.MLP([2, 4, 4, 2])
if i % 2 == 0:
flows += [nf.flows.MaskedAffineFlow(b, t, s)]
else:
flows += [nf.flows.MaskedAffineFlow(1 - b, t, s)]
# Set target and base distribution
img = 1 - plt.imread('img.png')[:, :, 0] # Specify the path to your image here
target = nf.distributions.ImagePrior(img)
q0 = nf.distributions.DiagGaussian(2)
# Construct flow model
nfm = nf.NormalizingFlow(q0=q0, flows=flows, p=target)
# Move model on GPU if available
enable_cuda = True
device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')
nfm = nfm.to(device)
nfm = nfm.double()
In [ ]:
Copied!
# Plot prior distribution
grid_size = 200
xx, yy = torch.meshgrid(torch.linspace(-3, 3, grid_size), torch.linspace(-3, 3, grid_size))
zz = torch.cat([xx.unsqueeze(2), yy.unsqueeze(2)], 2).view(-1, 2)
zz = zz.double().to(device)
log_prob = target.log_prob(zz).to('cpu').view(*xx.shape)
prob = torch.exp(log_prob)
plt.figure(figsize=(10, 10))
plt.pcolormesh(xx, yy, prob.data.numpy())
plt.show()
# Plot initial posterior distribution
log_prob = nfm.log_prob(zz).to('cpu').view(*xx.shape)
prob = torch.exp(log_prob)
prob[torch.isnan(prob)] = 0
plt.figure(figsize=(10, 10))
plt.pcolormesh(xx, yy, prob.data.numpy())
plt.show()
# Plot prior distribution
grid_size = 200
xx, yy = torch.meshgrid(torch.linspace(-3, 3, grid_size), torch.linspace(-3, 3, grid_size))
zz = torch.cat([xx.unsqueeze(2), yy.unsqueeze(2)], 2).view(-1, 2)
zz = zz.double().to(device)
log_prob = target.log_prob(zz).to('cpu').view(*xx.shape)
prob = torch.exp(log_prob)
plt.figure(figsize=(10, 10))
plt.pcolormesh(xx, yy, prob.data.numpy())
plt.show()
# Plot initial posterior distribution
log_prob = nfm.log_prob(zz).to('cpu').view(*xx.shape)
prob = torch.exp(log_prob)
prob[torch.isnan(prob)] = 0
plt.figure(figsize=(10, 10))
plt.pcolormesh(xx, yy, prob.data.numpy())
plt.show()
In [ ]:
Copied!
# Train model
max_iter = 10000
num_samples = 2 * 16
show_iter = 2000
loss_hist = np.array([])
optimizer = torch.optim.Adam(nfm.parameters(), lr=1e-4, weight_decay=1e-4)
for it in tqdm(range(max_iter)):
optimizer.zero_grad()
x = nfm.p.sample(num_samples).double()
loss = nfm.forward_kld(x)
loss.backward()
optimizer.step()
loss_hist = np.append(loss_hist, loss.to('cpu').data.numpy())
# Plot learned distribution
if (it + 1) % show_iter == 0:
log_prob = nfm.log_prob(zz).to('cpu').view(*xx.shape)
prob = torch.exp(log_prob)
prob[torch.isnan(prob)] = 0
plt.figure(figsize=(10, 10))
plt.pcolormesh(xx, yy, prob.data.numpy())
plt.show()
plt.figure(figsize=(10, 10))
plt.plot(loss_hist, label='loss')
plt.legend()
plt.show()
# Train model
max_iter = 10000
num_samples = 2 * 16
show_iter = 2000
loss_hist = np.array([])
optimizer = torch.optim.Adam(nfm.parameters(), lr=1e-4, weight_decay=1e-4)
for it in tqdm(range(max_iter)):
optimizer.zero_grad()
x = nfm.p.sample(num_samples).double()
loss = nfm.forward_kld(x)
loss.backward()
optimizer.step()
loss_hist = np.append(loss_hist, loss.to('cpu').data.numpy())
# Plot learned distribution
if (it + 1) % show_iter == 0:
log_prob = nfm.log_prob(zz).to('cpu').view(*xx.shape)
prob = torch.exp(log_prob)
prob[torch.isnan(prob)] = 0
plt.figure(figsize=(10, 10))
plt.pcolormesh(xx, yy, prob.data.numpy())
plt.show()
plt.figure(figsize=(10, 10))
plt.plot(loss_hist, label='loss')
plt.legend()
plt.show()
In [ ]:
Copied!
# Plot learned distribution
log_prob = nfm.log_prob(zz).to('cpu').view(*xx.shape)
prob = torch.exp(log_prob)
prob[torch.isnan(prob)] = 0
plt.figure(figsize=(10, 10))
plt.pcolormesh(xx, yy, prob.data.numpy())
plt.show()
# Plot learned distribution
log_prob = nfm.log_prob(zz).to('cpu').view(*xx.shape)
prob = torch.exp(log_prob)
prob[torch.isnan(prob)] = 0
plt.figure(figsize=(10, 10))
plt.pcolormesh(xx, yy, prob.data.numpy())
plt.show()