API references

ClassCondFlow

Bases: Module

Class conditional normalizing Flow model, providing the class to be conditioned on only to the base distribution, as done e.g. in Glow

Source code in normflows/core.py
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
class ClassCondFlow(nn.Module):
    """
    Class conditional normalizing Flow model, providing the
    class to be conditioned on only to the base distribution,
    as done e.g. in [Glow](https://arxiv.org/abs/1807.03039)
    """

    def __init__(self, q0, flows):
        """Constructor

        Args:
          q0: Base distribution
          flows: List of flows
        """
        super().__init__()
        self.q0 = q0
        self.flows = nn.ModuleList(flows)

    def forward_kld(self, x, y):
        """Estimates forward KL divergence, see [arXiv 1912.02762](https://arxiv.org/abs/1912.02762)

        Args:
          x: Batch sampled from target distribution

        Returns:
          Estimate of forward KL divergence averaged over batch
        """
        log_q = torch.zeros(len(x), dtype=x.dtype, device=x.device)
        z = x
        for i in range(len(self.flows) - 1, -1, -1):
            z, log_det = self.flows[i].inverse(z)
            log_q += log_det
        log_q += self.q0.log_prob(z, y)
        return -torch.mean(log_q)

    def sample(self, num_samples=1, y=None):
        """Samples from flow-based approximate distribution

        Args:
          num_samples: Number of samples to draw
          y: Classes to sample from, will be sampled uniformly if None

        Returns:
          Samples, log probability
        """
        z, log_q = self.q0(num_samples, y)
        for flow in self.flows:
            z, log_det = flow(z)
            log_q -= log_det
        return z, log_q

    def log_prob(self, x, y):
        """Get log probability for batch

        Args:
          x: Batch
          y: Classes of x

        Returns:
          log probability
        """
        log_q = torch.zeros(len(x), dtype=x.dtype, device=x.device)
        z = x
        for i in range(len(self.flows) - 1, -1, -1):
            z, log_det = self.flows[i].inverse(z)
            log_q += log_det
        log_q += self.q0.log_prob(z, y)
        return log_q

    def save(self, path):
        """Save state dict of model

        Args:
         param path: Path including filename where to save model
        """
        torch.save(self.state_dict(), path)

    def load(self, path):
        """Load model from state dict

        Args:
          path: Path including filename where to load model from
        """
        self.load_state_dict(torch.load(path))

__init__(q0, flows)

Constructor

Parameters:

Name Type Description Default
q0

Base distribution

required
flows

List of flows

required
Source code in normflows/core.py
376
377
378
379
380
381
382
383
384
385
def __init__(self, q0, flows):
    """Constructor

    Args:
      q0: Base distribution
      flows: List of flows
    """
    super().__init__()
    self.q0 = q0
    self.flows = nn.ModuleList(flows)

forward_kld(x, y)

Estimates forward KL divergence, see arXiv 1912.02762

Parameters:

Name Type Description Default
x

Batch sampled from target distribution

required

Returns:

Type Description

Estimate of forward KL divergence averaged over batch

Source code in normflows/core.py
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
def forward_kld(self, x, y):
    """Estimates forward KL divergence, see [arXiv 1912.02762](https://arxiv.org/abs/1912.02762)

    Args:
      x: Batch sampled from target distribution

    Returns:
      Estimate of forward KL divergence averaged over batch
    """
    log_q = torch.zeros(len(x), dtype=x.dtype, device=x.device)
    z = x
    for i in range(len(self.flows) - 1, -1, -1):
        z, log_det = self.flows[i].inverse(z)
        log_q += log_det
    log_q += self.q0.log_prob(z, y)
    return -torch.mean(log_q)

load(path)

Load model from state dict

Parameters:

Name Type Description Default
path

Path including filename where to load model from

required
Source code in normflows/core.py
446
447
448
449
450
451
452
def load(self, path):
    """Load model from state dict

    Args:
      path: Path including filename where to load model from
    """
    self.load_state_dict(torch.load(path))

log_prob(x, y)

Get log probability for batch

Parameters:

Name Type Description Default
x

Batch

required
y

Classes of x

required

Returns:

Type Description

log probability

Source code in normflows/core.py
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
def log_prob(self, x, y):
    """Get log probability for batch

    Args:
      x: Batch
      y: Classes of x

    Returns:
      log probability
    """
    log_q = torch.zeros(len(x), dtype=x.dtype, device=x.device)
    z = x
    for i in range(len(self.flows) - 1, -1, -1):
        z, log_det = self.flows[i].inverse(z)
        log_q += log_det
    log_q += self.q0.log_prob(z, y)
    return log_q

sample(num_samples=1, y=None)

Samples from flow-based approximate distribution

Parameters:

Name Type Description Default
num_samples

Number of samples to draw

1
y

Classes to sample from, will be sampled uniformly if None

None

Returns:

Type Description

Samples, log probability

Source code in normflows/core.py
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
def sample(self, num_samples=1, y=None):
    """Samples from flow-based approximate distribution

    Args:
      num_samples: Number of samples to draw
      y: Classes to sample from, will be sampled uniformly if None

    Returns:
      Samples, log probability
    """
    z, log_q = self.q0(num_samples, y)
    for flow in self.flows:
        z, log_det = flow(z)
        log_q -= log_det
    return z, log_q

save(path)

Save state dict of model

Parameters:

Name Type Description Default
param path

Path including filename where to save model

required
Source code in normflows/core.py
438
439
440
441
442
443
444
def save(self, path):
    """Save state dict of model

    Args:
     param path: Path including filename where to save model
    """
    torch.save(self.state_dict(), path)

ConditionalNormalizingFlow

Bases: NormalizingFlow

Conditional normalizing flow model, providing condition, which is also called context, to both the base distribution and the flow layers

Source code in normflows/core.py
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
class ConditionalNormalizingFlow(NormalizingFlow):
    """
    Conditional normalizing flow model, providing condition,
    which is also called context, to both the base distribution
    and the flow layers
    """
    def forward(self, z, context=None):
        """Transforms latent variable z to the flow variable x

        Args:
          z: Batch in the latent space
          context: Batch of conditions/context

        Returns:
          Batch in the space of the target distribution
        """
        for flow in self.flows:
            z, _ = flow(z, context=context)
        return z

    def forward_and_log_det(self, z, context=None):
        """Transforms latent variable z to the flow variable x and
        computes log determinant of the Jacobian

        Args:
          z: Batch in the latent space
          context: Batch of conditions/context

        Returns:
          Batch in the space of the target distribution,
          log determinant of the Jacobian
        """
        log_det = torch.zeros(len(z), device=z.device)
        for flow in self.flows:
            z, log_d = flow(z, context=context)
            log_det += log_d
        return z, log_det

    def inverse(self, x, context=None):
        """Transforms flow variable x to the latent variable z

        Args:
          x: Batch in the space of the target distribution
          context: Batch of conditions/context

        Returns:
          Batch in the latent space
        """
        for i in range(len(self.flows) - 1, -1, -1):
            x, _ = self.flows[i].inverse(x, context=context)
        return x

    def inverse_and_log_det(self, x, context=None):
        """Transforms flow variable x to the latent variable z and
        computes log determinant of the Jacobian

        Args:
          x: Batch in the space of the target distribution
          context: Batch of conditions/context

        Returns:
          Batch in the latent space, log determinant of the
          Jacobian
        """
        log_det = torch.zeros(len(x), device=x.device)
        for i in range(len(self.flows) - 1, -1, -1):
            x, log_d = self.flows[i].inverse(x, context=context)
            log_det += log_d
        return x, log_det

    def sample(self, num_samples=1, context=None):
        """Samples from flow-based approximate distribution

        Args:
          num_samples: Number of samples to draw
          context: Batch of conditions/context

        Returns:
          Samples, log probability
        """
        z, log_q = self.q0(num_samples, context=context)
        for flow in self.flows:
            z, log_det = flow(z, context=context)
            log_q -= log_det
        return z, log_q

    def log_prob(self, x, context=None):
        """Get log probability for batch

        Args:
          x: Batch
          context: Batch of conditions/context

        Returns:
          log probability
        """
        log_q = torch.zeros(len(x), dtype=x.dtype, device=x.device)
        z = x
        for i in range(len(self.flows) - 1, -1, -1):
            z, log_det = self.flows[i].inverse(z, context=context)
            log_q += log_det
        log_q += self.q0.log_prob(z, context=context)
        return log_q

    def forward_kld(self, x, context=None):
        """Estimates forward KL divergence, see [arXiv 1912.02762](https://arxiv.org/abs/1912.02762)

        Args:
          x: Batch sampled from target distribution
          context: Batch of conditions/context

        Returns:
          Estimate of forward KL divergence averaged over batch
        """
        log_q = torch.zeros(len(x), device=x.device)
        z = x
        for i in range(len(self.flows) - 1, -1, -1):
            z, log_det = self.flows[i].inverse(z, context=context)
            log_q += log_det
        log_q += self.q0.log_prob(z, context=context)
        return -torch.mean(log_q)

    def reverse_kld(self, num_samples=1, context=None, beta=1.0, score_fn=True):
        """Estimates reverse KL divergence, see [arXiv 1912.02762](https://arxiv.org/abs/1912.02762)

        Args:
          num_samples: Number of samples to draw from base distribution
          context: Batch of conditions/context
          beta: Annealing parameter, see [arXiv 1505.05770](https://arxiv.org/abs/1505.05770)
          score_fn: Flag whether to include score function in gradient, see [arXiv 1703.09194](https://arxiv.org/abs/1703.09194)

        Returns:
          Estimate of the reverse KL divergence averaged over latent samples
        """
        z, log_q_ = self.q0(num_samples, context=context)
        log_q = torch.zeros_like(log_q_)
        log_q += log_q_
        for flow in self.flows:
            z, log_det = flow(z, context=context)
            log_q -= log_det
        if not score_fn:
            z_ = z
            log_q = torch.zeros(len(z_), device=z_.device)
            utils.set_requires_grad(self, False)
            for i in range(len(self.flows) - 1, -1, -1):
                z_, log_det = self.flows[i].inverse(z_, context=context)
                log_q += log_det
            log_q += self.q0.log_prob(z_, context=context)
            utils.set_requires_grad(self, True)
        log_p = self.p.log_prob(z, context=context)
        return torch.mean(log_q) - beta * torch.mean(log_p)

forward(z, context=None)

Transforms latent variable z to the flow variable x

Parameters:

Name Type Description Default
z

Batch in the latent space

required
context

Batch of conditions/context

None

Returns:

Type Description

Batch in the space of the target distribution

Source code in normflows/core.py
222
223
224
225
226
227
228
229
230
231
232
233
234
def forward(self, z, context=None):
    """Transforms latent variable z to the flow variable x

    Args:
      z: Batch in the latent space
      context: Batch of conditions/context

    Returns:
      Batch in the space of the target distribution
    """
    for flow in self.flows:
        z, _ = flow(z, context=context)
    return z

forward_and_log_det(z, context=None)

Transforms latent variable z to the flow variable x and computes log determinant of the Jacobian

Parameters:

Name Type Description Default
z

Batch in the latent space

required
context

Batch of conditions/context

None

Returns:

Type Description

Batch in the space of the target distribution,

log determinant of the Jacobian

Source code in normflows/core.py
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
def forward_and_log_det(self, z, context=None):
    """Transforms latent variable z to the flow variable x and
    computes log determinant of the Jacobian

    Args:
      z: Batch in the latent space
      context: Batch of conditions/context

    Returns:
      Batch in the space of the target distribution,
      log determinant of the Jacobian
    """
    log_det = torch.zeros(len(z), device=z.device)
    for flow in self.flows:
        z, log_d = flow(z, context=context)
        log_det += log_d
    return z, log_det

forward_kld(x, context=None)

Estimates forward KL divergence, see arXiv 1912.02762

Parameters:

Name Type Description Default
x

Batch sampled from target distribution

required
context

Batch of conditions/context

None

Returns:

Type Description

Estimate of forward KL divergence averaged over batch

Source code in normflows/core.py
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
def forward_kld(self, x, context=None):
    """Estimates forward KL divergence, see [arXiv 1912.02762](https://arxiv.org/abs/1912.02762)

    Args:
      x: Batch sampled from target distribution
      context: Batch of conditions/context

    Returns:
      Estimate of forward KL divergence averaged over batch
    """
    log_q = torch.zeros(len(x), device=x.device)
    z = x
    for i in range(len(self.flows) - 1, -1, -1):
        z, log_det = self.flows[i].inverse(z, context=context)
        log_q += log_det
    log_q += self.q0.log_prob(z, context=context)
    return -torch.mean(log_q)

inverse(x, context=None)

Transforms flow variable x to the latent variable z

Parameters:

Name Type Description Default
x

Batch in the space of the target distribution

required
context

Batch of conditions/context

None

Returns:

Type Description

Batch in the latent space

Source code in normflows/core.py
254
255
256
257
258
259
260
261
262
263
264
265
266
def inverse(self, x, context=None):
    """Transforms flow variable x to the latent variable z

    Args:
      x: Batch in the space of the target distribution
      context: Batch of conditions/context

    Returns:
      Batch in the latent space
    """
    for i in range(len(self.flows) - 1, -1, -1):
        x, _ = self.flows[i].inverse(x, context=context)
    return x

inverse_and_log_det(x, context=None)

Transforms flow variable x to the latent variable z and computes log determinant of the Jacobian

Parameters:

Name Type Description Default
x

Batch in the space of the target distribution

required
context

Batch of conditions/context

None

Returns:

Type Description

Batch in the latent space, log determinant of the

Jacobian

Source code in normflows/core.py
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
def inverse_and_log_det(self, x, context=None):
    """Transforms flow variable x to the latent variable z and
    computes log determinant of the Jacobian

    Args:
      x: Batch in the space of the target distribution
      context: Batch of conditions/context

    Returns:
      Batch in the latent space, log determinant of the
      Jacobian
    """
    log_det = torch.zeros(len(x), device=x.device)
    for i in range(len(self.flows) - 1, -1, -1):
        x, log_d = self.flows[i].inverse(x, context=context)
        log_det += log_d
    return x, log_det

log_prob(x, context=None)

Get log probability for batch

Parameters:

Name Type Description Default
x

Batch

required
context

Batch of conditions/context

None

Returns:

Type Description

log probability

Source code in normflows/core.py
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
def log_prob(self, x, context=None):
    """Get log probability for batch

    Args:
      x: Batch
      context: Batch of conditions/context

    Returns:
      log probability
    """
    log_q = torch.zeros(len(x), dtype=x.dtype, device=x.device)
    z = x
    for i in range(len(self.flows) - 1, -1, -1):
        z, log_det = self.flows[i].inverse(z, context=context)
        log_q += log_det
    log_q += self.q0.log_prob(z, context=context)
    return log_q

reverse_kld(num_samples=1, context=None, beta=1.0, score_fn=True)

Estimates reverse KL divergence, see arXiv 1912.02762

Parameters:

Name Type Description Default
num_samples

Number of samples to draw from base distribution

1
context

Batch of conditions/context

None
beta

Annealing parameter, see arXiv 1505.05770

1.0
score_fn

Flag whether to include score function in gradient, see arXiv 1703.09194

True

Returns:

Type Description

Estimate of the reverse KL divergence averaged over latent samples

Source code in normflows/core.py
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
def reverse_kld(self, num_samples=1, context=None, beta=1.0, score_fn=True):
    """Estimates reverse KL divergence, see [arXiv 1912.02762](https://arxiv.org/abs/1912.02762)

    Args:
      num_samples: Number of samples to draw from base distribution
      context: Batch of conditions/context
      beta: Annealing parameter, see [arXiv 1505.05770](https://arxiv.org/abs/1505.05770)
      score_fn: Flag whether to include score function in gradient, see [arXiv 1703.09194](https://arxiv.org/abs/1703.09194)

    Returns:
      Estimate of the reverse KL divergence averaged over latent samples
    """
    z, log_q_ = self.q0(num_samples, context=context)
    log_q = torch.zeros_like(log_q_)
    log_q += log_q_
    for flow in self.flows:
        z, log_det = flow(z, context=context)
        log_q -= log_det
    if not score_fn:
        z_ = z
        log_q = torch.zeros(len(z_), device=z_.device)
        utils.set_requires_grad(self, False)
        for i in range(len(self.flows) - 1, -1, -1):
            z_, log_det = self.flows[i].inverse(z_, context=context)
            log_q += log_det
        log_q += self.q0.log_prob(z_, context=context)
        utils.set_requires_grad(self, True)
    log_p = self.p.log_prob(z, context=context)
    return torch.mean(log_q) - beta * torch.mean(log_p)

sample(num_samples=1, context=None)

Samples from flow-based approximate distribution

Parameters:

Name Type Description Default
num_samples

Number of samples to draw

1
context

Batch of conditions/context

None

Returns:

Type Description

Samples, log probability

Source code in normflows/core.py
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
def sample(self, num_samples=1, context=None):
    """Samples from flow-based approximate distribution

    Args:
      num_samples: Number of samples to draw
      context: Batch of conditions/context

    Returns:
      Samples, log probability
    """
    z, log_q = self.q0(num_samples, context=context)
    for flow in self.flows:
        z, log_det = flow(z, context=context)
        log_q -= log_det
    return z, log_q

MultiscaleFlow

Bases: Module

Normalizing Flow model with multiscale architecture, see RealNVP or Glow paper

Source code in normflows/core.py
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
class MultiscaleFlow(nn.Module):
    """
    Normalizing Flow model with multiscale architecture, see RealNVP or Glow paper
    """

    def __init__(self, q0, flows, merges, transform=None, class_cond=True):
        """Constructor

        Args:

          q0: List of base distribution
          flows: List of flows for each level
          merges: List of merge/split operations (forward pass must do merge)
          transform: Initial transformation of inputs
          class_cond: Flag, indicated whether model has class conditional
        base distributions
        """
        super().__init__()
        self.q0 = nn.ModuleList(q0)
        self.num_levels = len(self.q0)
        self.flows = torch.nn.ModuleList([nn.ModuleList(flow) for flow in flows])
        self.merges = torch.nn.ModuleList(merges)
        self.transform = transform
        self.class_cond = class_cond

    def forward_kld(self, x, y=None):
        """Estimates forward KL divergence, see [arXiv 1912.02762](https://arxiv.org/abs/1912.02762)

        Args:
          x: Batch sampled from target distribution
          y: Batch of classes to condition on, if applicable

        Returns:
          Estimate of forward KL divergence averaged over batch
        """
        return -torch.mean(self.log_prob(x, y))

    def forward(self, x, y=None):
        """Get negative log-likelihood for maximum likelihood training

        Args:
          x: Batch of data
          y: Batch of classes to condition on, if applicable

        Returns:
            Negative log-likelihood of the batch
        """
        return -self.log_prob(x, y)

    def forward_and_log_det(self, z):
        """Get observed variable x from list of latent variables z

        Args:
            z: List of latent variables

        Returns:
            Observed variable x, log determinant of Jacobian
        """
        log_det = torch.zeros(len(z[0]), dtype=z[0].dtype, device=z[0].device)
        for i in range(len(self.q0)):
            if i == 0:
                z_ = z[0]
            else:
                z_, log_det_ = self.merges[i - 1]([z_, z[i]])
                log_det += log_det_
            for flow in self.flows[i]:
                z_, log_det_ = flow(z_)
                log_det += log_det_
        if self.transform is not None:
            z_, log_det_ = self.transform(z_)
            log_det += log_det_
        return z_, log_det

    def inverse_and_log_det(self, x):
        """Get latent variable z from observed variable x

        Args:
            x: Observed variable

        Returns:
            List of latent variables z, log determinant of Jacobian
        """
        log_det = torch.zeros(len(x), dtype=x.dtype, device=x.device)
        if self.transform is not None:
            x, log_det_ = self.transform.inverse(x)
            log_det += log_det_
        z = [None] * len(self.q0)
        for i in range(len(self.q0) - 1, -1, -1):
            for flow in reversed(self.flows[i]):
                x, log_det_ = flow.inverse(x)
                log_det += log_det_
            if i == 0:
                z[i] = x
            else:
                [x, z[i]], log_det_ = self.merges[i - 1].inverse(x)
                log_det += log_det_
        return z, log_det

    def sample(self, num_samples=1, y=None, temperature=None):
        """Samples from flow-based approximate distribution

        Args:
          num_samples: Number of samples to draw
          y: Classes to sample from, will be sampled uniformly if None
          temperature: Temperature parameter for temp annealed sampling

        Returns:
          Samples, log probability
        """
        if temperature is not None:
            self.set_temperature(temperature)
        for i in range(len(self.q0)):
            if self.class_cond:
                z_, log_q_ = self.q0[i](num_samples, y)
            else:
                z_, log_q_ = self.q0[i](num_samples)
            if i == 0:
                log_q = log_q_
                z = z_
            else:
                log_q += log_q_
                z, log_det = self.merges[i - 1]([z, z_])
                log_q -= log_det
            for flow in self.flows[i]:
                z, log_det = flow(z)
                log_q -= log_det
        if self.transform is not None:
            z, log_det = self.transform(z)
            log_q -= log_det
        if temperature is not None:
            self.reset_temperature()
        return z, log_q

    def log_prob(self, x, y=None):
        """Get log probability for batch

        Args:
          x: Batch
          y: Classes of x. Must be passed in if `class_cond` is True.

        Returns:
          log probability
        """
        log_q = 0
        z = x
        if self.transform is not None:
            z, log_det = self.transform.inverse(z)
            log_q += log_det
        for i in range(len(self.q0) - 1, -1, -1):
            for j in range(len(self.flows[i]) - 1, -1, -1):
                z, log_det = self.flows[i][j].inverse(z)
                log_q += log_det
            if i > 0:
                [z, z_], log_det = self.merges[i - 1].inverse(z)
                log_q += log_det
            else:
                z_ = z
            if self.class_cond:
                log_q += self.q0[i].log_prob(z_, y)
            else:
                log_q += self.q0[i].log_prob(z_)
        return log_q

    def save(self, path):
        """Save state dict of model

        Args:
          path: Path including filename where to save model
        """
        torch.save(self.state_dict(), path)

    def load(self, path):
        """Load model from state dict

        Args:
          path: Path including filename where to load model from
        """
        self.load_state_dict(torch.load(path))

    def set_temperature(self, temperature):
        """Set temperature for temperature a annealed sampling

        Args:
          temperature: Temperature parameter
        """
        for q0 in self.q0:
            if hasattr(q0, "temperature"):
                q0.temperature = temperature
            else:
                raise NotImplementedError(
                    "One base function does not "
                    "support temperature annealed sampling"
                )

    def reset_temperature(self):
        """
        Set temperature values of base distributions back to None
        """
        self.set_temperature(None)

__init__(q0, flows, merges, transform=None, class_cond=True)

Constructor

Args:

q0: List of base distribution flows: List of flows for each level merges: List of merge/split operations (forward pass must do merge) transform: Initial transformation of inputs class_cond: Flag, indicated whether model has class conditional base distributions

Source code in normflows/core.py
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
def __init__(self, q0, flows, merges, transform=None, class_cond=True):
    """Constructor

    Args:

      q0: List of base distribution
      flows: List of flows for each level
      merges: List of merge/split operations (forward pass must do merge)
      transform: Initial transformation of inputs
      class_cond: Flag, indicated whether model has class conditional
    base distributions
    """
    super().__init__()
    self.q0 = nn.ModuleList(q0)
    self.num_levels = len(self.q0)
    self.flows = torch.nn.ModuleList([nn.ModuleList(flow) for flow in flows])
    self.merges = torch.nn.ModuleList(merges)
    self.transform = transform
    self.class_cond = class_cond

forward(x, y=None)

Get negative log-likelihood for maximum likelihood training

Parameters:

Name Type Description Default
x

Batch of data

required
y

Batch of classes to condition on, if applicable

None

Returns:

Type Description

Negative log-likelihood of the batch

Source code in normflows/core.py
492
493
494
495
496
497
498
499
500
501
502
def forward(self, x, y=None):
    """Get negative log-likelihood for maximum likelihood training

    Args:
      x: Batch of data
      y: Batch of classes to condition on, if applicable

    Returns:
        Negative log-likelihood of the batch
    """
    return -self.log_prob(x, y)

forward_and_log_det(z)

Get observed variable x from list of latent variables z

Parameters:

Name Type Description Default
z

List of latent variables

required

Returns:

Type Description

Observed variable x, log determinant of Jacobian

Source code in normflows/core.py
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
def forward_and_log_det(self, z):
    """Get observed variable x from list of latent variables z

    Args:
        z: List of latent variables

    Returns:
        Observed variable x, log determinant of Jacobian
    """
    log_det = torch.zeros(len(z[0]), dtype=z[0].dtype, device=z[0].device)
    for i in range(len(self.q0)):
        if i == 0:
            z_ = z[0]
        else:
            z_, log_det_ = self.merges[i - 1]([z_, z[i]])
            log_det += log_det_
        for flow in self.flows[i]:
            z_, log_det_ = flow(z_)
            log_det += log_det_
    if self.transform is not None:
        z_, log_det_ = self.transform(z_)
        log_det += log_det_
    return z_, log_det

forward_kld(x, y=None)

Estimates forward KL divergence, see arXiv 1912.02762

Parameters:

Name Type Description Default
x

Batch sampled from target distribution

required
y

Batch of classes to condition on, if applicable

None

Returns:

Type Description

Estimate of forward KL divergence averaged over batch

Source code in normflows/core.py
480
481
482
483
484
485
486
487
488
489
490
def forward_kld(self, x, y=None):
    """Estimates forward KL divergence, see [arXiv 1912.02762](https://arxiv.org/abs/1912.02762)

    Args:
      x: Batch sampled from target distribution
      y: Batch of classes to condition on, if applicable

    Returns:
      Estimate of forward KL divergence averaged over batch
    """
    return -torch.mean(self.log_prob(x, y))

inverse_and_log_det(x)

Get latent variable z from observed variable x

Parameters:

Name Type Description Default
x

Observed variable

required

Returns:

Type Description

List of latent variables z, log determinant of Jacobian

Source code in normflows/core.py
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
def inverse_and_log_det(self, x):
    """Get latent variable z from observed variable x

    Args:
        x: Observed variable

    Returns:
        List of latent variables z, log determinant of Jacobian
    """
    log_det = torch.zeros(len(x), dtype=x.dtype, device=x.device)
    if self.transform is not None:
        x, log_det_ = self.transform.inverse(x)
        log_det += log_det_
    z = [None] * len(self.q0)
    for i in range(len(self.q0) - 1, -1, -1):
        for flow in reversed(self.flows[i]):
            x, log_det_ = flow.inverse(x)
            log_det += log_det_
        if i == 0:
            z[i] = x
        else:
            [x, z[i]], log_det_ = self.merges[i - 1].inverse(x)
            log_det += log_det_
    return z, log_det

load(path)

Load model from state dict

Parameters:

Name Type Description Default
path

Path including filename where to load model from

required
Source code in normflows/core.py
626
627
628
629
630
631
632
def load(self, path):
    """Load model from state dict

    Args:
      path: Path including filename where to load model from
    """
    self.load_state_dict(torch.load(path))

log_prob(x, y=None)

Get log probability for batch

Parameters:

Name Type Description Default
x

Batch

required
y

Classes of x. Must be passed in if class_cond is True.

None

Returns:

Type Description

log probability

Source code in normflows/core.py
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
def log_prob(self, x, y=None):
    """Get log probability for batch

    Args:
      x: Batch
      y: Classes of x. Must be passed in if `class_cond` is True.

    Returns:
      log probability
    """
    log_q = 0
    z = x
    if self.transform is not None:
        z, log_det = self.transform.inverse(z)
        log_q += log_det
    for i in range(len(self.q0) - 1, -1, -1):
        for j in range(len(self.flows[i]) - 1, -1, -1):
            z, log_det = self.flows[i][j].inverse(z)
            log_q += log_det
        if i > 0:
            [z, z_], log_det = self.merges[i - 1].inverse(z)
            log_q += log_det
        else:
            z_ = z
        if self.class_cond:
            log_q += self.q0[i].log_prob(z_, y)
        else:
            log_q += self.q0[i].log_prob(z_)
    return log_q

reset_temperature()

Set temperature values of base distributions back to None

Source code in normflows/core.py
649
650
651
652
653
def reset_temperature(self):
    """
    Set temperature values of base distributions back to None
    """
    self.set_temperature(None)

sample(num_samples=1, y=None, temperature=None)

Samples from flow-based approximate distribution

Parameters:

Name Type Description Default
num_samples

Number of samples to draw

1
y

Classes to sample from, will be sampled uniformly if None

None
temperature

Temperature parameter for temp annealed sampling

None

Returns:

Type Description

Samples, log probability

Source code in normflows/core.py
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
def sample(self, num_samples=1, y=None, temperature=None):
    """Samples from flow-based approximate distribution

    Args:
      num_samples: Number of samples to draw
      y: Classes to sample from, will be sampled uniformly if None
      temperature: Temperature parameter for temp annealed sampling

    Returns:
      Samples, log probability
    """
    if temperature is not None:
        self.set_temperature(temperature)
    for i in range(len(self.q0)):
        if self.class_cond:
            z_, log_q_ = self.q0[i](num_samples, y)
        else:
            z_, log_q_ = self.q0[i](num_samples)
        if i == 0:
            log_q = log_q_
            z = z_
        else:
            log_q += log_q_
            z, log_det = self.merges[i - 1]([z, z_])
            log_q -= log_det
        for flow in self.flows[i]:
            z, log_det = flow(z)
            log_q -= log_det
    if self.transform is not None:
        z, log_det = self.transform(z)
        log_q -= log_det
    if temperature is not None:
        self.reset_temperature()
    return z, log_q

save(path)

Save state dict of model

Parameters:

Name Type Description Default
path

Path including filename where to save model

required
Source code in normflows/core.py
618
619
620
621
622
623
624
def save(self, path):
    """Save state dict of model

    Args:
      path: Path including filename where to save model
    """
    torch.save(self.state_dict(), path)

set_temperature(temperature)

Set temperature for temperature a annealed sampling

Parameters:

Name Type Description Default
temperature

Temperature parameter

required
Source code in normflows/core.py
634
635
636
637
638
639
640
641
642
643
644
645
646
647
def set_temperature(self, temperature):
    """Set temperature for temperature a annealed sampling

    Args:
      temperature: Temperature parameter
    """
    for q0 in self.q0:
        if hasattr(q0, "temperature"):
            q0.temperature = temperature
        else:
            raise NotImplementedError(
                "One base function does not "
                "support temperature annealed sampling"
            )

NormalizingFlow

Bases: Module

Normalizing Flow model to approximate target distribution

Source code in normflows/core.py
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
class NormalizingFlow(nn.Module):
    """
    Normalizing Flow model to approximate target distribution
    """

    def __init__(self, q0, flows, p=None):
        """Constructor

        Args:
          q0: Base distribution
          flows: List of flows
          p: Target distribution
        """
        super().__init__()
        self.q0 = q0
        self.flows = nn.ModuleList(flows)
        self.p = p

    def forward(self, z):
        """Transforms latent variable z to the flow variable x

        Args:
          z: Batch in the latent space

        Returns:
          Batch in the space of the target distribution
        """
        for flow in self.flows:
            z, _ = flow(z)
        return z

    def forward_and_log_det(self, z):
        """Transforms latent variable z to the flow variable x and
        computes log determinant of the Jacobian

        Args:
          z: Batch in the latent space

        Returns:
          Batch in the space of the target distribution,
          log determinant of the Jacobian
        """
        log_det = torch.zeros(len(z), device=z.device)
        for flow in self.flows:
            z, log_d = flow(z)
            log_det += log_d
        return z, log_det

    def inverse(self, x):
        """Transforms flow variable x to the latent variable z

        Args:
          x: Batch in the space of the target distribution

        Returns:
          Batch in the latent space
        """
        for i in range(len(self.flows) - 1, -1, -1):
            x, _ = self.flows[i].inverse(x)
        return x

    def inverse_and_log_det(self, x):
        """Transforms flow variable x to the latent variable z and
        computes log determinant of the Jacobian

        Args:
          x: Batch in the space of the target distribution

        Returns:
          Batch in the latent space, log determinant of the
          Jacobian
        """
        log_det = torch.zeros(len(x), device=x.device)
        for i in range(len(self.flows) - 1, -1, -1):
            x, log_d = self.flows[i].inverse(x)
            log_det += log_d
        return x, log_det

    def forward_kld(self, x):
        """Estimates forward KL divergence, see [arXiv 1912.02762](https://arxiv.org/abs/1912.02762)

        Args:
          x: Batch sampled from target distribution

        Returns:
          Estimate of forward KL divergence averaged over batch
        """
        log_q = torch.zeros(len(x), device=x.device)
        z = x
        for i in range(len(self.flows) - 1, -1, -1):
            z, log_det = self.flows[i].inverse(z)
            log_q += log_det
        log_q += self.q0.log_prob(z)
        return -torch.mean(log_q)

    def reverse_kld(self, num_samples=1, beta=1.0, score_fn=True):
        """Estimates reverse KL divergence, see [arXiv 1912.02762](https://arxiv.org/abs/1912.02762)

        Args:
          num_samples: Number of samples to draw from base distribution
          beta: Annealing parameter, see [arXiv 1505.05770](https://arxiv.org/abs/1505.05770)
          score_fn: Flag whether to include score function in gradient, see [arXiv 1703.09194](https://arxiv.org/abs/1703.09194)

        Returns:
          Estimate of the reverse KL divergence averaged over latent samples
        """
        z, log_q_ = self.q0(num_samples)
        log_q = torch.zeros_like(log_q_)
        log_q += log_q_
        for flow in self.flows:
            z, log_det = flow(z)
            log_q -= log_det
        if not score_fn:
            z_ = z
            log_q = torch.zeros(len(z_), device=z_.device)
            utils.set_requires_grad(self, False)
            for i in range(len(self.flows) - 1, -1, -1):
                z_, log_det = self.flows[i].inverse(z_)
                log_q += log_det
            log_q += self.q0.log_prob(z_)
            utils.set_requires_grad(self, True)
        log_p = self.p.log_prob(z)
        return torch.mean(log_q) - beta * torch.mean(log_p)

    def reverse_alpha_div(self, num_samples=1, alpha=1, dreg=False):
        """Alpha divergence when sampling from q

        Args:
          num_samples: Number of samples to draw
          dreg: Flag whether to use Double Reparametrized Gradient estimator, see [arXiv 1810.04152](https://arxiv.org/abs/1810.04152)

        Returns:
          Alpha divergence
        """
        z, log_q = self.q0(num_samples)
        for flow in self.flows:
            z, log_det = flow(z)
            log_q -= log_det
        log_p = self.p.log_prob(z)
        if dreg:
            w_const = torch.exp(log_p - log_q).detach()
            z_ = z
            log_q = torch.zeros(len(z_), device=z_.device)
            utils.set_requires_grad(self, False)
            for i in range(len(self.flows) - 1, -1, -1):
                z_, log_det = self.flows[i].inverse(z_)
                log_q += log_det
            log_q += self.q0.log_prob(z_)
            utils.set_requires_grad(self, True)
            w = torch.exp(log_p - log_q)
            w_alpha = w_const**alpha
            w_alpha = w_alpha / torch.mean(w_alpha)
            weights = (1 - alpha) * w_alpha + alpha * w_alpha**2
            loss = -alpha * torch.mean(weights * torch.log(w))
        else:
            loss = np.sign(alpha - 1) * torch.logsumexp(alpha * (log_p - log_q), 0)
        return loss

    def sample(self, num_samples=1):
        """Samples from flow-based approximate distribution

        Args:
          num_samples: Number of samples to draw

        Returns:
          Samples, log probability
        """
        z, log_q = self.q0(num_samples)
        for flow in self.flows:
            z, log_det = flow(z)
            log_q -= log_det
        return z, log_q

    def log_prob(self, x):
        """Get log probability for batch

        Args:
          x: Batch

        Returns:
          log probability
        """
        log_q = torch.zeros(len(x), dtype=x.dtype, device=x.device)
        z = x
        for i in range(len(self.flows) - 1, -1, -1):
            z, log_det = self.flows[i].inverse(z)
            log_q += log_det
        log_q += self.q0.log_prob(z)
        return log_q

    def save(self, path):
        """Save state dict of model

        Args:
          path: Path including filename where to save model
        """
        torch.save(self.state_dict(), path)

    def load(self, path):
        """Load model from state dict

        Args:
          path: Path including filename where to load model from
        """
        self.load_state_dict(torch.load(path))

__init__(q0, flows, p=None)

Constructor

Parameters:

Name Type Description Default
q0

Base distribution

required
flows

List of flows

required
p

Target distribution

None
Source code in normflows/core.py
14
15
16
17
18
19
20
21
22
23
24
25
def __init__(self, q0, flows, p=None):
    """Constructor

    Args:
      q0: Base distribution
      flows: List of flows
      p: Target distribution
    """
    super().__init__()
    self.q0 = q0
    self.flows = nn.ModuleList(flows)
    self.p = p

forward(z)

Transforms latent variable z to the flow variable x

Parameters:

Name Type Description Default
z

Batch in the latent space

required

Returns:

Type Description

Batch in the space of the target distribution

Source code in normflows/core.py
27
28
29
30
31
32
33
34
35
36
37
38
def forward(self, z):
    """Transforms latent variable z to the flow variable x

    Args:
      z: Batch in the latent space

    Returns:
      Batch in the space of the target distribution
    """
    for flow in self.flows:
        z, _ = flow(z)
    return z

forward_and_log_det(z)

Transforms latent variable z to the flow variable x and computes log determinant of the Jacobian

Parameters:

Name Type Description Default
z

Batch in the latent space

required

Returns:

Type Description

Batch in the space of the target distribution,

log determinant of the Jacobian

Source code in normflows/core.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def forward_and_log_det(self, z):
    """Transforms latent variable z to the flow variable x and
    computes log determinant of the Jacobian

    Args:
      z: Batch in the latent space

    Returns:
      Batch in the space of the target distribution,
      log determinant of the Jacobian
    """
    log_det = torch.zeros(len(z), device=z.device)
    for flow in self.flows:
        z, log_d = flow(z)
        log_det += log_d
    return z, log_det

forward_kld(x)

Estimates forward KL divergence, see arXiv 1912.02762

Parameters:

Name Type Description Default
x

Batch sampled from target distribution

required

Returns:

Type Description

Estimate of forward KL divergence averaged over batch

Source code in normflows/core.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def forward_kld(self, x):
    """Estimates forward KL divergence, see [arXiv 1912.02762](https://arxiv.org/abs/1912.02762)

    Args:
      x: Batch sampled from target distribution

    Returns:
      Estimate of forward KL divergence averaged over batch
    """
    log_q = torch.zeros(len(x), device=x.device)
    z = x
    for i in range(len(self.flows) - 1, -1, -1):
        z, log_det = self.flows[i].inverse(z)
        log_q += log_det
    log_q += self.q0.log_prob(z)
    return -torch.mean(log_q)

inverse(x)

Transforms flow variable x to the latent variable z

Parameters:

Name Type Description Default
x

Batch in the space of the target distribution

required

Returns:

Type Description

Batch in the latent space

Source code in normflows/core.py
57
58
59
60
61
62
63
64
65
66
67
68
def inverse(self, x):
    """Transforms flow variable x to the latent variable z

    Args:
      x: Batch in the space of the target distribution

    Returns:
      Batch in the latent space
    """
    for i in range(len(self.flows) - 1, -1, -1):
        x, _ = self.flows[i].inverse(x)
    return x

inverse_and_log_det(x)

Transforms flow variable x to the latent variable z and computes log determinant of the Jacobian

Parameters:

Name Type Description Default
x

Batch in the space of the target distribution

required

Returns:

Type Description

Batch in the latent space, log determinant of the

Jacobian

Source code in normflows/core.py
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def inverse_and_log_det(self, x):
    """Transforms flow variable x to the latent variable z and
    computes log determinant of the Jacobian

    Args:
      x: Batch in the space of the target distribution

    Returns:
      Batch in the latent space, log determinant of the
      Jacobian
    """
    log_det = torch.zeros(len(x), device=x.device)
    for i in range(len(self.flows) - 1, -1, -1):
        x, log_d = self.flows[i].inverse(x)
        log_det += log_d
    return x, log_det

load(path)

Load model from state dict

Parameters:

Name Type Description Default
path

Path including filename where to load model from

required
Source code in normflows/core.py
207
208
209
210
211
212
213
def load(self, path):
    """Load model from state dict

    Args:
      path: Path including filename where to load model from
    """
    self.load_state_dict(torch.load(path))

log_prob(x)

Get log probability for batch

Parameters:

Name Type Description Default
x

Batch

required

Returns:

Type Description

log probability

Source code in normflows/core.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
def log_prob(self, x):
    """Get log probability for batch

    Args:
      x: Batch

    Returns:
      log probability
    """
    log_q = torch.zeros(len(x), dtype=x.dtype, device=x.device)
    z = x
    for i in range(len(self.flows) - 1, -1, -1):
        z, log_det = self.flows[i].inverse(z)
        log_q += log_det
    log_q += self.q0.log_prob(z)
    return log_q

reverse_alpha_div(num_samples=1, alpha=1, dreg=False)

Alpha divergence when sampling from q

Parameters:

Name Type Description Default
num_samples

Number of samples to draw

1
dreg

Flag whether to use Double Reparametrized Gradient estimator, see arXiv 1810.04152

False

Returns:

Type Description

Alpha divergence

Source code in normflows/core.py
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
def reverse_alpha_div(self, num_samples=1, alpha=1, dreg=False):
    """Alpha divergence when sampling from q

    Args:
      num_samples: Number of samples to draw
      dreg: Flag whether to use Double Reparametrized Gradient estimator, see [arXiv 1810.04152](https://arxiv.org/abs/1810.04152)

    Returns:
      Alpha divergence
    """
    z, log_q = self.q0(num_samples)
    for flow in self.flows:
        z, log_det = flow(z)
        log_q -= log_det
    log_p = self.p.log_prob(z)
    if dreg:
        w_const = torch.exp(log_p - log_q).detach()
        z_ = z
        log_q = torch.zeros(len(z_), device=z_.device)
        utils.set_requires_grad(self, False)
        for i in range(len(self.flows) - 1, -1, -1):
            z_, log_det = self.flows[i].inverse(z_)
            log_q += log_det
        log_q += self.q0.log_prob(z_)
        utils.set_requires_grad(self, True)
        w = torch.exp(log_p - log_q)
        w_alpha = w_const**alpha
        w_alpha = w_alpha / torch.mean(w_alpha)
        weights = (1 - alpha) * w_alpha + alpha * w_alpha**2
        loss = -alpha * torch.mean(weights * torch.log(w))
    else:
        loss = np.sign(alpha - 1) * torch.logsumexp(alpha * (log_p - log_q), 0)
    return loss

reverse_kld(num_samples=1, beta=1.0, score_fn=True)

Estimates reverse KL divergence, see arXiv 1912.02762

Parameters:

Name Type Description Default
num_samples

Number of samples to draw from base distribution

1
beta

Annealing parameter, see arXiv 1505.05770

1.0
score_fn

Flag whether to include score function in gradient, see arXiv 1703.09194

True

Returns:

Type Description

Estimate of the reverse KL divergence averaged over latent samples

Source code in normflows/core.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
def reverse_kld(self, num_samples=1, beta=1.0, score_fn=True):
    """Estimates reverse KL divergence, see [arXiv 1912.02762](https://arxiv.org/abs/1912.02762)

    Args:
      num_samples: Number of samples to draw from base distribution
      beta: Annealing parameter, see [arXiv 1505.05770](https://arxiv.org/abs/1505.05770)
      score_fn: Flag whether to include score function in gradient, see [arXiv 1703.09194](https://arxiv.org/abs/1703.09194)

    Returns:
      Estimate of the reverse KL divergence averaged over latent samples
    """
    z, log_q_ = self.q0(num_samples)
    log_q = torch.zeros_like(log_q_)
    log_q += log_q_
    for flow in self.flows:
        z, log_det = flow(z)
        log_q -= log_det
    if not score_fn:
        z_ = z
        log_q = torch.zeros(len(z_), device=z_.device)
        utils.set_requires_grad(self, False)
        for i in range(len(self.flows) - 1, -1, -1):
            z_, log_det = self.flows[i].inverse(z_)
            log_q += log_det
        log_q += self.q0.log_prob(z_)
        utils.set_requires_grad(self, True)
    log_p = self.p.log_prob(z)
    return torch.mean(log_q) - beta * torch.mean(log_p)

sample(num_samples=1)

Samples from flow-based approximate distribution

Parameters:

Name Type Description Default
num_samples

Number of samples to draw

1

Returns:

Type Description

Samples, log probability

Source code in normflows/core.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
def sample(self, num_samples=1):
    """Samples from flow-based approximate distribution

    Args:
      num_samples: Number of samples to draw

    Returns:
      Samples, log probability
    """
    z, log_q = self.q0(num_samples)
    for flow in self.flows:
        z, log_det = flow(z)
        log_q -= log_det
    return z, log_q

save(path)

Save state dict of model

Parameters:

Name Type Description Default
path

Path including filename where to save model

required
Source code in normflows/core.py
199
200
201
202
203
204
205
def save(self, path):
    """Save state dict of model

    Args:
      path: Path including filename where to save model
    """
    torch.save(self.state_dict(), path)

NormalizingFlowVAE

Bases: Module

VAE using normalizing flows to express approximate distribution

Source code in normflows/core.py
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
class NormalizingFlowVAE(nn.Module):
    """
    VAE using normalizing flows to express approximate distribution
    """

    def __init__(self, prior, q0=distributions.Dirac(), flows=None, decoder=None):
        """Constructor of normalizing flow model

        Args:
          prior: Prior distribution of te VAE, i.e. Gaussian
          decoder: Optional decoder
          flows: Flows to transform output of base encoder
          q0: Base Encoder
        """
        super().__init__()
        self.prior = prior
        self.decoder = decoder
        self.flows = nn.ModuleList(flows)
        self.q0 = q0

    def forward(self, x, num_samples=1):
        """Takes data batch, samples num_samples for each data point from base distribution

        Args:
          x: data batch
          num_samples: number of samples to draw for each data point

        Returns:
          latent variables for each batch and sample, log_q, and log_p
        """
        z, log_q = self.q0(x, num_samples=num_samples)
        # Flatten batch and sample dim
        z = z.view(-1, *z.size()[2:])
        log_q = log_q.view(-1, *log_q.size()[2:])
        for flow in self.flows:
            z, log_det = flow(z)
            log_q -= log_det
        log_p = self.prior.log_prob(z)
        if self.decoder is not None:
            log_p += self.decoder.log_prob(x, z)
        # Separate batch and sample dimension again
        z = z.view(-1, num_samples, *z.size()[1:])
        log_q = log_q.view(-1, num_samples, *log_q.size()[1:])
        log_p = log_p.view(-1, num_samples, *log_p.size()[1:])
        return z, log_q, log_p

__init__(prior, q0=distributions.Dirac(), flows=None, decoder=None)

Constructor of normalizing flow model

Parameters:

Name Type Description Default
prior

Prior distribution of te VAE, i.e. Gaussian

required
decoder

Optional decoder

None
flows

Flows to transform output of base encoder

None
q0

Base Encoder

Dirac()
Source code in normflows/core.py
661
662
663
664
665
666
667
668
669
670
671
672
673
674
def __init__(self, prior, q0=distributions.Dirac(), flows=None, decoder=None):
    """Constructor of normalizing flow model

    Args:
      prior: Prior distribution of te VAE, i.e. Gaussian
      decoder: Optional decoder
      flows: Flows to transform output of base encoder
      q0: Base Encoder
    """
    super().__init__()
    self.prior = prior
    self.decoder = decoder
    self.flows = nn.ModuleList(flows)
    self.q0 = q0

forward(x, num_samples=1)

Takes data batch, samples num_samples for each data point from base distribution

Parameters:

Name Type Description Default
x

data batch

required
num_samples

number of samples to draw for each data point

1

Returns:

Type Description

latent variables for each batch and sample, log_q, and log_p

Source code in normflows/core.py
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
def forward(self, x, num_samples=1):
    """Takes data batch, samples num_samples for each data point from base distribution

    Args:
      x: data batch
      num_samples: number of samples to draw for each data point

    Returns:
      latent variables for each batch and sample, log_q, and log_p
    """
    z, log_q = self.q0(x, num_samples=num_samples)
    # Flatten batch and sample dim
    z = z.view(-1, *z.size()[2:])
    log_q = log_q.view(-1, *log_q.size()[2:])
    for flow in self.flows:
        z, log_det = flow(z)
        log_q -= log_det
    log_p = self.prior.log_prob(z)
    if self.decoder is not None:
        log_p += self.decoder.log_prob(x, z)
    # Separate batch and sample dimension again
    z = z.view(-1, num_samples, *z.size()[1:])
    log_q = log_q.view(-1, num_samples, *log_q.size()[1:])
    log_p = log_p.view(-1, num_samples, *log_p.size()[1:])
    return z, log_q, log_p

core

ClassCondFlow

Bases: Module

Class conditional normalizing Flow model, providing the class to be conditioned on only to the base distribution, as done e.g. in Glow

Source code in normflows/core.py
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
class ClassCondFlow(nn.Module):
    """
    Class conditional normalizing Flow model, providing the
    class to be conditioned on only to the base distribution,
    as done e.g. in [Glow](https://arxiv.org/abs/1807.03039)
    """

    def __init__(self, q0, flows):
        """Constructor

        Args:
          q0: Base distribution
          flows: List of flows
        """
        super().__init__()
        self.q0 = q0
        self.flows = nn.ModuleList(flows)

    def forward_kld(self, x, y):
        """Estimates forward KL divergence, see [arXiv 1912.02762](https://arxiv.org/abs/1912.02762)

        Args:
          x: Batch sampled from target distribution

        Returns:
          Estimate of forward KL divergence averaged over batch
        """
        log_q = torch.zeros(len(x), dtype=x.dtype, device=x.device)
        z = x
        for i in range(len(self.flows) - 1, -1, -1):
            z, log_det = self.flows[i].inverse(z)
            log_q += log_det
        log_q += self.q0.log_prob(z, y)
        return -torch.mean(log_q)

    def sample(self, num_samples=1, y=None):
        """Samples from flow-based approximate distribution

        Args:
          num_samples: Number of samples to draw
          y: Classes to sample from, will be sampled uniformly if None

        Returns:
          Samples, log probability
        """
        z, log_q = self.q0(num_samples, y)
        for flow in self.flows:
            z, log_det = flow(z)
            log_q -= log_det
        return z, log_q

    def log_prob(self, x, y):
        """Get log probability for batch

        Args:
          x: Batch
          y: Classes of x

        Returns:
          log probability
        """
        log_q = torch.zeros(len(x), dtype=x.dtype, device=x.device)
        z = x
        for i in range(len(self.flows) - 1, -1, -1):
            z, log_det = self.flows[i].inverse(z)
            log_q += log_det
        log_q += self.q0.log_prob(z, y)
        return log_q

    def save(self, path):
        """Save state dict of model

        Args:
         param path: Path including filename where to save model
        """
        torch.save(self.state_dict(), path)

    def load(self, path):
        """Load model from state dict

        Args:
          path: Path including filename where to load model from
        """
        self.load_state_dict(torch.load(path))

__init__(q0, flows)

Constructor

Parameters:

Name Type Description Default
q0

Base distribution

required
flows

List of flows

required
Source code in normflows/core.py
376
377
378
379
380
381
382
383
384
385
def __init__(self, q0, flows):
    """Constructor

    Args:
      q0: Base distribution
      flows: List of flows
    """
    super().__init__()
    self.q0 = q0
    self.flows = nn.ModuleList(flows)

forward_kld(x, y)

Estimates forward KL divergence, see arXiv 1912.02762

Parameters:

Name Type Description Default
x

Batch sampled from target distribution

required

Returns:

Type Description

Estimate of forward KL divergence averaged over batch

Source code in normflows/core.py
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
def forward_kld(self, x, y):
    """Estimates forward KL divergence, see [arXiv 1912.02762](https://arxiv.org/abs/1912.02762)

    Args:
      x: Batch sampled from target distribution

    Returns:
      Estimate of forward KL divergence averaged over batch
    """
    log_q = torch.zeros(len(x), dtype=x.dtype, device=x.device)
    z = x
    for i in range(len(self.flows) - 1, -1, -1):
        z, log_det = self.flows[i].inverse(z)
        log_q += log_det
    log_q += self.q0.log_prob(z, y)
    return -torch.mean(log_q)

load(path)

Load model from state dict

Parameters:

Name Type Description Default
path

Path including filename where to load model from

required
Source code in normflows/core.py
446
447
448
449
450
451
452
def load(self, path):
    """Load model from state dict

    Args:
      path: Path including filename where to load model from
    """
    self.load_state_dict(torch.load(path))

log_prob(x, y)

Get log probability for batch

Parameters:

Name Type Description Default
x

Batch

required
y

Classes of x

required

Returns:

Type Description

log probability

Source code in normflows/core.py
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
def log_prob(self, x, y):
    """Get log probability for batch

    Args:
      x: Batch
      y: Classes of x

    Returns:
      log probability
    """
    log_q = torch.zeros(len(x), dtype=x.dtype, device=x.device)
    z = x
    for i in range(len(self.flows) - 1, -1, -1):
        z, log_det = self.flows[i].inverse(z)
        log_q += log_det
    log_q += self.q0.log_prob(z, y)
    return log_q

sample(num_samples=1, y=None)

Samples from flow-based approximate distribution

Parameters:

Name Type Description Default
num_samples

Number of samples to draw

1
y

Classes to sample from, will be sampled uniformly if None

None

Returns:

Type Description

Samples, log probability

Source code in normflows/core.py
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
def sample(self, num_samples=1, y=None):
    """Samples from flow-based approximate distribution

    Args:
      num_samples: Number of samples to draw
      y: Classes to sample from, will be sampled uniformly if None

    Returns:
      Samples, log probability
    """
    z, log_q = self.q0(num_samples, y)
    for flow in self.flows:
        z, log_det = flow(z)
        log_q -= log_det
    return z, log_q

save(path)

Save state dict of model

Parameters:

Name Type Description Default
param path

Path including filename where to save model

required
Source code in normflows/core.py
438
439
440
441
442
443
444
def save(self, path):
    """Save state dict of model

    Args:
     param path: Path including filename where to save model
    """
    torch.save(self.state_dict(), path)

ConditionalNormalizingFlow

Bases: NormalizingFlow

Conditional normalizing flow model, providing condition, which is also called context, to both the base distribution and the flow layers

Source code in normflows/core.py
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
class ConditionalNormalizingFlow(NormalizingFlow):
    """
    Conditional normalizing flow model, providing condition,
    which is also called context, to both the base distribution
    and the flow layers
    """
    def forward(self, z, context=None):
        """Transforms latent variable z to the flow variable x

        Args:
          z: Batch in the latent space
          context: Batch of conditions/context

        Returns:
          Batch in the space of the target distribution
        """
        for flow in self.flows:
            z, _ = flow(z, context=context)
        return z

    def forward_and_log_det(self, z, context=None):
        """Transforms latent variable z to the flow variable x and
        computes log determinant of the Jacobian

        Args:
          z: Batch in the latent space
          context: Batch of conditions/context

        Returns:
          Batch in the space of the target distribution,
          log determinant of the Jacobian
        """
        log_det = torch.zeros(len(z), device=z.device)
        for flow in self.flows:
            z, log_d = flow(z, context=context)
            log_det += log_d
        return z, log_det

    def inverse(self, x, context=None):
        """Transforms flow variable x to the latent variable z

        Args:
          x: Batch in the space of the target distribution
          context: Batch of conditions/context

        Returns:
          Batch in the latent space
        """
        for i in range(len(self.flows) - 1, -1, -1):
            x, _ = self.flows[i].inverse(x, context=context)
        return x

    def inverse_and_log_det(self, x, context=None):
        """Transforms flow variable x to the latent variable z and
        computes log determinant of the Jacobian

        Args:
          x: Batch in the space of the target distribution
          context: Batch of conditions/context

        Returns:
          Batch in the latent space, log determinant of the
          Jacobian
        """
        log_det = torch.zeros(len(x), device=x.device)
        for i in range(len(self.flows) - 1, -1, -1):
            x, log_d = self.flows[i].inverse(x, context=context)
            log_det += log_d
        return x, log_det

    def sample(self, num_samples=1, context=None):
        """Samples from flow-based approximate distribution

        Args:
          num_samples: Number of samples to draw
          context: Batch of conditions/context

        Returns:
          Samples, log probability
        """
        z, log_q = self.q0(num_samples, context=context)
        for flow in self.flows:
            z, log_det = flow(z, context=context)
            log_q -= log_det
        return z, log_q

    def log_prob(self, x, context=None):
        """Get log probability for batch

        Args:
          x: Batch
          context: Batch of conditions/context

        Returns:
          log probability
        """
        log_q = torch.zeros(len(x), dtype=x.dtype, device=x.device)
        z = x
        for i in range(len(self.flows) - 1, -1, -1):
            z, log_det = self.flows[i].inverse(z, context=context)
            log_q += log_det
        log_q += self.q0.log_prob(z, context=context)
        return log_q

    def forward_kld(self, x, context=None):
        """Estimates forward KL divergence, see [arXiv 1912.02762](https://arxiv.org/abs/1912.02762)

        Args:
          x: Batch sampled from target distribution
          context: Batch of conditions/context

        Returns:
          Estimate of forward KL divergence averaged over batch
        """
        log_q = torch.zeros(len(x), device=x.device)
        z = x
        for i in range(len(self.flows) - 1, -1, -1):
            z, log_det = self.flows[i].inverse(z, context=context)
            log_q += log_det
        log_q += self.q0.log_prob(z, context=context)
        return -torch.mean(log_q)

    def reverse_kld(self, num_samples=1, context=None, beta=1.0, score_fn=True):
        """Estimates reverse KL divergence, see [arXiv 1912.02762](https://arxiv.org/abs/1912.02762)

        Args:
          num_samples: Number of samples to draw from base distribution
          context: Batch of conditions/context
          beta: Annealing parameter, see [arXiv 1505.05770](https://arxiv.org/abs/1505.05770)
          score_fn: Flag whether to include score function in gradient, see [arXiv 1703.09194](https://arxiv.org/abs/1703.09194)

        Returns:
          Estimate of the reverse KL divergence averaged over latent samples
        """
        z, log_q_ = self.q0(num_samples, context=context)
        log_q = torch.zeros_like(log_q_)
        log_q += log_q_
        for flow in self.flows:
            z, log_det = flow(z, context=context)
            log_q -= log_det
        if not score_fn:
            z_ = z
            log_q = torch.zeros(len(z_), device=z_.device)
            utils.set_requires_grad(self, False)
            for i in range(len(self.flows) - 1, -1, -1):
                z_, log_det = self.flows[i].inverse(z_, context=context)
                log_q += log_det
            log_q += self.q0.log_prob(z_, context=context)
            utils.set_requires_grad(self, True)
        log_p = self.p.log_prob(z, context=context)
        return torch.mean(log_q) - beta * torch.mean(log_p)

forward(z, context=None)

Transforms latent variable z to the flow variable x

Parameters:

Name Type Description Default
z

Batch in the latent space

required
context

Batch of conditions/context

None

Returns:

Type Description

Batch in the space of the target distribution

Source code in normflows/core.py
222
223
224
225
226
227
228
229
230
231
232
233
234
def forward(self, z, context=None):
    """Transforms latent variable z to the flow variable x

    Args:
      z: Batch in the latent space
      context: Batch of conditions/context

    Returns:
      Batch in the space of the target distribution
    """
    for flow in self.flows:
        z, _ = flow(z, context=context)
    return z

forward_and_log_det(z, context=None)

Transforms latent variable z to the flow variable x and computes log determinant of the Jacobian

Parameters:

Name Type Description Default
z

Batch in the latent space

required
context

Batch of conditions/context

None

Returns:

Type Description

Batch in the space of the target distribution,

log determinant of the Jacobian

Source code in normflows/core.py
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
def forward_and_log_det(self, z, context=None):
    """Transforms latent variable z to the flow variable x and
    computes log determinant of the Jacobian

    Args:
      z: Batch in the latent space
      context: Batch of conditions/context

    Returns:
      Batch in the space of the target distribution,
      log determinant of the Jacobian
    """
    log_det = torch.zeros(len(z), device=z.device)
    for flow in self.flows:
        z, log_d = flow(z, context=context)
        log_det += log_d
    return z, log_det

forward_kld(x, context=None)

Estimates forward KL divergence, see arXiv 1912.02762

Parameters:

Name Type Description Default
x

Batch sampled from target distribution

required
context

Batch of conditions/context

None

Returns:

Type Description

Estimate of forward KL divergence averaged over batch

Source code in normflows/core.py
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
def forward_kld(self, x, context=None):
    """Estimates forward KL divergence, see [arXiv 1912.02762](https://arxiv.org/abs/1912.02762)

    Args:
      x: Batch sampled from target distribution
      context: Batch of conditions/context

    Returns:
      Estimate of forward KL divergence averaged over batch
    """
    log_q = torch.zeros(len(x), device=x.device)
    z = x
    for i in range(len(self.flows) - 1, -1, -1):
        z, log_det = self.flows[i].inverse(z, context=context)
        log_q += log_det
    log_q += self.q0.log_prob(z, context=context)
    return -torch.mean(log_q)

inverse(x, context=None)

Transforms flow variable x to the latent variable z

Parameters:

Name Type Description Default
x

Batch in the space of the target distribution

required
context

Batch of conditions/context

None

Returns:

Type Description

Batch in the latent space

Source code in normflows/core.py
254
255
256
257
258
259
260
261
262
263
264
265
266
def inverse(self, x, context=None):
    """Transforms flow variable x to the latent variable z

    Args:
      x: Batch in the space of the target distribution
      context: Batch of conditions/context

    Returns:
      Batch in the latent space
    """
    for i in range(len(self.flows) - 1, -1, -1):
        x, _ = self.flows[i].inverse(x, context=context)
    return x

inverse_and_log_det(x, context=None)

Transforms flow variable x to the latent variable z and computes log determinant of the Jacobian

Parameters:

Name Type Description Default
x

Batch in the space of the target distribution

required
context

Batch of conditions/context

None

Returns:

Type Description

Batch in the latent space, log determinant of the

Jacobian

Source code in normflows/core.py
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
def inverse_and_log_det(self, x, context=None):
    """Transforms flow variable x to the latent variable z and
    computes log determinant of the Jacobian

    Args:
      x: Batch in the space of the target distribution
      context: Batch of conditions/context

    Returns:
      Batch in the latent space, log determinant of the
      Jacobian
    """
    log_det = torch.zeros(len(x), device=x.device)
    for i in range(len(self.flows) - 1, -1, -1):
        x, log_d = self.flows[i].inverse(x, context=context)
        log_det += log_d
    return x, log_det

log_prob(x, context=None)

Get log probability for batch

Parameters:

Name Type Description Default
x

Batch

required
context

Batch of conditions/context

None

Returns:

Type Description

log probability

Source code in normflows/core.py
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
def log_prob(self, x, context=None):
    """Get log probability for batch

    Args:
      x: Batch
      context: Batch of conditions/context

    Returns:
      log probability
    """
    log_q = torch.zeros(len(x), dtype=x.dtype, device=x.device)
    z = x
    for i in range(len(self.flows) - 1, -1, -1):
        z, log_det = self.flows[i].inverse(z, context=context)
        log_q += log_det
    log_q += self.q0.log_prob(z, context=context)
    return log_q

reverse_kld(num_samples=1, context=None, beta=1.0, score_fn=True)

Estimates reverse KL divergence, see arXiv 1912.02762

Parameters:

Name Type Description Default
num_samples

Number of samples to draw from base distribution

1
context

Batch of conditions/context

None
beta

Annealing parameter, see arXiv 1505.05770

1.0
score_fn

Flag whether to include score function in gradient, see arXiv 1703.09194

True

Returns:

Type Description

Estimate of the reverse KL divergence averaged over latent samples

Source code in normflows/core.py
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
def reverse_kld(self, num_samples=1, context=None, beta=1.0, score_fn=True):
    """Estimates reverse KL divergence, see [arXiv 1912.02762](https://arxiv.org/abs/1912.02762)

    Args:
      num_samples: Number of samples to draw from base distribution
      context: Batch of conditions/context
      beta: Annealing parameter, see [arXiv 1505.05770](https://arxiv.org/abs/1505.05770)
      score_fn: Flag whether to include score function in gradient, see [arXiv 1703.09194](https://arxiv.org/abs/1703.09194)

    Returns:
      Estimate of the reverse KL divergence averaged over latent samples
    """
    z, log_q_ = self.q0(num_samples, context=context)
    log_q = torch.zeros_like(log_q_)
    log_q += log_q_
    for flow in self.flows:
        z, log_det = flow(z, context=context)
        log_q -= log_det
    if not score_fn:
        z_ = z
        log_q = torch.zeros(len(z_), device=z_.device)
        utils.set_requires_grad(self, False)
        for i in range(len(self.flows) - 1, -1, -1):
            z_, log_det = self.flows[i].inverse(z_, context=context)
            log_q += log_det
        log_q += self.q0.log_prob(z_, context=context)
        utils.set_requires_grad(self, True)
    log_p = self.p.log_prob(z, context=context)
    return torch.mean(log_q) - beta * torch.mean(log_p)

sample(num_samples=1, context=None)

Samples from flow-based approximate distribution

Parameters:

Name Type Description Default
num_samples

Number of samples to draw

1
context

Batch of conditions/context

None

Returns:

Type Description

Samples, log probability

Source code in normflows/core.py
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
def sample(self, num_samples=1, context=None):
    """Samples from flow-based approximate distribution

    Args:
      num_samples: Number of samples to draw
      context: Batch of conditions/context

    Returns:
      Samples, log probability
    """
    z, log_q = self.q0(num_samples, context=context)
    for flow in self.flows:
        z, log_det = flow(z, context=context)
        log_q -= log_det
    return z, log_q

MultiscaleFlow

Bases: Module

Normalizing Flow model with multiscale architecture, see RealNVP or Glow paper

Source code in normflows/core.py
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
class MultiscaleFlow(nn.Module):
    """
    Normalizing Flow model with multiscale architecture, see RealNVP or Glow paper
    """

    def __init__(self, q0, flows, merges, transform=None, class_cond=True):
        """Constructor

        Args:

          q0: List of base distribution
          flows: List of flows for each level
          merges: List of merge/split operations (forward pass must do merge)
          transform: Initial transformation of inputs
          class_cond: Flag, indicated whether model has class conditional
        base distributions
        """
        super().__init__()
        self.q0 = nn.ModuleList(q0)
        self.num_levels = len(self.q0)
        self.flows = torch.nn.ModuleList([nn.ModuleList(flow) for flow in flows])
        self.merges = torch.nn.ModuleList(merges)
        self.transform = transform
        self.class_cond = class_cond

    def forward_kld(self, x, y=None):
        """Estimates forward KL divergence, see [arXiv 1912.02762](https://arxiv.org/abs/1912.02762)

        Args:
          x: Batch sampled from target distribution
          y: Batch of classes to condition on, if applicable

        Returns:
          Estimate of forward KL divergence averaged over batch
        """
        return -torch.mean(self.log_prob(x, y))

    def forward(self, x, y=None):
        """Get negative log-likelihood for maximum likelihood training

        Args:
          x: Batch of data
          y: Batch of classes to condition on, if applicable

        Returns:
            Negative log-likelihood of the batch
        """
        return -self.log_prob(x, y)

    def forward_and_log_det(self, z):
        """Get observed variable x from list of latent variables z

        Args:
            z: List of latent variables

        Returns:
            Observed variable x, log determinant of Jacobian
        """
        log_det = torch.zeros(len(z[0]), dtype=z[0].dtype, device=z[0].device)
        for i in range(len(self.q0)):
            if i == 0:
                z_ = z[0]
            else:
                z_, log_det_ = self.merges[i - 1]([z_, z[i]])
                log_det += log_det_
            for flow in self.flows[i]:
                z_, log_det_ = flow(z_)
                log_det += log_det_
        if self.transform is not None:
            z_, log_det_ = self.transform(z_)
            log_det += log_det_
        return z_, log_det

    def inverse_and_log_det(self, x):
        """Get latent variable z from observed variable x

        Args:
            x: Observed variable

        Returns:
            List of latent variables z, log determinant of Jacobian
        """
        log_det = torch.zeros(len(x), dtype=x.dtype, device=x.device)
        if self.transform is not None:
            x, log_det_ = self.transform.inverse(x)
            log_det += log_det_
        z = [None] * len(self.q0)
        for i in range(len(self.q0) - 1, -1, -1):
            for flow in reversed(self.flows[i]):
                x, log_det_ = flow.inverse(x)
                log_det += log_det_
            if i == 0:
                z[i] = x
            else:
                [x, z[i]], log_det_ = self.merges[i - 1].inverse(x)
                log_det += log_det_
        return z, log_det

    def sample(self, num_samples=1, y=None, temperature=None):
        """Samples from flow-based approximate distribution

        Args:
          num_samples: Number of samples to draw
          y: Classes to sample from, will be sampled uniformly if None
          temperature: Temperature parameter for temp annealed sampling

        Returns:
          Samples, log probability
        """
        if temperature is not None:
            self.set_temperature(temperature)
        for i in range(len(self.q0)):
            if self.class_cond:
                z_, log_q_ = self.q0[i](num_samples, y)
            else:
                z_, log_q_ = self.q0[i](num_samples)
            if i == 0:
                log_q = log_q_
                z = z_
            else:
                log_q += log_q_
                z, log_det = self.merges[i - 1]([z, z_])
                log_q -= log_det
            for flow in self.flows[i]:
                z, log_det = flow(z)
                log_q -= log_det
        if self.transform is not None:
            z, log_det = self.transform(z)
            log_q -= log_det
        if temperature is not None:
            self.reset_temperature()
        return z, log_q

    def log_prob(self, x, y=None):
        """Get log probability for batch

        Args:
          x: Batch
          y: Classes of x. Must be passed in if `class_cond` is True.

        Returns:
          log probability
        """
        log_q = 0
        z = x
        if self.transform is not None:
            z, log_det = self.transform.inverse(z)
            log_q += log_det
        for i in range(len(self.q0) - 1, -1, -1):
            for j in range(len(self.flows[i]) - 1, -1, -1):
                z, log_det = self.flows[i][j].inverse(z)
                log_q += log_det
            if i > 0:
                [z, z_], log_det = self.merges[i - 1].inverse(z)
                log_q += log_det
            else:
                z_ = z
            if self.class_cond:
                log_q += self.q0[i].log_prob(z_, y)
            else:
                log_q += self.q0[i].log_prob(z_)
        return log_q

    def save(self, path):
        """Save state dict of model

        Args:
          path: Path including filename where to save model
        """
        torch.save(self.state_dict(), path)

    def load(self, path):
        """Load model from state dict

        Args:
          path: Path including filename where to load model from
        """
        self.load_state_dict(torch.load(path))

    def set_temperature(self, temperature):
        """Set temperature for temperature a annealed sampling

        Args:
          temperature: Temperature parameter
        """
        for q0 in self.q0:
            if hasattr(q0, "temperature"):
                q0.temperature = temperature
            else:
                raise NotImplementedError(
                    "One base function does not "
                    "support temperature annealed sampling"
                )

    def reset_temperature(self):
        """
        Set temperature values of base distributions back to None
        """
        self.set_temperature(None)

__init__(q0, flows, merges, transform=None, class_cond=True)

Constructor

Args:

q0: List of base distribution flows: List of flows for each level merges: List of merge/split operations (forward pass must do merge) transform: Initial transformation of inputs class_cond: Flag, indicated whether model has class conditional base distributions

Source code in normflows/core.py
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
def __init__(self, q0, flows, merges, transform=None, class_cond=True):
    """Constructor

    Args:

      q0: List of base distribution
      flows: List of flows for each level
      merges: List of merge/split operations (forward pass must do merge)
      transform: Initial transformation of inputs
      class_cond: Flag, indicated whether model has class conditional
    base distributions
    """
    super().__init__()
    self.q0 = nn.ModuleList(q0)
    self.num_levels = len(self.q0)
    self.flows = torch.nn.ModuleList([nn.ModuleList(flow) for flow in flows])
    self.merges = torch.nn.ModuleList(merges)
    self.transform = transform
    self.class_cond = class_cond

forward(x, y=None)

Get negative log-likelihood for maximum likelihood training

Parameters:

Name Type Description Default
x

Batch of data

required
y

Batch of classes to condition on, if applicable

None

Returns:

Type Description

Negative log-likelihood of the batch

Source code in normflows/core.py
492
493
494
495
496
497
498
499
500
501
502
def forward(self, x, y=None):
    """Get negative log-likelihood for maximum likelihood training

    Args:
      x: Batch of data
      y: Batch of classes to condition on, if applicable

    Returns:
        Negative log-likelihood of the batch
    """
    return -self.log_prob(x, y)

forward_and_log_det(z)

Get observed variable x from list of latent variables z

Parameters:

Name Type Description Default
z

List of latent variables

required

Returns:

Type Description

Observed variable x, log determinant of Jacobian

Source code in normflows/core.py
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
def forward_and_log_det(self, z):
    """Get observed variable x from list of latent variables z

    Args:
        z: List of latent variables

    Returns:
        Observed variable x, log determinant of Jacobian
    """
    log_det = torch.zeros(len(z[0]), dtype=z[0].dtype, device=z[0].device)
    for i in range(len(self.q0)):
        if i == 0:
            z_ = z[0]
        else:
            z_, log_det_ = self.merges[i - 1]([z_, z[i]])
            log_det += log_det_
        for flow in self.flows[i]:
            z_, log_det_ = flow(z_)
            log_det += log_det_
    if self.transform is not None:
        z_, log_det_ = self.transform(z_)
        log_det += log_det_
    return z_, log_det

forward_kld(x, y=None)

Estimates forward KL divergence, see arXiv 1912.02762

Parameters:

Name Type Description Default
x

Batch sampled from target distribution

required
y

Batch of classes to condition on, if applicable

None

Returns:

Type Description

Estimate of forward KL divergence averaged over batch

Source code in normflows/core.py
480
481
482
483
484
485
486
487
488
489
490
def forward_kld(self, x, y=None):
    """Estimates forward KL divergence, see [arXiv 1912.02762](https://arxiv.org/abs/1912.02762)

    Args:
      x: Batch sampled from target distribution
      y: Batch of classes to condition on, if applicable

    Returns:
      Estimate of forward KL divergence averaged over batch
    """
    return -torch.mean(self.log_prob(x, y))

inverse_and_log_det(x)

Get latent variable z from observed variable x

Parameters:

Name Type Description Default
x

Observed variable

required

Returns:

Type Description

List of latent variables z, log determinant of Jacobian

Source code in normflows/core.py
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
def inverse_and_log_det(self, x):
    """Get latent variable z from observed variable x

    Args:
        x: Observed variable

    Returns:
        List of latent variables z, log determinant of Jacobian
    """
    log_det = torch.zeros(len(x), dtype=x.dtype, device=x.device)
    if self.transform is not None:
        x, log_det_ = self.transform.inverse(x)
        log_det += log_det_
    z = [None] * len(self.q0)
    for i in range(len(self.q0) - 1, -1, -1):
        for flow in reversed(self.flows[i]):
            x, log_det_ = flow.inverse(x)
            log_det += log_det_
        if i == 0:
            z[i] = x
        else:
            [x, z[i]], log_det_ = self.merges[i - 1].inverse(x)
            log_det += log_det_
    return z, log_det

load(path)

Load model from state dict

Parameters:

Name Type Description Default
path

Path including filename where to load model from

required
Source code in normflows/core.py
626
627
628
629
630
631
632
def load(self, path):
    """Load model from state dict

    Args:
      path: Path including filename where to load model from
    """
    self.load_state_dict(torch.load(path))

log_prob(x, y=None)

Get log probability for batch

Parameters:

Name Type Description Default
x

Batch

required
y

Classes of x. Must be passed in if class_cond is True.

None

Returns:

Type Description

log probability

Source code in normflows/core.py
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
def log_prob(self, x, y=None):
    """Get log probability for batch

    Args:
      x: Batch
      y: Classes of x. Must be passed in if `class_cond` is True.

    Returns:
      log probability
    """
    log_q = 0
    z = x
    if self.transform is not None:
        z, log_det = self.transform.inverse(z)
        log_q += log_det
    for i in range(len(self.q0) - 1, -1, -1):
        for j in range(len(self.flows[i]) - 1, -1, -1):
            z, log_det = self.flows[i][j].inverse(z)
            log_q += log_det
        if i > 0:
            [z, z_], log_det = self.merges[i - 1].inverse(z)
            log_q += log_det
        else:
            z_ = z
        if self.class_cond:
            log_q += self.q0[i].log_prob(z_, y)
        else:
            log_q += self.q0[i].log_prob(z_)
    return log_q

reset_temperature()

Set temperature values of base distributions back to None

Source code in normflows/core.py
649
650
651
652
653
def reset_temperature(self):
    """
    Set temperature values of base distributions back to None
    """
    self.set_temperature(None)

sample(num_samples=1, y=None, temperature=None)

Samples from flow-based approximate distribution

Parameters:

Name Type Description Default
num_samples

Number of samples to draw

1
y

Classes to sample from, will be sampled uniformly if None

None
temperature

Temperature parameter for temp annealed sampling

None

Returns:

Type Description

Samples, log probability

Source code in normflows/core.py
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
def sample(self, num_samples=1, y=None, temperature=None):
    """Samples from flow-based approximate distribution

    Args:
      num_samples: Number of samples to draw
      y: Classes to sample from, will be sampled uniformly if None
      temperature: Temperature parameter for temp annealed sampling

    Returns:
      Samples, log probability
    """
    if temperature is not None:
        self.set_temperature(temperature)
    for i in range(len(self.q0)):
        if self.class_cond:
            z_, log_q_ = self.q0[i](num_samples, y)
        else:
            z_, log_q_ = self.q0[i](num_samples)
        if i == 0:
            log_q = log_q_
            z = z_
        else:
            log_q += log_q_
            z, log_det = self.merges[i - 1]([z, z_])
            log_q -= log_det
        for flow in self.flows[i]:
            z, log_det = flow(z)
            log_q -= log_det
    if self.transform is not None:
        z, log_det = self.transform(z)
        log_q -= log_det
    if temperature is not None:
        self.reset_temperature()
    return z, log_q

save(path)

Save state dict of model

Parameters:

Name Type Description Default
path

Path including filename where to save model

required
Source code in normflows/core.py
618
619
620
621
622
623
624
def save(self, path):
    """Save state dict of model

    Args:
      path: Path including filename where to save model
    """
    torch.save(self.state_dict(), path)

set_temperature(temperature)

Set temperature for temperature a annealed sampling

Parameters:

Name Type Description Default
temperature

Temperature parameter

required
Source code in normflows/core.py
634
635
636
637
638
639
640
641
642
643
644
645
646
647
def set_temperature(self, temperature):
    """Set temperature for temperature a annealed sampling

    Args:
      temperature: Temperature parameter
    """
    for q0 in self.q0:
        if hasattr(q0, "temperature"):
            q0.temperature = temperature
        else:
            raise NotImplementedError(
                "One base function does not "
                "support temperature annealed sampling"
            )

NormalizingFlow

Bases: Module

Normalizing Flow model to approximate target distribution

Source code in normflows/core.py
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
class NormalizingFlow(nn.Module):
    """
    Normalizing Flow model to approximate target distribution
    """

    def __init__(self, q0, flows, p=None):
        """Constructor

        Args:
          q0: Base distribution
          flows: List of flows
          p: Target distribution
        """
        super().__init__()
        self.q0 = q0
        self.flows = nn.ModuleList(flows)
        self.p = p

    def forward(self, z):
        """Transforms latent variable z to the flow variable x

        Args:
          z: Batch in the latent space

        Returns:
          Batch in the space of the target distribution
        """
        for flow in self.flows:
            z, _ = flow(z)
        return z

    def forward_and_log_det(self, z):
        """Transforms latent variable z to the flow variable x and
        computes log determinant of the Jacobian

        Args:
          z: Batch in the latent space

        Returns:
          Batch in the space of the target distribution,
          log determinant of the Jacobian
        """
        log_det = torch.zeros(len(z), device=z.device)
        for flow in self.flows:
            z, log_d = flow(z)
            log_det += log_d
        return z, log_det

    def inverse(self, x):
        """Transforms flow variable x to the latent variable z

        Args:
          x: Batch in the space of the target distribution

        Returns:
          Batch in the latent space
        """
        for i in range(len(self.flows) - 1, -1, -1):
            x, _ = self.flows[i].inverse(x)
        return x

    def inverse_and_log_det(self, x):
        """Transforms flow variable x to the latent variable z and
        computes log determinant of the Jacobian

        Args:
          x: Batch in the space of the target distribution

        Returns:
          Batch in the latent space, log determinant of the
          Jacobian
        """
        log_det = torch.zeros(len(x), device=x.device)
        for i in range(len(self.flows) - 1, -1, -1):
            x, log_d = self.flows[i].inverse(x)
            log_det += log_d
        return x, log_det

    def forward_kld(self, x):
        """Estimates forward KL divergence, see [arXiv 1912.02762](https://arxiv.org/abs/1912.02762)

        Args:
          x: Batch sampled from target distribution

        Returns:
          Estimate of forward KL divergence averaged over batch
        """
        log_q = torch.zeros(len(x), device=x.device)
        z = x
        for i in range(len(self.flows) - 1, -1, -1):
            z, log_det = self.flows[i].inverse(z)
            log_q += log_det
        log_q += self.q0.log_prob(z)
        return -torch.mean(log_q)

    def reverse_kld(self, num_samples=1, beta=1.0, score_fn=True):
        """Estimates reverse KL divergence, see [arXiv 1912.02762](https://arxiv.org/abs/1912.02762)

        Args:
          num_samples: Number of samples to draw from base distribution
          beta: Annealing parameter, see [arXiv 1505.05770](https://arxiv.org/abs/1505.05770)
          score_fn: Flag whether to include score function in gradient, see [arXiv 1703.09194](https://arxiv.org/abs/1703.09194)

        Returns:
          Estimate of the reverse KL divergence averaged over latent samples
        """
        z, log_q_ = self.q0(num_samples)
        log_q = torch.zeros_like(log_q_)
        log_q += log_q_
        for flow in self.flows:
            z, log_det = flow(z)
            log_q -= log_det
        if not score_fn:
            z_ = z
            log_q = torch.zeros(len(z_), device=z_.device)
            utils.set_requires_grad(self, False)
            for i in range(len(self.flows) - 1, -1, -1):
                z_, log_det = self.flows[i].inverse(z_)
                log_q += log_det
            log_q += self.q0.log_prob(z_)
            utils.set_requires_grad(self, True)
        log_p = self.p.log_prob(z)
        return torch.mean(log_q) - beta * torch.mean(log_p)

    def reverse_alpha_div(self, num_samples=1, alpha=1, dreg=False):
        """Alpha divergence when sampling from q

        Args:
          num_samples: Number of samples to draw
          dreg: Flag whether to use Double Reparametrized Gradient estimator, see [arXiv 1810.04152](https://arxiv.org/abs/1810.04152)

        Returns:
          Alpha divergence
        """
        z, log_q = self.q0(num_samples)
        for flow in self.flows:
            z, log_det = flow(z)
            log_q -= log_det
        log_p = self.p.log_prob(z)
        if dreg:
            w_const = torch.exp(log_p - log_q).detach()
            z_ = z
            log_q = torch.zeros(len(z_), device=z_.device)
            utils.set_requires_grad(self, False)
            for i in range(len(self.flows) - 1, -1, -1):
                z_, log_det = self.flows[i].inverse(z_)
                log_q += log_det
            log_q += self.q0.log_prob(z_)
            utils.set_requires_grad(self, True)
            w = torch.exp(log_p - log_q)
            w_alpha = w_const**alpha
            w_alpha = w_alpha / torch.mean(w_alpha)
            weights = (1 - alpha) * w_alpha + alpha * w_alpha**2
            loss = -alpha * torch.mean(weights * torch.log(w))
        else:
            loss = np.sign(alpha - 1) * torch.logsumexp(alpha * (log_p - log_q), 0)
        return loss

    def sample(self, num_samples=1):
        """Samples from flow-based approximate distribution

        Args:
          num_samples: Number of samples to draw

        Returns:
          Samples, log probability
        """
        z, log_q = self.q0(num_samples)
        for flow in self.flows:
            z, log_det = flow(z)
            log_q -= log_det
        return z, log_q

    def log_prob(self, x):
        """Get log probability for batch

        Args:
          x: Batch

        Returns:
          log probability
        """
        log_q = torch.zeros(len(x), dtype=x.dtype, device=x.device)
        z = x
        for i in range(len(self.flows) - 1, -1, -1):
            z, log_det = self.flows[i].inverse(z)
            log_q += log_det
        log_q += self.q0.log_prob(z)
        return log_q

    def save(self, path):
        """Save state dict of model

        Args:
          path: Path including filename where to save model
        """
        torch.save(self.state_dict(), path)

    def load(self, path):
        """Load model from state dict

        Args:
          path: Path including filename where to load model from
        """
        self.load_state_dict(torch.load(path))

__init__(q0, flows, p=None)

Constructor

Parameters:

Name Type Description Default
q0

Base distribution

required
flows

List of flows

required
p

Target distribution

None
Source code in normflows/core.py
14
15
16
17
18
19
20
21
22
23
24
25
def __init__(self, q0, flows, p=None):
    """Constructor

    Args:
      q0: Base distribution
      flows: List of flows
      p: Target distribution
    """
    super().__init__()
    self.q0 = q0
    self.flows = nn.ModuleList(flows)
    self.p = p

forward(z)

Transforms latent variable z to the flow variable x

Parameters:

Name Type Description Default
z

Batch in the latent space

required

Returns:

Type Description

Batch in the space of the target distribution

Source code in normflows/core.py
27
28
29
30
31
32
33
34
35
36
37
38
def forward(self, z):
    """Transforms latent variable z to the flow variable x

    Args:
      z: Batch in the latent space

    Returns:
      Batch in the space of the target distribution
    """
    for flow in self.flows:
        z, _ = flow(z)
    return z

forward_and_log_det(z)

Transforms latent variable z to the flow variable x and computes log determinant of the Jacobian

Parameters:

Name Type Description Default
z

Batch in the latent space

required

Returns:

Type Description

Batch in the space of the target distribution,

log determinant of the Jacobian

Source code in normflows/core.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def forward_and_log_det(self, z):
    """Transforms latent variable z to the flow variable x and
    computes log determinant of the Jacobian

    Args:
      z: Batch in the latent space

    Returns:
      Batch in the space of the target distribution,
      log determinant of the Jacobian
    """
    log_det = torch.zeros(len(z), device=z.device)
    for flow in self.flows:
        z, log_d = flow(z)
        log_det += log_d
    return z, log_det

forward_kld(x)

Estimates forward KL divergence, see arXiv 1912.02762

Parameters:

Name Type Description Default
x

Batch sampled from target distribution

required

Returns:

Type Description

Estimate of forward KL divergence averaged over batch

Source code in normflows/core.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def forward_kld(self, x):
    """Estimates forward KL divergence, see [arXiv 1912.02762](https://arxiv.org/abs/1912.02762)

    Args:
      x: Batch sampled from target distribution

    Returns:
      Estimate of forward KL divergence averaged over batch
    """
    log_q = torch.zeros(len(x), device=x.device)
    z = x
    for i in range(len(self.flows) - 1, -1, -1):
        z, log_det = self.flows[i].inverse(z)
        log_q += log_det
    log_q += self.q0.log_prob(z)
    return -torch.mean(log_q)

inverse(x)

Transforms flow variable x to the latent variable z

Parameters:

Name Type Description Default
x

Batch in the space of the target distribution

required

Returns:

Type Description

Batch in the latent space

Source code in normflows/core.py
57
58
59
60
61
62
63
64
65
66
67
68
def inverse(self, x):
    """Transforms flow variable x to the latent variable z

    Args:
      x: Batch in the space of the target distribution

    Returns:
      Batch in the latent space
    """
    for i in range(len(self.flows) - 1, -1, -1):
        x, _ = self.flows[i].inverse(x)
    return x

inverse_and_log_det(x)

Transforms flow variable x to the latent variable z and computes log determinant of the Jacobian

Parameters:

Name Type Description Default
x

Batch in the space of the target distribution

required

Returns:

Type Description

Batch in the latent space, log determinant of the

Jacobian

Source code in normflows/core.py
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def inverse_and_log_det(self, x):
    """Transforms flow variable x to the latent variable z and
    computes log determinant of the Jacobian

    Args:
      x: Batch in the space of the target distribution

    Returns:
      Batch in the latent space, log determinant of the
      Jacobian
    """
    log_det = torch.zeros(len(x), device=x.device)
    for i in range(len(self.flows) - 1, -1, -1):
        x, log_d = self.flows[i].inverse(x)
        log_det += log_d
    return x, log_det

load(path)

Load model from state dict

Parameters:

Name Type Description Default
path

Path including filename where to load model from

required
Source code in normflows/core.py
207
208
209
210
211
212
213
def load(self, path):
    """Load model from state dict

    Args:
      path: Path including filename where to load model from
    """
    self.load_state_dict(torch.load(path))

log_prob(x)

Get log probability for batch

Parameters:

Name Type Description Default
x

Batch

required

Returns:

Type Description

log probability

Source code in normflows/core.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
def log_prob(self, x):
    """Get log probability for batch

    Args:
      x: Batch

    Returns:
      log probability
    """
    log_q = torch.zeros(len(x), dtype=x.dtype, device=x.device)
    z = x
    for i in range(len(self.flows) - 1, -1, -1):
        z, log_det = self.flows[i].inverse(z)
        log_q += log_det
    log_q += self.q0.log_prob(z)
    return log_q

reverse_alpha_div(num_samples=1, alpha=1, dreg=False)

Alpha divergence when sampling from q

Parameters:

Name Type Description Default
num_samples

Number of samples to draw

1
dreg

Flag whether to use Double Reparametrized Gradient estimator, see arXiv 1810.04152

False

Returns:

Type Description

Alpha divergence

Source code in normflows/core.py
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
def reverse_alpha_div(self, num_samples=1, alpha=1, dreg=False):
    """Alpha divergence when sampling from q

    Args:
      num_samples: Number of samples to draw
      dreg: Flag whether to use Double Reparametrized Gradient estimator, see [arXiv 1810.04152](https://arxiv.org/abs/1810.04152)

    Returns:
      Alpha divergence
    """
    z, log_q = self.q0(num_samples)
    for flow in self.flows:
        z, log_det = flow(z)
        log_q -= log_det
    log_p = self.p.log_prob(z)
    if dreg:
        w_const = torch.exp(log_p - log_q).detach()
        z_ = z
        log_q = torch.zeros(len(z_), device=z_.device)
        utils.set_requires_grad(self, False)
        for i in range(len(self.flows) - 1, -1, -1):
            z_, log_det = self.flows[i].inverse(z_)
            log_q += log_det
        log_q += self.q0.log_prob(z_)
        utils.set_requires_grad(self, True)
        w = torch.exp(log_p - log_q)
        w_alpha = w_const**alpha
        w_alpha = w_alpha / torch.mean(w_alpha)
        weights = (1 - alpha) * w_alpha + alpha * w_alpha**2
        loss = -alpha * torch.mean(weights * torch.log(w))
    else:
        loss = np.sign(alpha - 1) * torch.logsumexp(alpha * (log_p - log_q), 0)
    return loss

reverse_kld(num_samples=1, beta=1.0, score_fn=True)

Estimates reverse KL divergence, see arXiv 1912.02762

Parameters:

Name Type Description Default
num_samples

Number of samples to draw from base distribution

1
beta

Annealing parameter, see arXiv 1505.05770

1.0
score_fn

Flag whether to include score function in gradient, see arXiv 1703.09194

True

Returns:

Type Description

Estimate of the reverse KL divergence averaged over latent samples

Source code in normflows/core.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
def reverse_kld(self, num_samples=1, beta=1.0, score_fn=True):
    """Estimates reverse KL divergence, see [arXiv 1912.02762](https://arxiv.org/abs/1912.02762)

    Args:
      num_samples: Number of samples to draw from base distribution
      beta: Annealing parameter, see [arXiv 1505.05770](https://arxiv.org/abs/1505.05770)
      score_fn: Flag whether to include score function in gradient, see [arXiv 1703.09194](https://arxiv.org/abs/1703.09194)

    Returns:
      Estimate of the reverse KL divergence averaged over latent samples
    """
    z, log_q_ = self.q0(num_samples)
    log_q = torch.zeros_like(log_q_)
    log_q += log_q_
    for flow in self.flows:
        z, log_det = flow(z)
        log_q -= log_det
    if not score_fn:
        z_ = z
        log_q = torch.zeros(len(z_), device=z_.device)
        utils.set_requires_grad(self, False)
        for i in range(len(self.flows) - 1, -1, -1):
            z_, log_det = self.flows[i].inverse(z_)
            log_q += log_det
        log_q += self.q0.log_prob(z_)
        utils.set_requires_grad(self, True)
    log_p = self.p.log_prob(z)
    return torch.mean(log_q) - beta * torch.mean(log_p)

sample(num_samples=1)

Samples from flow-based approximate distribution

Parameters:

Name Type Description Default
num_samples

Number of samples to draw

1

Returns:

Type Description

Samples, log probability

Source code in normflows/core.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
def sample(self, num_samples=1):
    """Samples from flow-based approximate distribution

    Args:
      num_samples: Number of samples to draw

    Returns:
      Samples, log probability
    """
    z, log_q = self.q0(num_samples)
    for flow in self.flows:
        z, log_det = flow(z)
        log_q -= log_det
    return z, log_q

save(path)

Save state dict of model

Parameters:

Name Type Description Default
path

Path including filename where to save model

required
Source code in normflows/core.py
199
200
201
202
203
204
205
def save(self, path):
    """Save state dict of model

    Args:
      path: Path including filename where to save model
    """
    torch.save(self.state_dict(), path)

NormalizingFlowVAE

Bases: Module

VAE using normalizing flows to express approximate distribution

Source code in normflows/core.py
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
class NormalizingFlowVAE(nn.Module):
    """
    VAE using normalizing flows to express approximate distribution
    """

    def __init__(self, prior, q0=distributions.Dirac(), flows=None, decoder=None):
        """Constructor of normalizing flow model

        Args:
          prior: Prior distribution of te VAE, i.e. Gaussian
          decoder: Optional decoder
          flows: Flows to transform output of base encoder
          q0: Base Encoder
        """
        super().__init__()
        self.prior = prior
        self.decoder = decoder
        self.flows = nn.ModuleList(flows)
        self.q0 = q0

    def forward(self, x, num_samples=1):
        """Takes data batch, samples num_samples for each data point from base distribution

        Args:
          x: data batch
          num_samples: number of samples to draw for each data point

        Returns:
          latent variables for each batch and sample, log_q, and log_p
        """
        z, log_q = self.q0(x, num_samples=num_samples)
        # Flatten batch and sample dim
        z = z.view(-1, *z.size()[2:])
        log_q = log_q.view(-1, *log_q.size()[2:])
        for flow in self.flows:
            z, log_det = flow(z)
            log_q -= log_det
        log_p = self.prior.log_prob(z)
        if self.decoder is not None:
            log_p += self.decoder.log_prob(x, z)
        # Separate batch and sample dimension again
        z = z.view(-1, num_samples, *z.size()[1:])
        log_q = log_q.view(-1, num_samples, *log_q.size()[1:])
        log_p = log_p.view(-1, num_samples, *log_p.size()[1:])
        return z, log_q, log_p

__init__(prior, q0=distributions.Dirac(), flows=None, decoder=None)

Constructor of normalizing flow model

Parameters:

Name Type Description Default
prior

Prior distribution of te VAE, i.e. Gaussian

required
decoder

Optional decoder

None
flows

Flows to transform output of base encoder

None
q0

Base Encoder

Dirac()
Source code in normflows/core.py
661
662
663
664
665
666
667
668
669
670
671
672
673
674
def __init__(self, prior, q0=distributions.Dirac(), flows=None, decoder=None):
    """Constructor of normalizing flow model

    Args:
      prior: Prior distribution of te VAE, i.e. Gaussian
      decoder: Optional decoder
      flows: Flows to transform output of base encoder
      q0: Base Encoder
    """
    super().__init__()
    self.prior = prior
    self.decoder = decoder
    self.flows = nn.ModuleList(flows)
    self.q0 = q0

forward(x, num_samples=1)

Takes data batch, samples num_samples for each data point from base distribution

Parameters:

Name Type Description Default
x

data batch

required
num_samples

number of samples to draw for each data point

1

Returns:

Type Description

latent variables for each batch and sample, log_q, and log_p

Source code in normflows/core.py
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
def forward(self, x, num_samples=1):
    """Takes data batch, samples num_samples for each data point from base distribution

    Args:
      x: data batch
      num_samples: number of samples to draw for each data point

    Returns:
      latent variables for each batch and sample, log_q, and log_p
    """
    z, log_q = self.q0(x, num_samples=num_samples)
    # Flatten batch and sample dim
    z = z.view(-1, *z.size()[2:])
    log_q = log_q.view(-1, *log_q.size()[2:])
    for flow in self.flows:
        z, log_det = flow(z)
        log_q -= log_det
    log_p = self.prior.log_prob(z)
    if self.decoder is not None:
        log_p += self.decoder.log_prob(x, z)
    # Separate batch and sample dimension again
    z = z.view(-1, num_samples, *z.size()[1:])
    log_q = log_q.view(-1, num_samples, *log_q.size()[1:])
    log_p = log_p.view(-1, num_samples, *log_p.size()[1:])
    return z, log_q, log_p

distributions

base

AffineGaussian

Bases: BaseDistribution

Diagonal Gaussian an affine constant transformation applied to it, can be class conditional or not

Source code in normflows/distributions/base.py
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
class AffineGaussian(BaseDistribution):
    """
    Diagonal Gaussian an affine constant transformation applied to it,
    can be class conditional or not
    """

    def __init__(self, shape, affine_shape, num_classes=None):
        """Constructor

        Args:
          shape: Shape of the variables
          affine_shape: Shape of the parameters in the affine transformation
          num_classes: Number of classes if the base is class conditional, None otherwise
        """
        super().__init__()
        if isinstance(shape, int):
            shape = (shape,)
        if isinstance(shape, list):
            shape = tuple(shape)
        self.shape = shape
        self.n_dim = len(shape)
        self.d = np.prod(shape)
        self.sum_dim = list(range(1, self.n_dim + 1))
        self.affine_shape = affine_shape
        self.num_classes = num_classes
        self.class_cond = num_classes is not None
        # Affine transformation
        if self.class_cond:
            self.transform = flows.CCAffineConst(self.affine_shape, self.num_classes)
        else:
            self.transform = flows.AffineConstFlow(self.affine_shape)
        # Temperature parameter for annealed sampling
        self.temperature = None

    def forward(self, num_samples=1, y=None):
        dtype = self.transform.s.dtype
        device = self.transform.s.device
        if self.class_cond:
            if y is not None:
                num_samples = len(y)
            else:
                y = torch.randint(self.num_classes, (num_samples,), device=device)
            if y.dim() == 1:
                y_onehot = torch.zeros(
                    (len(y), self.num_classes), dtype=dtype, device=device
                )
                y_onehot.scatter_(1, y[:, None], 1)
                y = y_onehot
        if self.temperature is not None:
            log_scale = np.log(self.temperature)
        else:
            log_scale = 0.0
        # Sample
        eps = torch.randn((num_samples,) + self.shape, dtype=dtype, device=device)
        z = np.exp(log_scale) * eps
        # Get log prob
        log_p = (
            -0.5 * self.d * np.log(2 * np.pi)
            - self.d * log_scale
            - 0.5 * torch.sum(torch.pow(eps, 2), dim=self.sum_dim)
        )
        # Apply transform
        if self.class_cond:
            z, log_det = self.transform(z, y)
        else:
            z, log_det = self.transform(z)
        log_p -= log_det
        return z, log_p

    def log_prob(self, z, y=None):
        # Perpare onehot encoding of class if needed
        if self.class_cond:
            if y.dim() == 1:
                y_onehot = torch.zeros(
                    (len(y), self.num_classes),
                    dtype=self.transform.s.dtype,
                    device=self.transform.s.device,
                )
                y_onehot.scatter_(1, y[:, None], 1)
                y = y_onehot
        if self.temperature is not None:
            log_scale = np.log(self.temperature)
        else:
            log_scale = 0.0
        # Get log prob
        if self.class_cond:
            z, log_p = self.transform.inverse(z, y)
        else:
            z, log_p = self.transform.inverse(z)
        z = z / np.exp(log_scale)
        log_p = (
            log_p
            - self.d * log_scale
            - 0.5 * self.d * np.log(2 * np.pi)
            - 0.5 * torch.sum(torch.pow(z, 2), dim=self.sum_dim)
        )
        return log_p
__init__(shape, affine_shape, num_classes=None)

Constructor

Parameters:

Name Type Description Default
shape

Shape of the variables

required
affine_shape

Shape of the parameters in the affine transformation

required
num_classes

Number of classes if the base is class conditional, None otherwise

None
Source code in normflows/distributions/base.py
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
def __init__(self, shape, affine_shape, num_classes=None):
    """Constructor

    Args:
      shape: Shape of the variables
      affine_shape: Shape of the parameters in the affine transformation
      num_classes: Number of classes if the base is class conditional, None otherwise
    """
    super().__init__()
    if isinstance(shape, int):
        shape = (shape,)
    if isinstance(shape, list):
        shape = tuple(shape)
    self.shape = shape
    self.n_dim = len(shape)
    self.d = np.prod(shape)
    self.sum_dim = list(range(1, self.n_dim + 1))
    self.affine_shape = affine_shape
    self.num_classes = num_classes
    self.class_cond = num_classes is not None
    # Affine transformation
    if self.class_cond:
        self.transform = flows.CCAffineConst(self.affine_shape, self.num_classes)
    else:
        self.transform = flows.AffineConstFlow(self.affine_shape)
    # Temperature parameter for annealed sampling
    self.temperature = None

BaseDistribution

Bases: Module

Base distribution of a flow-based model Parameters do not depend of target variable (as is the case for a VAE encoder)

Source code in normflows/distributions/base.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class BaseDistribution(nn.Module):
    """
    Base distribution of a flow-based model
    Parameters do not depend of target variable (as is the case for a VAE encoder)
    """

    def __init__(self):
        super().__init__()

    def forward(self, num_samples=1):
        """Samples from base distribution and calculates log probability

        Args:
          num_samples: Number of samples to draw from the distriubtion

        Returns:
          Samples drawn from the distribution, log probability
        """
        raise NotImplementedError

    def log_prob(self, z):
        """Calculate log probability of batch of samples

        Args:
          z: Batch of random variables to determine log probability for

        Returns:
          log probability for each batch element
        """
        raise NotImplementedError

    def sample(self, num_samples=1, **kwargs):
        """Samples from base distribution

        Args:
          num_samples: Number of samples to draw from the distriubtion

        Returns:
          Samples drawn from the distribution
        """
        z, _ = self.forward(num_samples, **kwargs)
        return z
forward(num_samples=1)

Samples from base distribution and calculates log probability

Parameters:

Name Type Description Default
num_samples

Number of samples to draw from the distriubtion

1

Returns:

Type Description

Samples drawn from the distribution, log probability

Source code in normflows/distributions/base.py
17
18
19
20
21
22
23
24
25
26
def forward(self, num_samples=1):
    """Samples from base distribution and calculates log probability

    Args:
      num_samples: Number of samples to draw from the distriubtion

    Returns:
      Samples drawn from the distribution, log probability
    """
    raise NotImplementedError
log_prob(z)

Calculate log probability of batch of samples

Parameters:

Name Type Description Default
z

Batch of random variables to determine log probability for

required

Returns:

Type Description

log probability for each batch element

Source code in normflows/distributions/base.py
28
29
30
31
32
33
34
35
36
37
def log_prob(self, z):
    """Calculate log probability of batch of samples

    Args:
      z: Batch of random variables to determine log probability for

    Returns:
      log probability for each batch element
    """
    raise NotImplementedError
sample(num_samples=1, **kwargs)

Samples from base distribution

Parameters:

Name Type Description Default
num_samples

Number of samples to draw from the distriubtion

1

Returns:

Type Description

Samples drawn from the distribution

Source code in normflows/distributions/base.py
39
40
41
42
43
44
45
46
47
48
49
def sample(self, num_samples=1, **kwargs):
    """Samples from base distribution

    Args:
      num_samples: Number of samples to draw from the distriubtion

    Returns:
      Samples drawn from the distribution
    """
    z, _ = self.forward(num_samples, **kwargs)
    return z

ClassCondDiagGaussian

Bases: BaseDistribution

Class conditional multivariate Gaussian distribution with diagonal covariance matrix

Source code in normflows/distributions/base.py
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
class ClassCondDiagGaussian(BaseDistribution):
    """
    Class conditional multivariate Gaussian distribution with diagonal covariance matrix
    """

    def __init__(self, shape, num_classes):
        """Constructor

        Args:
          shape: Tuple with shape of data, if int shape has one dimension
          num_classes: Number of classes to condition on
        """
        super().__init__()
        if isinstance(shape, int):
            shape = (shape,)
        if isinstance(shape, list):
            shape = tuple(shape)
        self.shape = shape
        self.n_dim = len(shape)
        self.perm = [self.n_dim] + list(range(self.n_dim))
        self.d = np.prod(shape)
        self.num_classes = num_classes
        self.loc = nn.Parameter(torch.zeros(*self.shape, num_classes))
        self.log_scale = nn.Parameter(torch.zeros(*self.shape, num_classes))
        self.temperature = None  # Temperature parameter for annealed sampling

    def forward(self, num_samples=1, y=None):
        if y is not None:
            num_samples = len(y)
        else:
            y = torch.randint(self.num_classes, (num_samples,), device=self.loc.device)
        if y.dim() == 1:
            y_onehot = torch.zeros(
                (self.num_classes, num_samples),
                dtype=self.loc.dtype,
                device=self.loc.device,
            )
            y_onehot.scatter_(0, y[None], 1)
            y = y_onehot
        else:
            y = y.t()
        eps = torch.randn(
            (num_samples,) + self.shape, dtype=self.loc.dtype, device=self.loc.device
        )
        loc = (self.loc @ y).permute(*self.perm)
        log_scale = (self.log_scale @ y).permute(*self.perm)
        if self.temperature is not None:
            log_scale = np.log(self.temperature) + log_scale
        z = loc + torch.exp(log_scale) * eps
        log_p = -0.5 * self.d * np.log(2 * np.pi) - torch.sum(
            log_scale + 0.5 * torch.pow(eps, 2), list(range(1, self.n_dim + 1))
        )
        return z, log_p

    def log_prob(self, z, y):
        if y.dim() == 1:
            y_onehot = torch.zeros(
                (self.num_classes, len(y)), dtype=self.loc.dtype, device=self.loc.device
            )
            y_onehot.scatter_(0, y[None], 1)
            y = y_onehot
        else:
            y = y.t()
        loc = (self.loc @ y).permute(*self.perm)
        log_scale = (self.log_scale @ y).permute(*self.perm)
        if self.temperature is not None:
            log_scale = np.log(self.temperature) + log_scale
        log_p = -0.5 * self.d * np.log(2 * np.pi) - torch.sum(
            log_scale + 0.5 * torch.pow((z - loc) / torch.exp(log_scale), 2),
            list(range(1, self.n_dim + 1)),
        )
        return log_p
__init__(shape, num_classes)

Constructor

Parameters:

Name Type Description Default
shape

Tuple with shape of data, if int shape has one dimension

required
num_classes

Number of classes to condition on

required
Source code in normflows/distributions/base.py
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
def __init__(self, shape, num_classes):
    """Constructor

    Args:
      shape: Tuple with shape of data, if int shape has one dimension
      num_classes: Number of classes to condition on
    """
    super().__init__()
    if isinstance(shape, int):
        shape = (shape,)
    if isinstance(shape, list):
        shape = tuple(shape)
    self.shape = shape
    self.n_dim = len(shape)
    self.perm = [self.n_dim] + list(range(self.n_dim))
    self.d = np.prod(shape)
    self.num_classes = num_classes
    self.loc = nn.Parameter(torch.zeros(*self.shape, num_classes))
    self.log_scale = nn.Parameter(torch.zeros(*self.shape, num_classes))
    self.temperature = None  # Temperature parameter for annealed sampling

ConditionalDiagGaussian

Bases: BaseDistribution

Conditional multivariate Gaussian distribution with diagonal covariance matrix, parameters are obtained by a context encoder, context meaning the variable to condition on

Source code in normflows/distributions/base.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
class ConditionalDiagGaussian(BaseDistribution):
    """
    Conditional multivariate Gaussian distribution with diagonal
    covariance matrix, parameters are obtained by a context encoder,
    context meaning the variable to condition on
    """
    def __init__(self, shape, context_encoder):
        """Constructor

        Args:
          shape: Tuple with shape of data, if int shape has one dimension
          context_encoder: Computes mean and log of the standard deviation
          of the Gaussian, mean is the first half of the last dimension
          of the encoder output, log of the standard deviation the second
          half
        """
        super().__init__()
        if isinstance(shape, int):
            shape = (shape,)
        if isinstance(shape, list):
            shape = tuple(shape)
        self.shape = shape
        self.n_dim = len(shape)
        self.d = np.prod(shape)
        self.context_encoder = context_encoder

    def forward(self, num_samples=1, context=None):
        encoder_output = self.context_encoder(context)
        split_ind = encoder_output.shape[-1] // 2
        mean = encoder_output[..., :split_ind]
        log_scale = encoder_output[..., split_ind:]
        eps = torch.randn(
            (num_samples,) + self.shape, dtype=mean.dtype, device=mean.device
        )
        z = mean + torch.exp(log_scale) * eps
        log_p = -0.5 * self.d * np.log(2 * np.pi) - torch.sum(
            log_scale + 0.5 * torch.pow(eps, 2), list(range(1, self.n_dim + 1))
        )
        return z, log_p

    def log_prob(self, z, context=None):
        encoder_output = self.context_encoder(context)
        split_ind = encoder_output.shape[-1] // 2
        mean = encoder_output[..., :split_ind]
        log_scale = encoder_output[..., split_ind:]
        log_p = -0.5 * self.d * np.log(2 * np.pi) - torch.sum(
            log_scale + 0.5 * torch.pow((z - mean) / torch.exp(log_scale), 2),
            list(range(1, self.n_dim + 1)),
        )
        return log_p
__init__(shape, context_encoder)

Constructor

Parameters:

Name Type Description Default
shape

Tuple with shape of data, if int shape has one dimension

required
context_encoder

Computes mean and log of the standard deviation

required
Source code in normflows/distributions/base.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def __init__(self, shape, context_encoder):
    """Constructor

    Args:
      shape: Tuple with shape of data, if int shape has one dimension
      context_encoder: Computes mean and log of the standard deviation
      of the Gaussian, mean is the first half of the last dimension
      of the encoder output, log of the standard deviation the second
      half
    """
    super().__init__()
    if isinstance(shape, int):
        shape = (shape,)
    if isinstance(shape, list):
        shape = tuple(shape)
    self.shape = shape
    self.n_dim = len(shape)
    self.d = np.prod(shape)
    self.context_encoder = context_encoder

DiagGaussian

Bases: BaseDistribution

Multivariate Gaussian distribution with diagonal covariance matrix

Source code in normflows/distributions/base.py
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
class DiagGaussian(BaseDistribution):
    """
    Multivariate Gaussian distribution with diagonal covariance matrix
    """

    def __init__(self, shape, trainable=True):
        """Constructor

        Args:
          shape: Tuple with shape of data, if int shape has one dimension
          trainable: Flag whether to use trainable or fixed parameters
        """
        super().__init__()
        if isinstance(shape, int):
            shape = (shape,)
        if isinstance(shape, list):
            shape = tuple(shape)
        self.shape = shape
        self.n_dim = len(shape)
        self.d = np.prod(shape)
        if trainable:
            self.loc = nn.Parameter(torch.zeros(1, *self.shape))
            self.log_scale = nn.Parameter(torch.zeros(1, *self.shape))
        else:
            self.register_buffer("loc", torch.zeros(1, *self.shape))
            self.register_buffer("log_scale", torch.zeros(1, *self.shape))
        self.temperature = None  # Temperature parameter for annealed sampling

    def forward(self, num_samples=1, context=None):
        eps = torch.randn(
            (num_samples,) + self.shape, dtype=self.loc.dtype, device=self.loc.device
        )
        if self.temperature is None:
            log_scale = self.log_scale
        else:
            log_scale = self.log_scale + np.log(self.temperature)
        z = self.loc + torch.exp(log_scale) * eps
        log_p = -0.5 * self.d * np.log(2 * np.pi) - torch.sum(
            log_scale + 0.5 * torch.pow(eps, 2), list(range(1, self.n_dim + 1))
        )
        return z, log_p

    def log_prob(self, z, context=None):
        if self.temperature is None:
            log_scale = self.log_scale
        else:
            log_scale = self.log_scale + np.log(self.temperature)
        log_p = -0.5 * self.d * np.log(2 * np.pi) - torch.sum(
            log_scale + 0.5 * torch.pow((z - self.loc) / torch.exp(log_scale), 2),
            list(range(1, self.n_dim + 1)),
        )
        return log_p
__init__(shape, trainable=True)

Constructor

Parameters:

Name Type Description Default
shape

Tuple with shape of data, if int shape has one dimension

required
trainable

Flag whether to use trainable or fixed parameters

True
Source code in normflows/distributions/base.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def __init__(self, shape, trainable=True):
    """Constructor

    Args:
      shape: Tuple with shape of data, if int shape has one dimension
      trainable: Flag whether to use trainable or fixed parameters
    """
    super().__init__()
    if isinstance(shape, int):
        shape = (shape,)
    if isinstance(shape, list):
        shape = tuple(shape)
    self.shape = shape
    self.n_dim = len(shape)
    self.d = np.prod(shape)
    if trainable:
        self.loc = nn.Parameter(torch.zeros(1, *self.shape))
        self.log_scale = nn.Parameter(torch.zeros(1, *self.shape))
    else:
        self.register_buffer("loc", torch.zeros(1, *self.shape))
        self.register_buffer("log_scale", torch.zeros(1, *self.shape))
    self.temperature = None  # Temperature parameter for annealed sampling

GaussianMixture

Bases: BaseDistribution

Mixture of Gaussians with diagonal covariance matrix

Source code in normflows/distributions/base.py
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
class GaussianMixture(BaseDistribution):
    """
    Mixture of Gaussians with diagonal covariance matrix
    """

    def __init__(
        self, n_modes, dim, loc=None, scale=None, weights=None, trainable=True
    ):
        """Constructor

        Args:
          n_modes: Number of modes of the mixture model
          dim: Number of dimensions of each Gaussian
          loc: List of mean values
          scale: List of diagonals of the covariance matrices
          weights: List of mode probabilities
          trainable: Flag, if true parameters will be optimized during training
        """
        super().__init__()

        self.n_modes = n_modes
        self.dim = dim

        if loc is None:
            loc = np.random.randn(self.n_modes, self.dim)
        loc = np.array(loc)[None, ...]
        if scale is None:
            scale = np.ones((self.n_modes, self.dim))
        scale = np.array(scale)[None, ...]
        if weights is None:
            weights = np.ones(self.n_modes)
        weights = np.array(weights)[None, ...]
        weights /= weights.sum(1)

        if trainable:
            self.loc = nn.Parameter(torch.tensor(1.0 * loc))
            self.log_scale = nn.Parameter(torch.tensor(np.log(1.0 * scale)))
            self.weight_scores = nn.Parameter(torch.tensor(np.log(1.0 * weights)))
        else:
            self.register_buffer("loc", torch.tensor(1.0 * loc))
            self.register_buffer("log_scale", torch.tensor(np.log(1.0 * scale)))
            self.register_buffer("weight_scores", torch.tensor(np.log(1.0 * weights)))

    def forward(self, num_samples=1):
        # Get weights
        weights = torch.softmax(self.weight_scores, 1)

        # Sample mode indices
        mode = torch.multinomial(weights[0, :], num_samples, replacement=True)
        mode_1h = nn.functional.one_hot(mode, self.n_modes)
        mode_1h = mode_1h[..., None]

        # Get samples
        eps_ = torch.randn(
            num_samples, self.dim, dtype=self.loc.dtype, device=self.loc.device
        )
        scale_sample = torch.sum(torch.exp(self.log_scale) * mode_1h, 1)
        loc_sample = torch.sum(self.loc * mode_1h, 1)
        z = eps_ * scale_sample + loc_sample

        # Compute log probability
        eps = (z[:, None, :] - self.loc) / torch.exp(self.log_scale)
        log_p = (
            -0.5 * self.dim * np.log(2 * np.pi)
            + torch.log(weights)
            - 0.5 * torch.sum(torch.pow(eps, 2), 2)
            - torch.sum(self.log_scale, 2)
        )
        log_p = torch.logsumexp(log_p, 1)

        return z, log_p

    def log_prob(self, z):
        # Get weights
        weights = torch.softmax(self.weight_scores, 1)

        # Compute log probability
        eps = (z[:, None, :] - self.loc) / torch.exp(self.log_scale)
        log_p = (
            -0.5 * self.dim * np.log(2 * np.pi)
            + torch.log(weights)
            - 0.5 * torch.sum(torch.pow(eps, 2), 2)
            - torch.sum(self.log_scale, 2)
        )
        log_p = torch.logsumexp(log_p, 1)

        return log_p
__init__(n_modes, dim, loc=None, scale=None, weights=None, trainable=True)

Constructor

Parameters:

Name Type Description Default
n_modes

Number of modes of the mixture model

required
dim

Number of dimensions of each Gaussian

required
loc

List of mean values

None
scale

List of diagonals of the covariance matrices

None
weights

List of mode probabilities

None
trainable

Flag, if true parameters will be optimized during training

True
Source code in normflows/distributions/base.py
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
def __init__(
    self, n_modes, dim, loc=None, scale=None, weights=None, trainable=True
):
    """Constructor

    Args:
      n_modes: Number of modes of the mixture model
      dim: Number of dimensions of each Gaussian
      loc: List of mean values
      scale: List of diagonals of the covariance matrices
      weights: List of mode probabilities
      trainable: Flag, if true parameters will be optimized during training
    """
    super().__init__()

    self.n_modes = n_modes
    self.dim = dim

    if loc is None:
        loc = np.random.randn(self.n_modes, self.dim)
    loc = np.array(loc)[None, ...]
    if scale is None:
        scale = np.ones((self.n_modes, self.dim))
    scale = np.array(scale)[None, ...]
    if weights is None:
        weights = np.ones(self.n_modes)
    weights = np.array(weights)[None, ...]
    weights /= weights.sum(1)

    if trainable:
        self.loc = nn.Parameter(torch.tensor(1.0 * loc))
        self.log_scale = nn.Parameter(torch.tensor(np.log(1.0 * scale)))
        self.weight_scores = nn.Parameter(torch.tensor(np.log(1.0 * weights)))
    else:
        self.register_buffer("loc", torch.tensor(1.0 * loc))
        self.register_buffer("log_scale", torch.tensor(np.log(1.0 * scale)))
        self.register_buffer("weight_scores", torch.tensor(np.log(1.0 * weights)))

GaussianPCA

Bases: BaseDistribution

Gaussian distribution resulting from linearly mapping a normal distributed latent variable describing the "content of the target"

Source code in normflows/distributions/base.py
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
class GaussianPCA(BaseDistribution):
    """
    Gaussian distribution resulting from linearly mapping a normal distributed latent
    variable describing the "content of the target"
    """

    def __init__(self, dim, latent_dim=None, sigma=0.1):
        """Constructor

        Args:
          dim: Number of dimensions of the flow variables
          latent_dim: Number of dimensions of the latent "content" variable;
                           if None it is set equal to dim
          sigma: Noise level
        """
        super().__init__()

        self.dim = dim
        if latent_dim is None:
            self.latent_dim = dim
        else:
            self.latent_dim = latent_dim

        self.loc = nn.Parameter(torch.zeros(1, dim))
        self.W = nn.Parameter(torch.randn(latent_dim, dim))
        self.log_sigma = nn.Parameter(torch.tensor(np.log(sigma)))

    def forward(self, num_samples=1):
        eps = torch.randn(
            num_samples, self.latent_dim, dtype=self.loc.dtype, device=self.loc.device
        )
        z_ = torch.matmul(eps, self.W)
        z = z_ + self.loc

        Sig = torch.matmul(self.W.T, self.W) + torch.exp(
            self.log_sigma * 2
        ) * torch.eye(self.dim, dtype=self.loc.dtype, device=self.loc.device)
        log_p = (
            self.dim / 2 * np.log(2 * np.pi)
            - 0.5 * torch.det(Sig)
            - 0.5 * torch.sum(z_ * torch.matmul(z_, torch.inverse(Sig)), 1)
        )

        return z, log_p

    def log_prob(self, z):
        z_ = z - self.loc

        Sig = torch.matmul(self.W.T, self.W) + torch.exp(
            self.log_sigma * 2
        ) * torch.eye(self.dim, dtype=self.loc.dtype, device=self.loc.device)
        log_p = (
            self.dim / 2 * np.log(2 * np.pi)
            - 0.5 * torch.det(Sig)
            - 0.5 * torch.sum(z_ * torch.matmul(z_, torch.inverse(Sig)), 1)
        )

        return log_p
__init__(dim, latent_dim=None, sigma=0.1)

Constructor

Parameters:

Name Type Description Default
dim

Number of dimensions of the flow variables

required
latent_dim

Number of dimensions of the latent "content" variable; if None it is set equal to dim

None
sigma

Noise level

0.1
Source code in normflows/distributions/base.py
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
def __init__(self, dim, latent_dim=None, sigma=0.1):
    """Constructor

    Args:
      dim: Number of dimensions of the flow variables
      latent_dim: Number of dimensions of the latent "content" variable;
                       if None it is set equal to dim
      sigma: Noise level
    """
    super().__init__()

    self.dim = dim
    if latent_dim is None:
        self.latent_dim = dim
    else:
        self.latent_dim = latent_dim

    self.loc = nn.Parameter(torch.zeros(1, dim))
    self.W = nn.Parameter(torch.randn(latent_dim, dim))
    self.log_sigma = nn.Parameter(torch.tensor(np.log(sigma)))

GlowBase

Bases: BaseDistribution

Base distribution of the Glow model, i.e. Diagonal Gaussian with one mean and log scale for each channel

Source code in normflows/distributions/base.py
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
class GlowBase(BaseDistribution):
    """
    Base distribution of the Glow model, i.e. Diagonal Gaussian with one mean and
    log scale for each channel
    """

    def __init__(self, shape, num_classes=None, logscale_factor=3.0):
        """Constructor

        Args:
          shape: Shape of the variables
          num_classes: Number of classes if the base is class conditional, None otherwise
          logscale_factor: Scaling factor for mean and log variance
        """
        super().__init__()
        # Save shape and related statistics
        if isinstance(shape, int):
            shape = (shape,)
        if isinstance(shape, list):
            shape = tuple(shape)
        self.shape = shape
        self.n_dim = len(shape)
        self.num_pix = np.prod(shape[1:])
        self.d = np.prod(shape)
        self.sum_dim = list(range(1, self.n_dim + 1))
        self.num_classes = num_classes
        self.class_cond = num_classes is not None
        self.logscale_factor = logscale_factor
        # Set up parameters
        self.loc = nn.Parameter(
            torch.zeros(1, self.shape[0], *((self.n_dim - 1) * [1]))
        )
        self.loc_logs = nn.Parameter(
            torch.zeros(1, self.shape[0], *((self.n_dim - 1) * [1]))
        )
        self.log_scale = nn.Parameter(
            torch.zeros(1, self.shape[0], *((self.n_dim - 1) * [1]))
        )
        self.log_scale_logs = nn.Parameter(
            torch.zeros(1, self.shape[0], *((self.n_dim - 1) * [1]))
        )
        # Class conditional parameter if needed
        if self.class_cond:
            self.loc_cc = nn.Parameter(torch.zeros(self.num_classes, self.shape[0]))
            self.log_scale_cc = nn.Parameter(
                torch.zeros(self.num_classes, self.shape[0])
            )
        # Temperature parameter for annealed sampling
        self.temperature = None

    def forward(self, num_samples=1, y=None):
        # Prepare parameter
        loc = self.loc * torch.exp(self.loc_logs * self.logscale_factor)
        log_scale = self.log_scale * torch.exp(
            self.log_scale_logs * self.logscale_factor
        )
        if self.class_cond:
            if y is not None:
                num_samples = len(y)
            else:
                y = torch.randint(
                    self.num_classes, (num_samples,), device=self.loc.device
                )
            if y.dim() == 1:
                y_onehot = torch.zeros(
                    (len(y), self.num_classes),
                    dtype=self.loc.dtype,
                    device=self.loc.device,
                )
                y_onehot.scatter_(1, y[:, None], 1)
                y = y_onehot
            loc = loc + (y @ self.loc_cc).view(
                y.size(0), self.shape[0], *((self.n_dim - 1) * [1])
            )
            log_scale = log_scale + (y @ self.log_scale_cc).view(
                y.size(0), self.shape[0], *((self.n_dim - 1) * [1])
            )
        if self.temperature is not None:
            log_scale = log_scale + np.log(self.temperature)
        # Sample
        eps = torch.randn(
            (num_samples,) + self.shape, dtype=self.loc.dtype, device=self.loc.device
        )
        z = loc + torch.exp(log_scale) * eps
        # Get log prob
        log_p = (
            -0.5 * self.d * np.log(2 * np.pi)
            - self.num_pix * torch.sum(log_scale, dim=self.sum_dim)
            - 0.5 * torch.sum(torch.pow(eps, 2), dim=self.sum_dim)
        )
        return z, log_p

    def log_prob(self, z, y=None):
        # Perpare parameter
        loc = self.loc * torch.exp(self.loc_logs * self.logscale_factor)
        log_scale = self.log_scale * torch.exp(
            self.log_scale_logs * self.logscale_factor
        )
        if self.class_cond:
            if y.dim() == 1:
                y_onehot = torch.zeros(
                    (len(y), self.num_classes),
                    dtype=self.loc.dtype,
                    device=self.loc.device,
                )
                y_onehot.scatter_(1, y[:, None], 1)
                y = y_onehot
            loc = loc + (y @ self.loc_cc).view(
                y.size(0), self.shape[0], *((self.n_dim - 1) * [1])
            )
            log_scale = log_scale + (y @ self.log_scale_cc).view(
                y.size(0), self.shape[0], *((self.n_dim - 1) * [1])
            )
        if self.temperature is not None:
            log_scale = log_scale + np.log(self.temperature)
        # Get log prob
        log_p = (
            -0.5 * self.d * np.log(2 * np.pi)
            - self.num_pix * torch.sum(log_scale, dim=self.sum_dim)
            - 0.5
            * torch.sum(
                torch.pow((z - loc) / torch.exp(log_scale), 2), dim=self.sum_dim
            )
        )
        return log_p
__init__(shape, num_classes=None, logscale_factor=3.0)

Constructor

Parameters:

Name Type Description Default
shape

Shape of the variables

required
num_classes

Number of classes if the base is class conditional, None otherwise

None
logscale_factor

Scaling factor for mean and log variance

3.0
Source code in normflows/distributions/base.py
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
def __init__(self, shape, num_classes=None, logscale_factor=3.0):
    """Constructor

    Args:
      shape: Shape of the variables
      num_classes: Number of classes if the base is class conditional, None otherwise
      logscale_factor: Scaling factor for mean and log variance
    """
    super().__init__()
    # Save shape and related statistics
    if isinstance(shape, int):
        shape = (shape,)
    if isinstance(shape, list):
        shape = tuple(shape)
    self.shape = shape
    self.n_dim = len(shape)
    self.num_pix = np.prod(shape[1:])
    self.d = np.prod(shape)
    self.sum_dim = list(range(1, self.n_dim + 1))
    self.num_classes = num_classes
    self.class_cond = num_classes is not None
    self.logscale_factor = logscale_factor
    # Set up parameters
    self.loc = nn.Parameter(
        torch.zeros(1, self.shape[0], *((self.n_dim - 1) * [1]))
    )
    self.loc_logs = nn.Parameter(
        torch.zeros(1, self.shape[0], *((self.n_dim - 1) * [1]))
    )
    self.log_scale = nn.Parameter(
        torch.zeros(1, self.shape[0], *((self.n_dim - 1) * [1]))
    )
    self.log_scale_logs = nn.Parameter(
        torch.zeros(1, self.shape[0], *((self.n_dim - 1) * [1]))
    )
    # Class conditional parameter if needed
    if self.class_cond:
        self.loc_cc = nn.Parameter(torch.zeros(self.num_classes, self.shape[0]))
        self.log_scale_cc = nn.Parameter(
            torch.zeros(self.num_classes, self.shape[0])
        )
    # Temperature parameter for annealed sampling
    self.temperature = None

Uniform

Bases: BaseDistribution

Multivariate uniform distribution

Source code in normflows/distributions/base.py
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
class Uniform(BaseDistribution):
    """
    Multivariate uniform distribution
    """

    def __init__(self, shape, low=-1.0, high=1.0):
        """Constructor

        Args:
          shape: Tuple with shape of data, if int shape has one dimension
          low: Lower bound of uniform distribution
          high: Upper bound of uniform distribution
        """
        super().__init__()
        if isinstance(shape, int):
            shape = (shape,)
        if isinstance(shape, list):
            shape = tuple(shape)
        self.shape = shape
        self.d = np.prod(shape)
        self.low = torch.tensor(low)
        self.high = torch.tensor(high)
        self.log_prob_val = -self.d * np.log(self.high - self.low)

    def forward(self, num_samples=1, context=None):
        eps = torch.rand(
            (num_samples,) + self.shape, dtype=self.low.dtype, device=self.low.device
        )
        z = self.low + (self.high - self.low) * eps
        log_p = self.log_prob_val * torch.ones(num_samples, device=self.low.device)
        return z, log_p

    def log_prob(self, z, context=None):
        log_p = self.log_prob_val * torch.ones(z.shape[0], device=z.device)
        out_range = torch.logical_or(z < self.low, z > self.high)
        ind_inf = torch.any(torch.reshape(out_range, (z.shape[0], -1)), dim=-1)
        log_p[ind_inf] = -np.inf
        return log_p
__init__(shape, low=-1.0, high=1.0)

Constructor

Parameters:

Name Type Description Default
shape

Tuple with shape of data, if int shape has one dimension

required
low

Lower bound of uniform distribution

-1.0
high

Upper bound of uniform distribution

1.0
Source code in normflows/distributions/base.py
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
def __init__(self, shape, low=-1.0, high=1.0):
    """Constructor

    Args:
      shape: Tuple with shape of data, if int shape has one dimension
      low: Lower bound of uniform distribution
      high: Upper bound of uniform distribution
    """
    super().__init__()
    if isinstance(shape, int):
        shape = (shape,)
    if isinstance(shape, list):
        shape = tuple(shape)
    self.shape = shape
    self.d = np.prod(shape)
    self.low = torch.tensor(low)
    self.high = torch.tensor(high)
    self.log_prob_val = -self.d * np.log(self.high - self.low)

UniformGaussian

Bases: BaseDistribution

Distribution of a 1D random variable with some entries having a uniform and others a Gaussian distribution

Source code in normflows/distributions/base.py
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
class UniformGaussian(BaseDistribution):
    """
    Distribution of a 1D random variable with some entries having a uniform and
    others a Gaussian distribution
    """

    def __init__(self, ndim, ind, scale=None):
        """Constructor

        Args:
          ndim: Int, number of dimensions
          ind: Iterable, indices of uniformly distributed entries
          scale: Iterable, standard deviation of Gaussian or width of uniform distribution
        """
        super().__init__()
        self.ndim = ndim
        if isinstance(ind, int):
            ind = [ind]

        # Set up indices and permutations
        self.ndim = ndim
        if torch.is_tensor(ind):
            self.register_buffer("ind", torch._cast_Long(ind))
        else:
            self.register_buffer("ind", torch.tensor(ind, dtype=torch.long))

        ind_ = []
        for i in range(self.ndim):
            if not i in self.ind:
                ind_ += [i]
        self.register_buffer("ind_", torch.tensor(ind_, dtype=torch.long))

        perm_ = torch.cat((self.ind, self.ind_))
        inv_perm_ = torch.zeros_like(perm_)
        for i in range(self.ndim):
            inv_perm_[perm_[i]] = i
        self.register_buffer("inv_perm", inv_perm_)

        if scale is None:
            self.register_buffer("scale", torch.ones(self.ndim))
        else:
            self.register_buffer("scale", scale)

    def forward(self, num_samples=1, context=None):
        z = self.sample(num_samples)
        return z, self.log_prob(z)

    def sample(self, num_samples=1, context=None):
        eps_u = (
            torch.rand(
                (num_samples, len(self.ind)),
                dtype=self.scale.dtype,
                device=self.scale.device,
            )
            - 0.5
        )
        eps_g = torch.randn(
            (num_samples, len(self.ind_)),
            dtype=self.scale.dtype,
            device=self.scale.device,
        )
        z = torch.cat((eps_u, eps_g), -1)
        z = z[..., self.inv_perm]
        return self.scale * z

    def log_prob(self, z, context=None):
        log_p_u = torch.broadcast_to(-torch.log(self.scale[self.ind]), (len(z), -1))
        log_p_g = (
            -0.5 * np.log(2 * np.pi)
            - torch.log(self.scale[self.ind_])
            - 0.5 * torch.pow(z[..., self.ind_] / self.scale[self.ind_], 2)
        )
        return torch.sum(log_p_u, -1) + torch.sum(log_p_g, -1)
__init__(ndim, ind, scale=None)

Constructor

Parameters:

Name Type Description Default
ndim

Int, number of dimensions

required
ind

Iterable, indices of uniformly distributed entries

required
scale

Iterable, standard deviation of Gaussian or width of uniform distribution

None
Source code in normflows/distributions/base.py
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
def __init__(self, ndim, ind, scale=None):
    """Constructor

    Args:
      ndim: Int, number of dimensions
      ind: Iterable, indices of uniformly distributed entries
      scale: Iterable, standard deviation of Gaussian or width of uniform distribution
    """
    super().__init__()
    self.ndim = ndim
    if isinstance(ind, int):
        ind = [ind]

    # Set up indices and permutations
    self.ndim = ndim
    if torch.is_tensor(ind):
        self.register_buffer("ind", torch._cast_Long(ind))
    else:
        self.register_buffer("ind", torch.tensor(ind, dtype=torch.long))

    ind_ = []
    for i in range(self.ndim):
        if not i in self.ind:
            ind_ += [i]
    self.register_buffer("ind_", torch.tensor(ind_, dtype=torch.long))

    perm_ = torch.cat((self.ind, self.ind_))
    inv_perm_ = torch.zeros_like(perm_)
    for i in range(self.ndim):
        inv_perm_[perm_[i]] = i
    self.register_buffer("inv_perm", inv_perm_)

    if scale is None:
        self.register_buffer("scale", torch.ones(self.ndim))
    else:
        self.register_buffer("scale", scale)

decoder

BaseDecoder

Bases: Module

Source code in normflows/distributions/decoder.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class BaseDecoder(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, z):
        """Decodes z to x

        Args:
          z: latent variable

        Returns:
          x, std of x
        """
        raise NotImplementedError

    def log_prob(self, x, z):
        """Log probability

        Args:
          x: observable
          z: latent variable

        Returns:
          log(p) of x given z
        """
        raise NotImplementedError
forward(z)

Decodes z to x

Parameters:

Name Type Description Default
z

latent variable

required

Returns:

Type Description

x, std of x

Source code in normflows/distributions/decoder.py
10
11
12
13
14
15
16
17
18
19
def forward(self, z):
    """Decodes z to x

    Args:
      z: latent variable

    Returns:
      x, std of x
    """
    raise NotImplementedError
log_prob(x, z)

Log probability

Parameters:

Name Type Description Default
x

observable

required
z

latent variable

required

Returns:

Type Description

log(p) of x given z

Source code in normflows/distributions/decoder.py
21
22
23
24
25
26
27
28
29
30
31
def log_prob(self, x, z):
    """Log probability

    Args:
      x: observable
      z: latent variable

    Returns:
      log(p) of x given z
    """
    raise NotImplementedError

NNBernoulliDecoder

Bases: BaseDecoder

BaseDecoder representing a Bernoulli distribution with mean parametrized by a NN

Source code in normflows/distributions/decoder.py
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
class NNBernoulliDecoder(BaseDecoder):
    """
    BaseDecoder representing a Bernoulli distribution with mean parametrized by a NN
    """

    def __init__(self, net):
        """Constructor

        Args:
          net: neural network parametrizing mean Bernoulli (mean = sigmoid(nn_out)
        """
        super().__init__()
        self.net = net

    def forward(self, z):
        mean = torch.sigmoid(self.net(z))
        return mean

    def log_prob(self, x, z):
        score = self.net(z)
        if len(z) > len(x):
            x = x.unsqueeze(1)
            x = x.repeat(1, z.size()[0] // x.size()[0], *((x.dim() - 2) * [1])).view(
                -1, *x.size()[2:]
            )
        log_sig = lambda a: -torch.relu(-a) - torch.log(1 + torch.exp(-torch.abs(a)))
        log_p = torch.sum(
            x * log_sig(score) + (1 - x) * log_sig(-score), list(range(1, x.dim()))
        )
        return log_p
__init__(net)

Constructor

Parameters:

Name Type Description Default
net

neural network parametrizing mean Bernoulli (mean = sigmoid(nn_out)

required
Source code in normflows/distributions/decoder.py
78
79
80
81
82
83
84
85
def __init__(self, net):
    """Constructor

    Args:
      net: neural network parametrizing mean Bernoulli (mean = sigmoid(nn_out)
    """
    super().__init__()
    self.net = net

NNDiagGaussianDecoder

Bases: BaseDecoder

BaseDecoder representing a diagonal Gaussian distribution with mean and std parametrized by a NN

Source code in normflows/distributions/decoder.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
class NNDiagGaussianDecoder(BaseDecoder):
    """
    BaseDecoder representing a diagonal Gaussian distribution with mean and std parametrized by a NN
    """

    def __init__(self, net):
        """Constructor

        Args:
          net: neural network parametrizing mean and standard deviation of diagonal Gaussian
        """
        super().__init__()
        self.net = net

    def forward(self, z):
        mean_std = self.net(z)
        n_hidden = mean_std.size()[1] // 2
        mean = mean_std[:, :n_hidden, ...]
        std = torch.exp(0.5 * mean_std[:, n_hidden:, ...])
        return mean, std

    def log_prob(self, x, z):
        mean_std = self.net(z)
        n_hidden = mean_std.size()[1] // 2
        mean = mean_std[:, :n_hidden, ...]
        var = torch.exp(mean_std[:, n_hidden:, ...])
        if len(z) > len(x):
            x = x.unsqueeze(1)
            x = x.repeat(1, z.size()[0] // x.size()[0], *((x.dim() - 2) * [1])).view(
                -1, *x.size()[2:]
            )
        log_p = -0.5 * torch.prod(torch.tensor(z.size()[1:])) * np.log(
            2 * np.pi
        ) - 0.5 * torch.sum(
            torch.log(var) + (x - mean) ** 2 / var, list(range(1, z.dim()))
        )
        return log_p
__init__(net)

Constructor

Parameters:

Name Type Description Default
net

neural network parametrizing mean and standard deviation of diagonal Gaussian

required
Source code in normflows/distributions/decoder.py
39
40
41
42
43
44
45
46
def __init__(self, net):
    """Constructor

    Args:
      net: neural network parametrizing mean and standard deviation of diagonal Gaussian
    """
    super().__init__()
    self.net = net

distribution_test

DistributionTest

Bases: TestCase

Generic test case for distribution modules

Source code in normflows/distributions/distribution_test.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class DistributionTest(unittest.TestCase):
    """
    Generic test case for distribution modules
    """
    def assertClose(self, actual, expected, atol=None, rtol=None):
        assert_close(actual, expected, atol=atol, rtol=rtol)

    def checkForward(self, distribution, num_samples=1, **kwargs):
        # Do forward
        outputs, log_p = distribution(num_samples, **kwargs)
        # Check type
        assert outputs.dtype == log_p.dtype
        # Check shape
        assert log_p.shape[0] == num_samples
        assert outputs.shape[0] == num_samples
        # Check dim
        assert outputs.dim() > log_p.dim()
        # Return results
        return outputs, log_p

    def checkLogProb(self, distribution, inputs, **kwargs):
        # Compute log prob
        log_p = distribution.log_prob(inputs, **kwargs)
        # Check type
        assert log_p.dtype == inputs.dtype
        # Check shape
        assert log_p.shape[0] == inputs.shape[0]
        # Return results
        return log_p

    def checkSample(self, distribution, num_samples=1, **kwargs):
        # Do forward
        outputs = distribution.sample(num_samples, **kwargs)
        # Check shape
        assert outputs.shape[0] == num_samples
        # Check dim
        assert outputs.dim() > 1
        # Return results
        return outputs

    def checkForwardLogProb(self, distribution, num_samples=1, atol=None, rtol=None, **kwargs):
        # Check forward
        outputs, log_p = self.checkForward(distribution, num_samples, **kwargs)
        # Check log prob
        log_p_ = self.checkLogProb(distribution, outputs, **kwargs)
        # Check consistency
        self.assertClose(log_p_, log_p, atol, rtol)

encoder

BaseEncoder

Bases: Module

Base distribution of a flow-based variational autoencoder Parameters of the distribution depend of the target variable x

Source code in normflows/distributions/encoder.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class BaseEncoder(nn.Module):
    """
    Base distribution of a flow-based variational autoencoder
    Parameters of the distribution depend of the target variable x
    """

    def __init__(self):
        super().__init__()

    def forward(self, x, num_samples=1):
        """
        Args:
          x: Variable to condition on, first dimension is batch size
          num_samples: number of samples to draw per element of mini-batch

        Returns
          sample of z for x, log probability for sample
        """
        raise NotImplementedError

    def log_prob(self, z, x):
        """

        Args:
          z: Primary random variable, first dimension is batch size
          x: Variable to condition on, first dimension is batch size

        Returns:
          log probability of z given x
        """
        raise NotImplementedError
forward(x, num_samples=1)

Parameters:

Name Type Description Default
x

Variable to condition on, first dimension is batch size

required
num_samples

number of samples to draw per element of mini-batch

1

Returns sample of z for x, log probability for sample

Source code in normflows/distributions/encoder.py
15
16
17
18
19
20
21
22
23
24
def forward(self, x, num_samples=1):
    """
    Args:
      x: Variable to condition on, first dimension is batch size
      num_samples: number of samples to draw per element of mini-batch

    Returns
      sample of z for x, log probability for sample
    """
    raise NotImplementedError
log_prob(z, x)

Parameters:

Name Type Description Default
z

Primary random variable, first dimension is batch size

required
x

Variable to condition on, first dimension is batch size

required

Returns:

Type Description

log probability of z given x

Source code in normflows/distributions/encoder.py
26
27
28
29
30
31
32
33
34
35
36
def log_prob(self, z, x):
    """

    Args:
      z: Primary random variable, first dimension is batch size
      x: Variable to condition on, first dimension is batch size

    Returns:
      log probability of z given x
    """
    raise NotImplementedError

ConstDiagGaussian

Bases: BaseEncoder

Source code in normflows/distributions/encoder.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
class ConstDiagGaussian(BaseEncoder):
    def __init__(self, loc, scale):
        """Multivariate Gaussian distribution with diagonal covariance and parameters being constant wrt x

        Args:
          loc: mean vector of the distribution
          scale: vector of the standard deviations on the diagonal of the covariance matrix
        """
        super().__init__()
        self.d = len(loc)
        if not torch.is_tensor(loc):
            loc = torch.tensor(loc)
        if not torch.is_tensor(scale):
            scale = torch.tensor(scale)
        self.loc = nn.Parameter(loc.reshape((1, 1, self.d)))
        self.scale = nn.Parameter(scale)

    def forward(self, x=None, num_samples=1):
        """
        Args:
          x: Variable to condition on, will only be used to determine the batch size
          num_samples: number of samples to draw per element of mini-batch

        Returns:
          sample of z for x, log probability for sample
        """
        if x is not None:
            batch_size = len(x)
        else:
            batch_size = 1
        eps = torch.randn((batch_size, num_samples, self.d), device=x.device)
        z = self.loc + self.scale * eps
        log_q = -0.5 * self.d * np.log(2 * np.pi) - torch.sum(
            torch.log(self.scale) + 0.5 * torch.pow(eps, 2), 2
        )
        return z, log_q

    def log_prob(self, z, x):
        """
        Args:
          z: Primary random variable, first dimension is batch dimension
          x: Variable to condition on, first dimension is batch dimension

        Returns:
          log probability of z given x
        """
        if z.dim() == 1:
            z = z.unsqueeze(0)
        if z.dim() == 2:
            z = z.unsqueeze(0)
        log_q = -0.5 * self.d * np.log(2 * np.pi) - torch.sum(
            torch.log(self.scale) + 0.5 * ((z - self.loc) / self.scale) ** 2, 2
        )
        return log_q
__init__(loc, scale)

Multivariate Gaussian distribution with diagonal covariance and parameters being constant wrt x

Parameters:

Name Type Description Default
loc

mean vector of the distribution

required
scale

vector of the standard deviations on the diagonal of the covariance matrix

required
Source code in normflows/distributions/encoder.py
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def __init__(self, loc, scale):
    """Multivariate Gaussian distribution with diagonal covariance and parameters being constant wrt x

    Args:
      loc: mean vector of the distribution
      scale: vector of the standard deviations on the diagonal of the covariance matrix
    """
    super().__init__()
    self.d = len(loc)
    if not torch.is_tensor(loc):
        loc = torch.tensor(loc)
    if not torch.is_tensor(scale):
        scale = torch.tensor(scale)
    self.loc = nn.Parameter(loc.reshape((1, 1, self.d)))
    self.scale = nn.Parameter(scale)
forward(x=None, num_samples=1)

Parameters:

Name Type Description Default
x

Variable to condition on, will only be used to determine the batch size

None
num_samples

number of samples to draw per element of mini-batch

1

Returns:

Type Description

sample of z for x, log probability for sample

Source code in normflows/distributions/encoder.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def forward(self, x=None, num_samples=1):
    """
    Args:
      x: Variable to condition on, will only be used to determine the batch size
      num_samples: number of samples to draw per element of mini-batch

    Returns:
      sample of z for x, log probability for sample
    """
    if x is not None:
        batch_size = len(x)
    else:
        batch_size = 1
    eps = torch.randn((batch_size, num_samples, self.d), device=x.device)
    z = self.loc + self.scale * eps
    log_q = -0.5 * self.d * np.log(2 * np.pi) - torch.sum(
        torch.log(self.scale) + 0.5 * torch.pow(eps, 2), 2
    )
    return z, log_q
log_prob(z, x)

Parameters:

Name Type Description Default
z

Primary random variable, first dimension is batch dimension

required
x

Variable to condition on, first dimension is batch dimension

required

Returns:

Type Description

log probability of z given x

Source code in normflows/distributions/encoder.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def log_prob(self, z, x):
    """
    Args:
      z: Primary random variable, first dimension is batch dimension
      x: Variable to condition on, first dimension is batch dimension

    Returns:
      log probability of z given x
    """
    if z.dim() == 1:
        z = z.unsqueeze(0)
    if z.dim() == 2:
        z = z.unsqueeze(0)
    log_q = -0.5 * self.d * np.log(2 * np.pi) - torch.sum(
        torch.log(self.scale) + 0.5 * ((z - self.loc) / self.scale) ** 2, 2
    )
    return log_q

NNDiagGaussian

Bases: BaseEncoder

Diagonal Gaussian distribution with mean and variance determined by a neural network

Source code in normflows/distributions/encoder.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
class NNDiagGaussian(BaseEncoder):
    """
    Diagonal Gaussian distribution with mean and variance determined by a neural network
    """

    def __init__(self, net):
        """Construtor

        Args:
          net: net computing mean (first n / 2 outputs), standard deviation (second n / 2 outputs)
        """
        super().__init__()
        self.net = net

    def forward(self, x, num_samples=1):
        """
        Args:
          x: Variable to condition on
          num_samples: number of samples to draw per element of mini-batch

        Returns:
          sample of z for x, log probability for sample
        """
        batch_size = len(x)
        mean_std = self.net(x)
        n_hidden = mean_std.size()[1] // 2
        mean = mean_std[:, :n_hidden, ...].unsqueeze(1)
        std = torch.exp(0.5 * mean_std[:, n_hidden : (2 * n_hidden), ...].unsqueeze(1))
        eps = torch.randn(
            (batch_size, num_samples) + tuple(mean.size()[2:]), device=x.device
        )
        z = mean + std * eps
        log_q = -0.5 * torch.prod(torch.tensor(z.size()[2:])) * np.log(
            2 * np.pi
        ) - torch.sum(torch.log(std) + 0.5 * torch.pow(eps, 2), list(range(2, z.dim())))
        return z, log_q

    def log_prob(self, z, x):
        """

        Args:
          z: Primary random variable, first dimension is batch dimension
          x: Variable to condition on, first dimension is batch dimension

        Returns:
          log probability of z given x
        """
        if z.dim() == 1:
            z = z.unsqueeze(0)
        if z.dim() == 2:
            z = z.unsqueeze(0)
        mean_std = self.net(x)
        n_hidden = mean_std.size()[1] // 2
        mean = mean_std[:, :n_hidden, ...].unsqueeze(1)
        var = torch.exp(mean_std[:, n_hidden : (2 * n_hidden), ...].unsqueeze(1))
        log_q = -0.5 * torch.prod(torch.tensor(z.size()[2:])) * np.log(
            2 * np.pi
        ) - 0.5 * torch.sum(torch.log(var) + (z - mean) ** 2 / var, 2)
        return log_q
__init__(net)

Construtor

Parameters:

Name Type Description Default
net

net computing mean (first n / 2 outputs), standard deviation (second n / 2 outputs)

required
Source code in normflows/distributions/encoder.py
135
136
137
138
139
140
141
142
def __init__(self, net):
    """Construtor

    Args:
      net: net computing mean (first n / 2 outputs), standard deviation (second n / 2 outputs)
    """
    super().__init__()
    self.net = net
forward(x, num_samples=1)

Parameters:

Name Type Description Default
x

Variable to condition on

required
num_samples

number of samples to draw per element of mini-batch

1

Returns:

Type Description

sample of z for x, log probability for sample

Source code in normflows/distributions/encoder.py
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
def forward(self, x, num_samples=1):
    """
    Args:
      x: Variable to condition on
      num_samples: number of samples to draw per element of mini-batch

    Returns:
      sample of z for x, log probability for sample
    """
    batch_size = len(x)
    mean_std = self.net(x)
    n_hidden = mean_std.size()[1] // 2
    mean = mean_std[:, :n_hidden, ...].unsqueeze(1)
    std = torch.exp(0.5 * mean_std[:, n_hidden : (2 * n_hidden), ...].unsqueeze(1))
    eps = torch.randn(
        (batch_size, num_samples) + tuple(mean.size()[2:]), device=x.device
    )
    z = mean + std * eps
    log_q = -0.5 * torch.prod(torch.tensor(z.size()[2:])) * np.log(
        2 * np.pi
    ) - torch.sum(torch.log(std) + 0.5 * torch.pow(eps, 2), list(range(2, z.dim())))
    return z, log_q
log_prob(z, x)

Parameters:

Name Type Description Default
z

Primary random variable, first dimension is batch dimension

required
x

Variable to condition on, first dimension is batch dimension

required

Returns:

Type Description

log probability of z given x

Source code in normflows/distributions/encoder.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
def log_prob(self, z, x):
    """

    Args:
      z: Primary random variable, first dimension is batch dimension
      x: Variable to condition on, first dimension is batch dimension

    Returns:
      log probability of z given x
    """
    if z.dim() == 1:
        z = z.unsqueeze(0)
    if z.dim() == 2:
        z = z.unsqueeze(0)
    mean_std = self.net(x)
    n_hidden = mean_std.size()[1] // 2
    mean = mean_std[:, :n_hidden, ...].unsqueeze(1)
    var = torch.exp(mean_std[:, n_hidden : (2 * n_hidden), ...].unsqueeze(1))
    log_q = -0.5 * torch.prod(torch.tensor(z.size()[2:])) * np.log(
        2 * np.pi
    ) - 0.5 * torch.sum(torch.log(var) + (z - mean) ** 2 / var, 2)
    return log_q

linear_interpolation

LinearInterpolation

Linear interpolation of two distributions in the log space

Source code in normflows/distributions/linear_interpolation.py
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class LinearInterpolation:
    """
    Linear interpolation of two distributions in the log space
    """

    def __init__(self, dist1, dist2, alpha):
        """Constructor

        Interpolation parameter alpha:

        ```
        log_p = alpha * log_p_1 + (1 - alpha) * log_p_2
        ```

        Args:
          dist1: First distribution
          dist2: Second distribution
          alpha: Interpolation parameter
        """
        self.alpha = alpha
        self.dist1 = dist1
        self.dist2 = dist2

    def log_prob(self, z):
        return self.alpha * self.dist1.log_prob(z) + (
            1 - self.alpha
        ) * self.dist2.log_prob(z)
__init__(dist1, dist2, alpha)

Constructor

Interpolation parameter alpha:

log_p = alpha * log_p_1 + (1 - alpha) * log_p_2

Parameters:

Name Type Description Default
dist1

First distribution

required
dist2

Second distribution

required
alpha

Interpolation parameter

required
Source code in normflows/distributions/linear_interpolation.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
def __init__(self, dist1, dist2, alpha):
    """Constructor

    Interpolation parameter alpha:

    ```
    log_p = alpha * log_p_1 + (1 - alpha) * log_p_2
    ```

    Args:
      dist1: First distribution
      dist2: Second distribution
      alpha: Interpolation parameter
    """
    self.alpha = alpha
    self.dist1 = dist1
    self.dist2 = dist2

mh_proposal

DiagGaussianProposal

Bases: MHProposal

Diagonal Gaussian distribution with previous value as mean as a proposal for Metropolis Hastings algorithm

Source code in normflows/distributions/mh_proposal.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
class DiagGaussianProposal(MHProposal):
    """
    Diagonal Gaussian distribution with previous value as mean
    as a proposal for Metropolis Hastings algorithm
    """

    def __init__(self, shape, scale):
        """Constructor

        Args:
          shape: Shape of variables to sample
          scale: Standard deviation of distribution
        """
        super().__init__()
        self.shape = shape
        self.scale_cpu = torch.tensor(scale)
        self.register_buffer("scale", self.scale_cpu.unsqueeze(0))

    def sample(self, z):
        num_samples = len(z)
        eps = torch.randn((num_samples,) + self.shape, dtype=z.dtype, device=z.device)
        z_ = eps * self.scale + z
        return z_

    def log_prob(self, z_, z):
        log_p = -0.5 * np.prod(self.shape) * np.log(2 * np.pi) - torch.sum(
            torch.log(self.scale) + 0.5 * torch.pow((z_ - z) / self.scale, 2),
            list(range(1, z.dim())),
        )
        return log_p

    def forward(self, z):
        num_samples = len(z)
        eps = torch.randn((num_samples,) + self.shape, dtype=z.dtype, device=z.device)
        z_ = eps * self.scale + z
        log_p_diff = torch.zeros(num_samples, dtype=z.dtype, device=z.device)
        return z_, log_p_diff
__init__(shape, scale)

Constructor

Parameters:

Name Type Description Default
shape

Shape of variables to sample

required
scale

Standard deviation of distribution

required
Source code in normflows/distributions/mh_proposal.py
53
54
55
56
57
58
59
60
61
62
63
def __init__(self, shape, scale):
    """Constructor

    Args:
      shape: Shape of variables to sample
      scale: Standard deviation of distribution
    """
    super().__init__()
    self.shape = shape
    self.scale_cpu = torch.tensor(scale)
    self.register_buffer("scale", self.scale_cpu.unsqueeze(0))

MHProposal

Bases: Module

Proposal distribution for the Metropolis Hastings algorithm

Source code in normflows/distributions/mh_proposal.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class MHProposal(nn.Module):
    """
    Proposal distribution for the Metropolis Hastings algorithm
    """

    def __init__(self):
        super().__init__()

    def sample(self, z):
        """
        Sample new value based on previous z
        """
        raise NotImplementedError

    def log_prob(self, z_, z):
        """
        Args:
          z_: Potential new sample
          z: Previous sample

        Returns:
          Log probability of proposal distribution
        """
        raise NotImplementedError

    def forward(self, z):
        """Draw samples given z and compute log probability difference

        ```
        log(p(z | z_new)) - log(p(z_new | z))
        ```

        Args:
          z: Previous samples

        Returns:
          Proposal, difference of log probability ratio
        """
        raise NotImplementedError
forward(z)

Draw samples given z and compute log probability difference

log(p(z | z_new)) - log(p(z_new | z))

Parameters:

Name Type Description Default
z

Previous samples

required

Returns:

Type Description

Proposal, difference of log probability ratio

Source code in normflows/distributions/mh_proposal.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def forward(self, z):
    """Draw samples given z and compute log probability difference

    ```
    log(p(z | z_new)) - log(p(z_new | z))
    ```

    Args:
      z: Previous samples

    Returns:
      Proposal, difference of log probability ratio
    """
    raise NotImplementedError
log_prob(z_, z)

Parameters:

Name Type Description Default
z_

Potential new sample

required
z

Previous sample

required

Returns:

Type Description

Log probability of proposal distribution

Source code in normflows/distributions/mh_proposal.py
20
21
22
23
24
25
26
27
28
29
def log_prob(self, z_, z):
    """
    Args:
      z_: Potential new sample
      z: Previous sample

    Returns:
      Log probability of proposal distribution
    """
    raise NotImplementedError
sample(z)

Sample new value based on previous z

Source code in normflows/distributions/mh_proposal.py
14
15
16
17
18
def sample(self, z):
    """
    Sample new value based on previous z
    """
    raise NotImplementedError

prior

ImagePrior

Bases: Module

Intensities of an image determine probability density of prior

Source code in normflows/distributions/prior.py
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
class ImagePrior(nn.Module):
    """
    Intensities of an image determine probability density of prior
    """

    def __init__(self, image, x_range=[-3, 3], y_range=[-3, 3], eps=1.0e-10):
        """Constructor

        Args:
          image: image as np matrix
          x_range: x range to position image at
          y_range: y range to position image at
          eps: small value to add to image to avoid log(0) problems
        """
        super().__init__()
        image_ = np.flip(image, 0).transpose() + eps
        self.image_cpu = torch.tensor(image_ / np.max(image_))
        self.image_size_cpu = self.image_cpu.size()
        self.x_range = torch.tensor(x_range)
        self.y_range = torch.tensor(y_range)

        self.register_buffer("image", self.image_cpu)
        self.register_buffer(
            "image_size", torch.tensor(self.image_size_cpu).unsqueeze(0)
        )
        self.register_buffer(
            "density", torch.log(self.image_cpu / torch.sum(self.image_cpu))
        )
        self.register_buffer(
            "scale",
            torch.tensor(
                [[self.x_range[1] - self.x_range[0], self.y_range[1] - self.y_range[0]]]
            ),
        )
        self.register_buffer(
            "shift", torch.tensor([[self.x_range[0], self.y_range[0]]])
        )

    def log_prob(self, z):
        """
        Args:
          z: value or batch of latent variable

        Returns:
          log probability of the distribution for z
        """
        z_ = torch.clamp((z - self.shift) / self.scale, max=1, min=0)
        ind = (z_ * (self.image_size - 1)).long()
        return self.density[ind[:, 0], ind[:, 1]]

    def rejection_sampling(self, num_steps=1):
        """Perform rejection sampling on image distribution

        Args:
         num_steps: Number of rejection sampling steps to perform

        Returns:
          Accepted samples
        """
        z_ = torch.rand(
            (num_steps, 2), dtype=self.image.dtype, device=self.image.device
        )
        prob = torch.rand(num_steps, dtype=self.image.dtype, device=self.image.device)
        ind = (z_ * (self.image_size - 1)).long()
        intensity = self.image[ind[:, 0], ind[:, 1]]
        accept = intensity > prob
        z = z_[accept, :] * self.scale + self.shift
        return z

    def sample(self, num_samples=1):
        """Sample from image distribution through rejection sampling

        Args:
          num_samples: Number of samples to draw

        Returns:
          Samples
        """
        z = torch.ones((0, 2), dtype=self.image.dtype, device=self.image.device)
        while len(z) < num_samples:
            z_ = self.rejection_sampling(num_samples)
            ind = np.min([len(z_), num_samples - len(z)])
            z = torch.cat([z, z_[:ind, :]], 0)
        return z
__init__(image, x_range=[-3, 3], y_range=[-3, 3], eps=1e-10)

Constructor

Parameters:

Name Type Description Default
image

image as np matrix

required
x_range

x range to position image at

[-3, 3]
y_range

y range to position image at

[-3, 3]
eps

small value to add to image to avoid log(0) problems

1e-10
Source code in normflows/distributions/prior.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def __init__(self, image, x_range=[-3, 3], y_range=[-3, 3], eps=1.0e-10):
    """Constructor

    Args:
      image: image as np matrix
      x_range: x range to position image at
      y_range: y range to position image at
      eps: small value to add to image to avoid log(0) problems
    """
    super().__init__()
    image_ = np.flip(image, 0).transpose() + eps
    self.image_cpu = torch.tensor(image_ / np.max(image_))
    self.image_size_cpu = self.image_cpu.size()
    self.x_range = torch.tensor(x_range)
    self.y_range = torch.tensor(y_range)

    self.register_buffer("image", self.image_cpu)
    self.register_buffer(
        "image_size", torch.tensor(self.image_size_cpu).unsqueeze(0)
    )
    self.register_buffer(
        "density", torch.log(self.image_cpu / torch.sum(self.image_cpu))
    )
    self.register_buffer(
        "scale",
        torch.tensor(
            [[self.x_range[1] - self.x_range[0], self.y_range[1] - self.y_range[0]]]
        ),
    )
    self.register_buffer(
        "shift", torch.tensor([[self.x_range[0], self.y_range[0]]])
    )
log_prob(z)

Parameters:

Name Type Description Default
z

value or batch of latent variable

required

Returns:

Type Description

log probability of the distribution for z

Source code in normflows/distributions/prior.py
59
60
61
62
63
64
65
66
67
68
69
def log_prob(self, z):
    """
    Args:
      z: value or batch of latent variable

    Returns:
      log probability of the distribution for z
    """
    z_ = torch.clamp((z - self.shift) / self.scale, max=1, min=0)
    ind = (z_ * (self.image_size - 1)).long()
    return self.density[ind[:, 0], ind[:, 1]]
rejection_sampling(num_steps=1)

Perform rejection sampling on image distribution

Parameters:

Name Type Description Default
num_steps

Number of rejection sampling steps to perform

1

Returns:

Type Description

Accepted samples

Source code in normflows/distributions/prior.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def rejection_sampling(self, num_steps=1):
    """Perform rejection sampling on image distribution

    Args:
     num_steps: Number of rejection sampling steps to perform

    Returns:
      Accepted samples
    """
    z_ = torch.rand(
        (num_steps, 2), dtype=self.image.dtype, device=self.image.device
    )
    prob = torch.rand(num_steps, dtype=self.image.dtype, device=self.image.device)
    ind = (z_ * (self.image_size - 1)).long()
    intensity = self.image[ind[:, 0], ind[:, 1]]
    accept = intensity > prob
    z = z_[accept, :] * self.scale + self.shift
    return z
sample(num_samples=1)

Sample from image distribution through rejection sampling

Parameters:

Name Type Description Default
num_samples

Number of samples to draw

1

Returns:

Type Description

Samples

Source code in normflows/distributions/prior.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def sample(self, num_samples=1):
    """Sample from image distribution through rejection sampling

    Args:
      num_samples: Number of samples to draw

    Returns:
      Samples
    """
    z = torch.ones((0, 2), dtype=self.image.dtype, device=self.image.device)
    while len(z) < num_samples:
        z_ = self.rejection_sampling(num_samples)
        ind = np.min([len(z_), num_samples - len(z)])
        z = torch.cat([z, z_[:ind, :]], 0)
    return z

PriorDistribution

Source code in normflows/distributions/prior.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
class PriorDistribution:
    def __init__(self):
        raise NotImplementedError

    def log_prob(self, z):
        """
        Args:
         z: value or batch of latent variable

        Returns:
          log probability of the distribution for z
        """
        raise NotImplementedError
log_prob(z)

Parameters:

Name Type Description Default
z

value or batch of latent variable

required

Returns:

Type Description

log probability of the distribution for z

Source code in normflows/distributions/prior.py
10
11
12
13
14
15
16
17
18
def log_prob(self, z):
    """
    Args:
     z: value or batch of latent variable

    Returns:
      log probability of the distribution for z
    """
    raise NotImplementedError

Sinusoidal

Bases: PriorDistribution

Source code in normflows/distributions/prior.py
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
class Sinusoidal(PriorDistribution):
    def __init__(self, scale, period):
        """Distribution 2d with sinusoidal density
        given by

        ```
        w_1(z) = sin(2*pi / period * z[0])
        log(p) = - 1/2 * ((z[1] - w_1(z)) / (2 * scale)) ** 2
        ```

        Args:
          scale: scale of the distribution, see formula
          period: period of the sinosoidal
        """
        self.scale = scale
        self.period = period

    def log_prob(self, z):
        """

        ```
        log(p) = - 1/2 * ((z[1] - w_1(z)) / (2 * scale)) ** 2
        w_1(z) = sin(2*pi / period * z[0])
        ```

        Args:
          z: value or batch of latent variable

        Returns:
          log probability of the distribution for z
        """
        if z.dim() > 1:
            z_ = z.permute((z.dim() - 1,) + tuple(range(0, z.dim() - 1)))
        else:
            z_ = z

        w_1 = lambda x: torch.sin(2 * np.pi / self.period * z_[0])
        log_prob = (
            -0.5 * ((z_[1] - w_1(z_)) / (self.scale)) ** 2
            - 0.5 * (torch.norm(z_, dim=0, p=4) / (20 * self.scale)) ** 4
        )  # add Gaussian envelope for valid p(z)

        return log_prob
__init__(scale, period)

Distribution 2d with sinusoidal density given by

w_1(z) = sin(2*pi / period * z[0])
log(p) = - 1/2 * ((z[1] - w_1(z)) / (2 * scale)) ** 2

Parameters:

Name Type Description Default
scale

scale of the distribution, see formula

required
period

period of the sinosoidal

required
Source code in normflows/distributions/prior.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
def __init__(self, scale, period):
    """Distribution 2d with sinusoidal density
    given by

    ```
    w_1(z) = sin(2*pi / period * z[0])
    log(p) = - 1/2 * ((z[1] - w_1(z)) / (2 * scale)) ** 2
    ```

    Args:
      scale: scale of the distribution, see formula
      period: period of the sinosoidal
    """
    self.scale = scale
    self.period = period
log_prob(z)
log(p) = - 1/2 * ((z[1] - w_1(z)) / (2 * scale)) ** 2
w_1(z) = sin(2*pi / period * z[0])

Parameters:

Name Type Description Default
z

value or batch of latent variable

required

Returns:

Type Description

log probability of the distribution for z

Source code in normflows/distributions/prior.py
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
def log_prob(self, z):
    """

    ```
    log(p) = - 1/2 * ((z[1] - w_1(z)) / (2 * scale)) ** 2
    w_1(z) = sin(2*pi / period * z[0])
    ```

    Args:
      z: value or batch of latent variable

    Returns:
      log probability of the distribution for z
    """
    if z.dim() > 1:
        z_ = z.permute((z.dim() - 1,) + tuple(range(0, z.dim() - 1)))
    else:
        z_ = z

    w_1 = lambda x: torch.sin(2 * np.pi / self.period * z_[0])
    log_prob = (
        -0.5 * ((z_[1] - w_1(z_)) / (self.scale)) ** 2
        - 0.5 * (torch.norm(z_, dim=0, p=4) / (20 * self.scale)) ** 4
    )  # add Gaussian envelope for valid p(z)

    return log_prob

Sinusoidal_gap

Bases: PriorDistribution

Source code in normflows/distributions/prior.py
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
class Sinusoidal_gap(PriorDistribution):
    def __init__(self, scale, period):
        """Distribution 2d with sinusoidal density with gap
        given by

        ```
        w_1(z) = sin(2*pi / period * z[0])
        w_2(z) = 3 * exp(-0.5 * ((z[0] - 1) / 0.6) ** 2)
        log(p) = -log(exp(-0.5 * ((z[1] - w_1(z)) / 0.35) ** 2) + exp(-0.5 * ((z[1] - w_1(z) + w_2(z)) / 0.35) ** 2))
        ```

        Args:
          loc: distance of modes from the origin
          scale: scale of modes
        """
        self.scale = scale
        self.period = period
        self.w2_scale = 0.6
        self.w2_amp = 3.0
        self.w2_mu = 1.0

    def log_prob(self, z):
        """
        Args:
          z: value or batch of latent variable

        Returns:
          log probability of the distribution for z
        """
        if z.dim() > 1:
            z_ = z.permute((z.dim() - 1,) + tuple(range(0, z.dim() - 1)))
        else:
            z_ = z

        w_1 = lambda x: torch.sin(2 * np.pi / self.period * z_[0])
        w_2 = lambda x: self.w2_amp * torch.exp(
            -0.5 * ((z_[0] - self.w2_mu) / self.w2_scale) ** 2
        )

        eps = torch.abs(w_2(z_) / 2)
        a = torch.abs(z_[1] - w_1(z_) + w_2(z_) / 2)

        log_prob = (
            -0.5 * ((a - eps) / self.scale) ** 2
            + torch.log(1 + torch.exp(-2 * (eps * a) / self.scale**2))
            - 0.5 * (torch.norm(z_, dim=0, p=4) / (20 * self.scale)) ** 4
        )

        return log_prob
__init__(scale, period)

Distribution 2d with sinusoidal density with gap given by

w_1(z) = sin(2*pi / period * z[0])
w_2(z) = 3 * exp(-0.5 * ((z[0] - 1) / 0.6) ** 2)
log(p) = -log(exp(-0.5 * ((z[1] - w_1(z)) / 0.35) ** 2) + exp(-0.5 * ((z[1] - w_1(z) + w_2(z)) / 0.35) ** 2))

Parameters:

Name Type Description Default
loc

distance of modes from the origin

required
scale

scale of modes

required
Source code in normflows/distributions/prior.py
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
def __init__(self, scale, period):
    """Distribution 2d with sinusoidal density with gap
    given by

    ```
    w_1(z) = sin(2*pi / period * z[0])
    w_2(z) = 3 * exp(-0.5 * ((z[0] - 1) / 0.6) ** 2)
    log(p) = -log(exp(-0.5 * ((z[1] - w_1(z)) / 0.35) ** 2) + exp(-0.5 * ((z[1] - w_1(z) + w_2(z)) / 0.35) ** 2))
    ```

    Args:
      loc: distance of modes from the origin
      scale: scale of modes
    """
    self.scale = scale
    self.period = period
    self.w2_scale = 0.6
    self.w2_amp = 3.0
    self.w2_mu = 1.0
log_prob(z)

Parameters:

Name Type Description Default
z

value or batch of latent variable

required

Returns:

Type Description

log probability of the distribution for z

Source code in normflows/distributions/prior.py
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
def log_prob(self, z):
    """
    Args:
      z: value or batch of latent variable

    Returns:
      log probability of the distribution for z
    """
    if z.dim() > 1:
        z_ = z.permute((z.dim() - 1,) + tuple(range(0, z.dim() - 1)))
    else:
        z_ = z

    w_1 = lambda x: torch.sin(2 * np.pi / self.period * z_[0])
    w_2 = lambda x: self.w2_amp * torch.exp(
        -0.5 * ((z_[0] - self.w2_mu) / self.w2_scale) ** 2
    )

    eps = torch.abs(w_2(z_) / 2)
    a = torch.abs(z_[1] - w_1(z_) + w_2(z_) / 2)

    log_prob = (
        -0.5 * ((a - eps) / self.scale) ** 2
        + torch.log(1 + torch.exp(-2 * (eps * a) / self.scale**2))
        - 0.5 * (torch.norm(z_, dim=0, p=4) / (20 * self.scale)) ** 4
    )

    return log_prob

Sinusoidal_split

Bases: PriorDistribution

Source code in normflows/distributions/prior.py
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
class Sinusoidal_split(PriorDistribution):
    def __init__(self, scale, period):
        """Distribution 2d with sinusoidal density with split
        given by

        ```
        w_1(z) = sin(2*pi / period * z[0])
        w_3(z) = 3 * sigmoid((z[0] - 1) / 0.3)
        log(p) = -log(exp(-0.5 * ((z[1] - w_1(z)) / 0.4) ** 2) + exp(-0.5 * ((z[1] - w_1(z) + w_3(z)) / 0.35) ** 2))
        ```

        Args:
          loc: distance of modes from the origin
          scale: scale of modes
        """
        self.scale = scale
        self.period = period
        self.w3_scale = 0.3
        self.w3_amp = 3.0
        self.w3_mu = 1.0

    def log_prob(self, z):
        """
        Args:
          z: value or batch of latent variable

        Returns:
          log probability of the distribution for z
        """
        if z.dim() > 1:
            z_ = z.permute((z.dim() - 1,) + tuple(range(0, z.dim() - 1)))
        else:
            z_ = z

        w_1 = lambda x: torch.sin(2 * np.pi / self.period * z_[0])
        w_3 = lambda x: self.w3_amp * torch.sigmoid(
            (z_[0] - self.w3_mu) / self.w3_scale
        )

        eps = torch.abs(w_3(z_) / 2)
        a = torch.abs(z_[1] - w_1(z_) + w_3(z_) / 2)

        log_prob = (
            -0.5 * ((a - eps) / (self.scale)) ** 2
            + torch.log(1 + torch.exp(-2 * (eps * a) / self.scale**2))
            - 0.5 * (torch.norm(z_, dim=0, p=4) / (20 * self.scale)) ** 4
        )

        return log_prob
__init__(scale, period)

Distribution 2d with sinusoidal density with split given by

w_1(z) = sin(2*pi / period * z[0])
w_3(z) = 3 * sigmoid((z[0] - 1) / 0.3)
log(p) = -log(exp(-0.5 * ((z[1] - w_1(z)) / 0.4) ** 2) + exp(-0.5 * ((z[1] - w_1(z) + w_3(z)) / 0.35) ** 2))

Parameters:

Name Type Description Default
loc

distance of modes from the origin

required
scale

scale of modes

required
Source code in normflows/distributions/prior.py
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
def __init__(self, scale, period):
    """Distribution 2d with sinusoidal density with split
    given by

    ```
    w_1(z) = sin(2*pi / period * z[0])
    w_3(z) = 3 * sigmoid((z[0] - 1) / 0.3)
    log(p) = -log(exp(-0.5 * ((z[1] - w_1(z)) / 0.4) ** 2) + exp(-0.5 * ((z[1] - w_1(z) + w_3(z)) / 0.35) ** 2))
    ```

    Args:
      loc: distance of modes from the origin
      scale: scale of modes
    """
    self.scale = scale
    self.period = period
    self.w3_scale = 0.3
    self.w3_amp = 3.0
    self.w3_mu = 1.0
log_prob(z)

Parameters:

Name Type Description Default
z

value or batch of latent variable

required

Returns:

Type Description

log probability of the distribution for z

Source code in normflows/distributions/prior.py
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
def log_prob(self, z):
    """
    Args:
      z: value or batch of latent variable

    Returns:
      log probability of the distribution for z
    """
    if z.dim() > 1:
        z_ = z.permute((z.dim() - 1,) + tuple(range(0, z.dim() - 1)))
    else:
        z_ = z

    w_1 = lambda x: torch.sin(2 * np.pi / self.period * z_[0])
    w_3 = lambda x: self.w3_amp * torch.sigmoid(
        (z_[0] - self.w3_mu) / self.w3_scale
    )

    eps = torch.abs(w_3(z_) / 2)
    a = torch.abs(z_[1] - w_1(z_) + w_3(z_) / 2)

    log_prob = (
        -0.5 * ((a - eps) / (self.scale)) ** 2
        + torch.log(1 + torch.exp(-2 * (eps * a) / self.scale**2))
        - 0.5 * (torch.norm(z_, dim=0, p=4) / (20 * self.scale)) ** 4
    )

    return log_prob

Smiley

Bases: PriorDistribution

Source code in normflows/distributions/prior.py
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
class Smiley(PriorDistribution):
    def __init__(self, scale):
        """Distribution 2d of a smiley :)

        Args:
          scale: scale of the smiley
        """
        self.scale = scale
        self.loc = 2.0

    def log_prob(self, z):
        """
        Args:
          z: value or batch of latent variable

        Returns:
          log probability of the distribution for z
        """
        if z.dim() > 1:
            z_ = z.permute((z.dim() - 1,) + tuple(range(0, z.dim() - 1)))
        else:
            z_ = z

        log_prob = (
            -0.5 * ((torch.norm(z_, dim=0) - self.loc) / (2 * self.scale)) ** 2
            - 0.5 * ((torch.abs(z_[1] + 0.8) - 1.2) / (2 * self.scale)) ** 2
        )

        return log_prob
__init__(scale)

Distribution 2d of a smiley :)

Parameters:

Name Type Description Default
scale

scale of the smiley

required
Source code in normflows/distributions/prior.py
300
301
302
303
304
305
306
307
def __init__(self, scale):
    """Distribution 2d of a smiley :)

    Args:
      scale: scale of the smiley
    """
    self.scale = scale
    self.loc = 2.0
log_prob(z)

Parameters:

Name Type Description Default
z

value or batch of latent variable

required

Returns:

Type Description

log probability of the distribution for z

Source code in normflows/distributions/prior.py
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
def log_prob(self, z):
    """
    Args:
      z: value or batch of latent variable

    Returns:
      log probability of the distribution for z
    """
    if z.dim() > 1:
        z_ = z.permute((z.dim() - 1,) + tuple(range(0, z.dim() - 1)))
    else:
        z_ = z

    log_prob = (
        -0.5 * ((torch.norm(z_, dim=0) - self.loc) / (2 * self.scale)) ** 2
        - 0.5 * ((torch.abs(z_[1] + 0.8) - 1.2) / (2 * self.scale)) ** 2
    )

    return log_prob

TwoModes

Bases: PriorDistribution

Source code in normflows/distributions/prior.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
class TwoModes(PriorDistribution):
    def __init__(self, loc, scale):
        """Distribution 2d with two modes

        Distribution 2d with two modes at
        ```z[0] = -loc```  and ```z[0] = loc```
        following the density
        ```
        log(p) = 1/2 * ((norm(z) - loc) / (2 * scale)) ** 2
                - log(exp(-1/2 * ((z[0] - loc) / (3 * scale)) ** 2) + exp(-1/2 * ((z[0] + loc) / (3 * scale)) ** 2))
        ```

        Args:
          loc: distance of modes from the origin
          scale: scale of modes
        """
        self.loc = loc
        self.scale = scale

    def log_prob(self, z):
        """

        ```
        log(p) = 1/2 * ((norm(z) - loc) / (2 * scale)) ** 2
                - log(exp(-1/2 * ((z[0] - loc) / (3 * scale)) ** 2) + exp(-1/2 * ((z[0] + loc) / (3 * scale)) ** 2))
        ```

        Args:
          z: value or batch of latent variable

        Returns:
          log probability of the distribution for z
        """
        a = torch.abs(z[:, 0])
        eps = torch.abs(torch.tensor(self.loc))

        log_prob = (
            -0.5 * ((torch.norm(z, dim=1) - self.loc) / (2 * self.scale)) ** 2
            - 0.5 * ((a - eps) / (3 * self.scale)) ** 2
            + torch.log(1 + torch.exp(-2 * (a * eps) / (3 * self.scale) ** 2))
        )

        return log_prob
__init__(loc, scale)

Distribution 2d with two modes

Distribution 2d with two modes at z[0] = -loc and z[0] = loc following the density

log(p) = 1/2 * ((norm(z) - loc) / (2 * scale)) ** 2
        - log(exp(-1/2 * ((z[0] - loc) / (3 * scale)) ** 2) + exp(-1/2 * ((z[0] + loc) / (3 * scale)) ** 2))

Args: loc: distance of modes from the origin scale: scale of modes

Source code in normflows/distributions/prior.py
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def __init__(self, loc, scale):
    """Distribution 2d with two modes

    Distribution 2d with two modes at
    ```z[0] = -loc```  and ```z[0] = loc```
    following the density
    ```
    log(p) = 1/2 * ((norm(z) - loc) / (2 * scale)) ** 2
            - log(exp(-1/2 * ((z[0] - loc) / (3 * scale)) ** 2) + exp(-1/2 * ((z[0] + loc) / (3 * scale)) ** 2))
    ```

    Args:
      loc: distance of modes from the origin
      scale: scale of modes
    """
    self.loc = loc
    self.scale = scale
log_prob(z)
log(p) = 1/2 * ((norm(z) - loc) / (2 * scale)) ** 2
        - log(exp(-1/2 * ((z[0] - loc) / (3 * scale)) ** 2) + exp(-1/2 * ((z[0] + loc) / (3 * scale)) ** 2))

Parameters:

Name Type Description Default
z

value or batch of latent variable

required

Returns:

Type Description

log probability of the distribution for z

Source code in normflows/distributions/prior.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
def log_prob(self, z):
    """

    ```
    log(p) = 1/2 * ((norm(z) - loc) / (2 * scale)) ** 2
            - log(exp(-1/2 * ((z[0] - loc) / (3 * scale)) ** 2) + exp(-1/2 * ((z[0] + loc) / (3 * scale)) ** 2))
    ```

    Args:
      z: value or batch of latent variable

    Returns:
      log probability of the distribution for z
    """
    a = torch.abs(z[:, 0])
    eps = torch.abs(torch.tensor(self.loc))

    log_prob = (
        -0.5 * ((torch.norm(z, dim=1) - self.loc) / (2 * self.scale)) ** 2
        - 0.5 * ((a - eps) / (3 * self.scale)) ** 2
        + torch.log(1 + torch.exp(-2 * (a * eps) / (3 * self.scale) ** 2))
    )

    return log_prob

target

CircularGaussianMixture

Bases: Module

Two-dimensional Gaussian mixture arranged in a circle

Source code in normflows/distributions/target.py
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
class CircularGaussianMixture(nn.Module):
    """
    Two-dimensional Gaussian mixture arranged in a circle
    """

    def __init__(self, n_modes=8):
        """Constructor

        Args:
          n_modes: Number of modes
        """
        super(CircularGaussianMixture, self).__init__()
        self.n_modes = n_modes
        self.register_buffer(
            "scale", torch.tensor(2 / 3 * np.sin(np.pi / self.n_modes)).float()
        )

    def log_prob(self, z):
        d = torch.zeros((len(z), 0), dtype=z.dtype, device=z.device)
        for i in range(self.n_modes):
            d_ = (
                (z[:, 0] - 2 * np.sin(2 * np.pi / self.n_modes * i)) ** 2
                + (z[:, 1] - 2 * np.cos(2 * np.pi / self.n_modes * i)) ** 2
            ) / (2 * self.scale**2)
            d = torch.cat((d, d_[:, None]), 1)
        log_p = -torch.log(
            2 * np.pi * self.scale**2 * self.n_modes
        ) + torch.logsumexp(-d, 1)
        return log_p

    def sample(self, num_samples=1):
        eps = torch.randn(
            (num_samples, 2), dtype=self.scale.dtype, device=self.scale.device
        )
        phi = (
            2
            * np.pi
            / self.n_modes
            * torch.randint(0, self.n_modes, (num_samples,), device=self.scale.device)
        )
        loc = torch.stack((2 * torch.sin(phi), 2 * torch.cos(phi)), 1).type(eps.dtype)
        return eps * self.scale + loc
__init__(n_modes=8)

Constructor

Parameters:

Name Type Description Default
n_modes

Number of modes

8
Source code in normflows/distributions/target.py
137
138
139
140
141
142
143
144
145
146
147
def __init__(self, n_modes=8):
    """Constructor

    Args:
      n_modes: Number of modes
    """
    super(CircularGaussianMixture, self).__init__()
    self.n_modes = n_modes
    self.register_buffer(
        "scale", torch.tensor(2 / 3 * np.sin(np.pi / self.n_modes)).float()
    )

ConditionalDiagGaussian

Bases: Target

Gaussian distribution conditioned on its mean and standard deviation

The first half of the entries of the condition, also called context, are the mean, while the second half are the standard deviation.

Source code in normflows/distributions/target.py
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
class ConditionalDiagGaussian(Target):
    """
    Gaussian distribution conditioned on its mean and standard
    deviation

    The first half of the entries of the condition, also called context,
    are the mean, while the second half are the standard deviation.
    """
    def log_prob(self, z, context=None):
        d = z.shape[-1]
        loc = context[:, :d]
        scale = context[:, d:]
        log_p = -0.5 * d * np.log(2 * np.pi) - torch.sum(
            torch.log(scale) + 0.5 * torch.pow((z - loc) / scale, 2),
            dim=-1
        )
        return log_p

    def sample(self, num_samples=1, context=None):
        d = context.shape[-1] // 2
        loc = context[:, :d]
        scale = context[:, d:]
        eps = torch.randn(
            (num_samples, d), dtype=context.dtype, device=context.device
        )
        z = loc + scale * eps
        return z

RingMixture

Bases: Target

Mixture of ring distributions in two dimensions

Source code in normflows/distributions/target.py
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
class RingMixture(Target):
    """
    Mixture of ring distributions in two dimensions
    """

    def __init__(self, n_rings=2):
        super().__init__()
        self.n_dims = 2
        self.max_log_prob = 0.0
        self.n_rings = n_rings
        self.scale = 1 / 4 / self.n_rings

    def log_prob(self, z):
        d = torch.zeros((len(z), 0), dtype=z.dtype, device=z.device)
        for i in range(self.n_rings):
            d_ = ((torch.norm(z, dim=1) - 2 / self.n_rings * (i + 1)) ** 2) / (
                2 * self.scale**2
            )
            d = torch.cat((d, d_[:, None]), 1)
        return torch.logsumexp(-d, 1)

Target

Bases: Module

Sample target distributions to test models

Source code in normflows/distributions/target.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
class Target(nn.Module):
    """
    Sample target distributions to test models
    """

    def __init__(self, prop_scale=torch.tensor(6.0), prop_shift=torch.tensor(-3.0)):
        """Constructor

        Args:
          prop_scale: Scale for the uniform proposal
          prop_shift: Shift for the uniform proposal
        """
        super().__init__()
        self.register_buffer("prop_scale", prop_scale)
        self.register_buffer("prop_shift", prop_shift)

    def log_prob(self, z):
        """
        Args:
          z: value or batch of latent variable

        Returns:
          log probability of the distribution for z
        """
        raise NotImplementedError("The log probability is not implemented yet.")

    def rejection_sampling(self, num_steps=1):
        """Perform rejection sampling on image distribution

        Args:
          num_steps: Number of rejection sampling steps to perform

        Returns:
          Accepted samples
        """
        eps = torch.rand(
            (num_steps, self.n_dims),
            dtype=self.prop_scale.dtype,
            device=self.prop_scale.device,
        )
        z_ = self.prop_scale * eps + self.prop_shift
        prob = torch.rand(
            num_steps, dtype=self.prop_scale.dtype, device=self.prop_scale.device
        )
        prob_ = torch.exp(self.log_prob(z_) - self.max_log_prob)
        accept = prob_ > prob
        z = z_[accept, :]
        return z

    def sample(self, num_samples=1):
        """Sample from image distribution through rejection sampling

        Args:
          num_samples: Number of samples to draw

        Returns:
          Samples
        """
        z = torch.zeros(
            (0, self.n_dims), dtype=self.prop_scale.dtype, device=self.prop_scale.device
        )
        while len(z) < num_samples:
            z_ = self.rejection_sampling(num_samples)
            ind = np.min([len(z_), num_samples - len(z)])
            z = torch.cat([z, z_[:ind, :]], 0)
        return z
__init__(prop_scale=torch.tensor(6.0), prop_shift=torch.tensor(-3.0))

Constructor

Parameters:

Name Type Description Default
prop_scale

Scale for the uniform proposal

tensor(6.0)
prop_shift

Shift for the uniform proposal

tensor(-3.0)
Source code in normflows/distributions/target.py
13
14
15
16
17
18
19
20
21
22
def __init__(self, prop_scale=torch.tensor(6.0), prop_shift=torch.tensor(-3.0)):
    """Constructor

    Args:
      prop_scale: Scale for the uniform proposal
      prop_shift: Shift for the uniform proposal
    """
    super().__init__()
    self.register_buffer("prop_scale", prop_scale)
    self.register_buffer("prop_shift", prop_shift)
log_prob(z)

Parameters:

Name Type Description Default
z

value or batch of latent variable

required

Returns:

Type Description

log probability of the distribution for z

Source code in normflows/distributions/target.py
24
25
26
27
28
29
30
31
32
def log_prob(self, z):
    """
    Args:
      z: value or batch of latent variable

    Returns:
      log probability of the distribution for z
    """
    raise NotImplementedError("The log probability is not implemented yet.")
rejection_sampling(num_steps=1)

Perform rejection sampling on image distribution

Parameters:

Name Type Description Default
num_steps

Number of rejection sampling steps to perform

1

Returns:

Type Description

Accepted samples

Source code in normflows/distributions/target.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def rejection_sampling(self, num_steps=1):
    """Perform rejection sampling on image distribution

    Args:
      num_steps: Number of rejection sampling steps to perform

    Returns:
      Accepted samples
    """
    eps = torch.rand(
        (num_steps, self.n_dims),
        dtype=self.prop_scale.dtype,
        device=self.prop_scale.device,
    )
    z_ = self.prop_scale * eps + self.prop_shift
    prob = torch.rand(
        num_steps, dtype=self.prop_scale.dtype, device=self.prop_scale.device
    )
    prob_ = torch.exp(self.log_prob(z_) - self.max_log_prob)
    accept = prob_ > prob
    z = z_[accept, :]
    return z
sample(num_samples=1)

Sample from image distribution through rejection sampling

Parameters:

Name Type Description Default
num_samples

Number of samples to draw

1

Returns:

Type Description

Samples

Source code in normflows/distributions/target.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def sample(self, num_samples=1):
    """Sample from image distribution through rejection sampling

    Args:
      num_samples: Number of samples to draw

    Returns:
      Samples
    """
    z = torch.zeros(
        (0, self.n_dims), dtype=self.prop_scale.dtype, device=self.prop_scale.device
    )
    while len(z) < num_samples:
        z_ = self.rejection_sampling(num_samples)
        ind = np.min([len(z_), num_samples - len(z)])
        z = torch.cat([z, z_[:ind, :]], 0)
    return z

TwoIndependent

Bases: Target

Target distribution that combines two independent distributions of equal size into one distribution. This is needed for Augmented Normalizing Flows, see https://arxiv.org/abs/2002.07101

Source code in normflows/distributions/target.py
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
class TwoIndependent(Target):
    """
    Target distribution that combines two independent distributions of equal
    size into one distribution. This is needed for Augmented Normalizing Flows,
    see https://arxiv.org/abs/2002.07101
    """

    def __init__(self, target1, target2):
        super().__init__()
        self.target1 = target1
        self.target2 = target2
        self.split = Split(mode='channel')

    def log_prob(self, z):
        z1, z2 = self.split(z)[0]
        return self.target1.log_prob(z1) + self.target2.log_prob(z2)

    def sample(self, num_samples=1):
        z1 = self.target1.sample(num_samples)
        z2 = self.target2.sample(num_samples)
        return self.split.inverse([z1, z2])[0]

TwoMoons

Bases: Target

Bimodal two-dimensional distribution

Source code in normflows/distributions/target.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
class TwoMoons(Target):
    """
    Bimodal two-dimensional distribution
    """

    def __init__(self):
        super().__init__()
        self.n_dims = 2
        self.max_log_prob = 0.0

    def log_prob(self, z):
        """
        ```
        log(p) = - 1/2 * ((norm(z) - 2) / 0.2) ** 2
                 + log(  exp(-1/2 * ((z[0] - 2) / 0.3) ** 2)
                       + exp(-1/2 * ((z[0] + 2) / 0.3) ** 2))
        ```

        Args:
          z: value or batch of latent variable

        Returns:
          log probability of the distribution for z
        """
        a = torch.abs(z[:, 0])
        log_prob = (
            -0.5 * ((torch.norm(z, dim=1) - 2) / 0.2) ** 2
            - 0.5 * ((a - 2) / 0.3) ** 2
            + torch.log(1 + torch.exp(-4 * a / 0.09))
        )
        return log_prob
log_prob(z)
log(p) = - 1/2 * ((norm(z) - 2) / 0.2) ** 2
         + log(  exp(-1/2 * ((z[0] - 2) / 0.3) ** 2)
               + exp(-1/2 * ((z[0] + 2) / 0.3) ** 2))

Parameters:

Name Type Description Default
z

value or batch of latent variable

required

Returns:

Type Description

log probability of the distribution for z

Source code in normflows/distributions/target.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
def log_prob(self, z):
    """
    ```
    log(p) = - 1/2 * ((norm(z) - 2) / 0.2) ** 2
             + log(  exp(-1/2 * ((z[0] - 2) / 0.3) ** 2)
                   + exp(-1/2 * ((z[0] + 2) / 0.3) ** 2))
    ```

    Args:
      z: value or batch of latent variable

    Returns:
      log probability of the distribution for z
    """
    a = torch.abs(z[:, 0])
    log_prob = (
        -0.5 * ((torch.norm(z, dim=1) - 2) / 0.2) ** 2
        - 0.5 * ((a - 2) / 0.3) ** 2
        + torch.log(1 + torch.exp(-4 * a / 0.09))
    )
    return log_prob

flows

affine

autoregressive

Autoregressive

Bases: Flow

Transforms each input variable with an invertible elementwise transformation.

The parameters of each invertible elementwise transformation can be functions of previous input variables, but they must not depend on the current or any following input variables.

NOTE Calculating the inverse transform is D times slower than calculating the forward transform, where D is the dimensionality of the input to the transform.

Source code in normflows/flows/affine/autoregressive.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class Autoregressive(Flow):
    """Transforms each input variable with an invertible elementwise transformation.

    The parameters of each invertible elementwise transformation can be functions of previous input
    variables, but they must not depend on the current or any following input variables.

    **NOTE** Calculating the inverse transform is D times slower than calculating the
    forward transform, where D is the dimensionality of the input to the transform.
    """

    def __init__(self, autoregressive_net):
        super(Autoregressive, self).__init__()
        self.autoregressive_net = autoregressive_net

    def forward(self, inputs, context=None):
        autoregressive_params = self.autoregressive_net(inputs, context)
        outputs, logabsdet = self._elementwise_forward(inputs, autoregressive_params)
        return outputs, logabsdet

    def inverse(self, inputs, context=None):
        num_inputs = np.prod(inputs.shape[1:])
        outputs = torch.zeros_like(inputs)
        logabsdet = None
        for _ in range(num_inputs):
            autoregressive_params = self.autoregressive_net(outputs, context)
            outputs, logabsdet = self._elementwise_inverse(
                inputs, autoregressive_params
            )
        return outputs, logabsdet

    def _output_dim_multiplier(self):
        raise NotImplementedError()

    def _elementwise_forward(self, inputs, autoregressive_params):
        raise NotImplementedError()

    def _elementwise_inverse(self, inputs, autoregressive_params):
        raise NotImplementedError()
MaskedAffineAutoregressive

Bases: Autoregressive

Masked affine autoregressive flow, mostly referred to as Masked Autoregressive Flow (MAF), see arXiv 1705.07057.

Source code in normflows/flows/affine/autoregressive.py
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
class MaskedAffineAutoregressive(Autoregressive):
    """ Masked affine autoregressive flow, mostly referred to as
    Masked Autoregressive Flow (MAF), see
    [arXiv 1705.07057](https://arxiv.org/abs/1705.07057).
    """
    def __init__(
        self,
        features,
        hidden_features,
        context_features=None,
        num_blocks=2,
        use_residual_blocks=True,
        random_mask=False,
        activation=F.relu,
        dropout_probability=0.0,
        use_batch_norm=False,
    ):
        """Constructor

        Args:
          features: Number of features/input dimensions
          hidden_features: Number of hidden units in the MADE network
          context_features: Number of context/conditional features
          num_blocks: Number of blocks in the MADE network
          use_residual_blocks: Flag whether residual blocks should be used
          random_mask: Flag whether to use random masks
          activation: Activation function to be used in the MADE network
          dropout_probability: Dropout probability in the MADE network
          use_batch_norm: Flag whether batch normalization should be used
        """
        self.features = features
        made = made_module.MADE(
            features=features,
            hidden_features=hidden_features,
            context_features=context_features,
            num_blocks=num_blocks,
            output_multiplier=self._output_dim_multiplier(),
            use_residual_blocks=use_residual_blocks,
            random_mask=random_mask,
            activation=activation,
            dropout_probability=dropout_probability,
            use_batch_norm=use_batch_norm,
        )
        super(MaskedAffineAutoregressive, self).__init__(made)

    def _output_dim_multiplier(self):
        return 2

    def _elementwise_forward(self, inputs, autoregressive_params):
        unconstrained_scale, shift = self._unconstrained_scale_and_shift(
            autoregressive_params
        )
        scale = torch.sigmoid(unconstrained_scale + 2.0) + 1e-3
        log_scale = torch.log(scale)
        outputs = scale * inputs + shift
        logabsdet = utils.sum_except_batch(log_scale, num_batch_dims=1)
        return outputs, logabsdet

    def _elementwise_inverse(self, inputs, autoregressive_params):
        unconstrained_scale, shift = self._unconstrained_scale_and_shift(
            autoregressive_params
        )
        scale = torch.sigmoid(unconstrained_scale + 2.0) + 1e-3
        log_scale = torch.log(scale)
        outputs = (inputs - shift) / scale
        logabsdet = -utils.sum_except_batch(log_scale, num_batch_dims=1)
        return outputs, logabsdet

    def _unconstrained_scale_and_shift(self, autoregressive_params):
        # split_idx = autoregressive_params.size(1) // 2
        # unconstrained_scale = autoregressive_params[..., :split_idx]
        # shift = autoregressive_params[..., split_idx:]
        # return unconstrained_scale, shift
        autoregressive_params = autoregressive_params.view(
            -1, self.features, self._output_dim_multiplier()
        )
        unconstrained_scale = autoregressive_params[..., 0]
        shift = autoregressive_params[..., 1]
        return unconstrained_scale, shift
__init__(features, hidden_features, context_features=None, num_blocks=2, use_residual_blocks=True, random_mask=False, activation=F.relu, dropout_probability=0.0, use_batch_norm=False)

Constructor

Parameters:

Name Type Description Default
features

Number of features/input dimensions

required
hidden_features

Number of hidden units in the MADE network

required
context_features

Number of context/conditional features

None
num_blocks

Number of blocks in the MADE network

2
use_residual_blocks

Flag whether residual blocks should be used

True
random_mask

Flag whether to use random masks

False
activation

Activation function to be used in the MADE network

relu
dropout_probability

Dropout probability in the MADE network

0.0
use_batch_norm

Flag whether batch normalization should be used

False
Source code in normflows/flows/affine/autoregressive.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
def __init__(
    self,
    features,
    hidden_features,
    context_features=None,
    num_blocks=2,
    use_residual_blocks=True,
    random_mask=False,
    activation=F.relu,
    dropout_probability=0.0,
    use_batch_norm=False,
):
    """Constructor

    Args:
      features: Number of features/input dimensions
      hidden_features: Number of hidden units in the MADE network
      context_features: Number of context/conditional features
      num_blocks: Number of blocks in the MADE network
      use_residual_blocks: Flag whether residual blocks should be used
      random_mask: Flag whether to use random masks
      activation: Activation function to be used in the MADE network
      dropout_probability: Dropout probability in the MADE network
      use_batch_norm: Flag whether batch normalization should be used
    """
    self.features = features
    made = made_module.MADE(
        features=features,
        hidden_features=hidden_features,
        context_features=context_features,
        num_blocks=num_blocks,
        output_multiplier=self._output_dim_multiplier(),
        use_residual_blocks=use_residual_blocks,
        random_mask=random_mask,
        activation=activation,
        dropout_probability=dropout_probability,
        use_batch_norm=use_batch_norm,
    )
    super(MaskedAffineAutoregressive, self).__init__(made)

coupling

AffineConstFlow

Bases: Flow

scales and shifts with learned constants per dimension. In the NICE paper there is a scaling layer which is a special case of this where t is None

Source code in normflows/flows/affine/coupling.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class AffineConstFlow(Flow):
    """
    scales and shifts with learned constants per dimension. In the NICE paper there is a
    scaling layer which is a special case of this where t is None
    """

    def __init__(self, shape, scale=True, shift=True):
        """Constructor

        Args:
          shape: Shape of the coupling layer
          scale: Flag whether to apply scaling
          shift: Flag whether to apply shift
          logscale_factor: Optional factor which can be used to control the scale of the log scale factor
        """
        super().__init__()
        if scale:
            self.s = nn.Parameter(torch.zeros(shape)[None])
        else:
            self.register_buffer("s", torch.zeros(shape)[None])
        if shift:
            self.t = nn.Parameter(torch.zeros(shape)[None])
        else:
            self.register_buffer("t", torch.zeros(shape)[None])
        self.n_dim = self.s.dim()
        self.batch_dims = torch.nonzero(
            torch.tensor(self.s.shape) == 1, as_tuple=False
        )[:, 0].tolist()

    def forward(self, z):
        z_ = z * torch.exp(self.s) + self.t
        if len(self.batch_dims) > 1:
            prod_batch_dims = np.prod([z.size(i) for i in self.batch_dims[1:]])
        else:
            prod_batch_dims = 1
        log_det = prod_batch_dims * torch.sum(self.s)
        return z_, log_det

    def inverse(self, z):
        z_ = (z - self.t) * torch.exp(-self.s)
        if len(self.batch_dims) > 1:
            prod_batch_dims = np.prod([z.size(i) for i in self.batch_dims[1:]])
        else:
            prod_batch_dims = 1
        log_det = -prod_batch_dims * torch.sum(self.s)
        return z_, log_det
__init__(shape, scale=True, shift=True)

Constructor

Parameters:

Name Type Description Default
shape

Shape of the coupling layer

required
scale

Flag whether to apply scaling

True
shift

Flag whether to apply shift

True
logscale_factor

Optional factor which can be used to control the scale of the log scale factor

required
Source code in normflows/flows/affine/coupling.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def __init__(self, shape, scale=True, shift=True):
    """Constructor

    Args:
      shape: Shape of the coupling layer
      scale: Flag whether to apply scaling
      shift: Flag whether to apply shift
      logscale_factor: Optional factor which can be used to control the scale of the log scale factor
    """
    super().__init__()
    if scale:
        self.s = nn.Parameter(torch.zeros(shape)[None])
    else:
        self.register_buffer("s", torch.zeros(shape)[None])
    if shift:
        self.t = nn.Parameter(torch.zeros(shape)[None])
    else:
        self.register_buffer("t", torch.zeros(shape)[None])
    self.n_dim = self.s.dim()
    self.batch_dims = torch.nonzero(
        torch.tensor(self.s.shape) == 1, as_tuple=False
    )[:, 0].tolist()
AffineCoupling

Bases: Flow

Affine Coupling layer as introduced RealNVP paper, see arXiv: 1605.08803

Source code in normflows/flows/affine/coupling.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
class AffineCoupling(Flow):
    """
    Affine Coupling layer as introduced RealNVP paper, see arXiv: 1605.08803
    """

    def __init__(self, param_map, scale=True, scale_map="exp"):
        """Constructor

        Args:
          param_map: Maps features to shift and scale parameter (if applicable)
          scale: Flag whether scale shall be applied
          scale_map: Map to be applied to the scale parameter, can be 'exp' as in RealNVP or 'sigmoid' as in Glow, 'sigmoid_inv' uses multiplicative sigmoid scale when sampling from the model
        """
        super().__init__()
        self.add_module("param_map", param_map)
        self.scale = scale
        self.scale_map = scale_map

    def forward(self, z):
        """
        z is a list of z1 and z2; ```z = [z1, z2]```
        z1 is left constant and affine map is applied to z2 with parameters depending
        on z1

        Args:
          z
        """
        z1, z2 = z
        param = self.param_map(z1)
        if self.scale:
            shift = param[:, 0::2, ...]
            scale_ = param[:, 1::2, ...]
            if self.scale_map == "exp":
                z2 = z2 * torch.exp(scale_) + shift
                log_det = torch.sum(scale_, dim=list(range(1, shift.dim())))
            elif self.scale_map == "sigmoid":
                scale = torch.sigmoid(scale_ + 2)
                z2 = z2 / scale + shift
                log_det = -torch.sum(torch.log(scale), dim=list(range(1, shift.dim())))
            elif self.scale_map == "sigmoid_inv":
                scale = torch.sigmoid(scale_ + 2)
                z2 = z2 * scale + shift
                log_det = torch.sum(torch.log(scale), dim=list(range(1, shift.dim())))
            else:
                raise NotImplementedError("This scale map is not implemented.")
        else:
            z2 = z2 + param
            log_det = zero_log_det_like_z(z2)
        return [z1, z2], log_det

    def inverse(self, z):
        z1, z2 = z
        param = self.param_map(z1)
        if self.scale:
            shift = param[:, 0::2, ...]
            scale_ = param[:, 1::2, ...]
            if self.scale_map == "exp":
                z2 = (z2 - shift) * torch.exp(-scale_)
                log_det = -torch.sum(scale_, dim=list(range(1, shift.dim())))
            elif self.scale_map == "sigmoid":
                scale = torch.sigmoid(scale_ + 2)
                z2 = (z2 - shift) * scale
                log_det = torch.sum(torch.log(scale), dim=list(range(1, shift.dim())))
            elif self.scale_map == "sigmoid_inv":
                scale = torch.sigmoid(scale_ + 2)
                z2 = (z2 - shift) / scale
                log_det = -torch.sum(torch.log(scale), dim=list(range(1, shift.dim())))
            else:
                raise NotImplementedError("This scale map is not implemented.")
        else:
            z2 = z2 - param
            log_det = zero_log_det_like_z(z2)
        return [z1, z2], log_det
__init__(param_map, scale=True, scale_map='exp')

Constructor

Parameters:

Name Type Description Default
param_map

Maps features to shift and scale parameter (if applicable)

required
scale

Flag whether scale shall be applied

True
scale_map

Map to be applied to the scale parameter, can be 'exp' as in RealNVP or 'sigmoid' as in Glow, 'sigmoid_inv' uses multiplicative sigmoid scale when sampling from the model

'exp'
Source code in normflows/flows/affine/coupling.py
104
105
106
107
108
109
110
111
112
113
114
115
def __init__(self, param_map, scale=True, scale_map="exp"):
    """Constructor

    Args:
      param_map: Maps features to shift and scale parameter (if applicable)
      scale: Flag whether scale shall be applied
      scale_map: Map to be applied to the scale parameter, can be 'exp' as in RealNVP or 'sigmoid' as in Glow, 'sigmoid_inv' uses multiplicative sigmoid scale when sampling from the model
    """
    super().__init__()
    self.add_module("param_map", param_map)
    self.scale = scale
    self.scale_map = scale_map
forward(z)

z is a list of z1 and z2; z = [z1, z2] z1 is left constant and affine map is applied to z2 with parameters depending on z1

Source code in normflows/flows/affine/coupling.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
def forward(self, z):
    """
    z is a list of z1 and z2; ```z = [z1, z2]```
    z1 is left constant and affine map is applied to z2 with parameters depending
    on z1

    Args:
      z
    """
    z1, z2 = z
    param = self.param_map(z1)
    if self.scale:
        shift = param[:, 0::2, ...]
        scale_ = param[:, 1::2, ...]
        if self.scale_map == "exp":
            z2 = z2 * torch.exp(scale_) + shift
            log_det = torch.sum(scale_, dim=list(range(1, shift.dim())))
        elif self.scale_map == "sigmoid":
            scale = torch.sigmoid(scale_ + 2)
            z2 = z2 / scale + shift
            log_det = -torch.sum(torch.log(scale), dim=list(range(1, shift.dim())))
        elif self.scale_map == "sigmoid_inv":
            scale = torch.sigmoid(scale_ + 2)
            z2 = z2 * scale + shift
            log_det = torch.sum(torch.log(scale), dim=list(range(1, shift.dim())))
        else:
            raise NotImplementedError("This scale map is not implemented.")
    else:
        z2 = z2 + param
        log_det = zero_log_det_like_z(z2)
    return [z1, z2], log_det
AffineCouplingBlock

Bases: Flow

Affine Coupling layer including split and merge operation

Source code in normflows/flows/affine/coupling.py
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
class AffineCouplingBlock(Flow):
    """
    Affine Coupling layer including split and merge operation
    """

    def __init__(self, param_map, scale=True, scale_map="exp", split_mode="channel"):
        """Constructor

        Args:
          param_map: Maps features to shift and scale parameter (if applicable)
          scale: Flag whether scale shall be applied
          scale_map: Map to be applied to the scale parameter, can be 'exp' as in RealNVP or 'sigmoid' as in Glow
          split_mode: Splitting mode, for possible values see Split class
        """
        super().__init__()
        self.flows = nn.ModuleList([])
        # Split layer
        self.flows += [Split(split_mode)]
        # Affine coupling layer
        self.flows += [AffineCoupling(param_map, scale, scale_map)]
        # Merge layer
        self.flows += [Merge(split_mode)]

    def forward(self, z):
        log_det_tot = torch.zeros(z.shape[0], dtype=z.dtype, device=z.device)
        for flow in self.flows:
            z, log_det = flow(z)
            log_det_tot += log_det
        return z, log_det_tot

    def inverse(self, z):
        log_det_tot = torch.zeros(z.shape[0], dtype=z.dtype, device=z.device)
        for i in range(len(self.flows) - 1, -1, -1):
            z, log_det = self.flows[i].inverse(z)
            log_det_tot += log_det
        return z, log_det_tot
__init__(param_map, scale=True, scale_map='exp', split_mode='channel')

Constructor

Parameters:

Name Type Description Default
param_map

Maps features to shift and scale parameter (if applicable)

required
scale

Flag whether scale shall be applied

True
scale_map

Map to be applied to the scale parameter, can be 'exp' as in RealNVP or 'sigmoid' as in Glow

'exp'
split_mode

Splitting mode, for possible values see Split class

'channel'
Source code in normflows/flows/affine/coupling.py
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
def __init__(self, param_map, scale=True, scale_map="exp", split_mode="channel"):
    """Constructor

    Args:
      param_map: Maps features to shift and scale parameter (if applicable)
      scale: Flag whether scale shall be applied
      scale_map: Map to be applied to the scale parameter, can be 'exp' as in RealNVP or 'sigmoid' as in Glow
      split_mode: Splitting mode, for possible values see Split class
    """
    super().__init__()
    self.flows = nn.ModuleList([])
    # Split layer
    self.flows += [Split(split_mode)]
    # Affine coupling layer
    self.flows += [AffineCoupling(param_map, scale, scale_map)]
    # Merge layer
    self.flows += [Merge(split_mode)]
CCAffineConst

Bases: Flow

Affine constant flow layer with class-conditional parameters

Source code in normflows/flows/affine/coupling.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
class CCAffineConst(Flow):
    """
    Affine constant flow layer with class-conditional parameters
    """

    def __init__(self, shape, num_classes):
        super().__init__()
        if isinstance(shape, int):
            shape = (shape,)
        self.shape = shape
        self.s = nn.Parameter(torch.zeros(shape)[None])
        self.t = nn.Parameter(torch.zeros(shape)[None])
        self.s_cc = nn.Parameter(torch.zeros(num_classes, np.prod(shape)))
        self.t_cc = nn.Parameter(torch.zeros(num_classes, np.prod(shape)))
        self.n_dim = self.s.dim()
        self.batch_dims = torch.nonzero(
            torch.tensor(self.s.shape) == 1, as_tuple=False
        )[:, 0].tolist()

    def forward(self, z, y):
        s = self.s + (y @ self.s_cc).view(-1, *self.shape)
        t = self.t + (y @ self.t_cc).view(-1, *self.shape)
        z_ = z * torch.exp(s) + t
        if len(self.batch_dims) > 1:
            prod_batch_dims = np.prod([z.size(i) for i in self.batch_dims[1:]])
        else:
            prod_batch_dims = 1
        log_det = prod_batch_dims * torch.sum(s, dim=list(range(1, self.n_dim)))
        return z_, log_det

    def inverse(self, z, y):
        s = self.s + (y @ self.s_cc).view(-1, *self.shape)
        t = self.t + (y @ self.t_cc).view(-1, *self.shape)
        z_ = (z - t) * torch.exp(-s)
        if len(self.batch_dims) > 1:
            prod_batch_dims = np.prod([z.size(i) for i in self.batch_dims[1:]])
        else:
            prod_batch_dims = 1
        log_det = -prod_batch_dims * torch.sum(s, dim=list(range(1, self.n_dim)))
        return z_, log_det
MaskedAffineFlow

Bases: Flow

RealNVP as introduced in arXiv: 1605.08803

Masked affine flow:

f(z) = b * z + (1 - b) * (z * exp(s(b * z)) + t)
  • class AffineHalfFlow(Flow): is MaskedAffineFlow with alternating bit mask
  • NICE is AffineFlow with only shifts (volume preserving)
Source code in normflows/flows/affine/coupling.py
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
class MaskedAffineFlow(Flow):
    """RealNVP as introduced in [arXiv: 1605.08803](https://arxiv.org/abs/1605.08803)

    Masked affine flow:

    ```
    f(z) = b * z + (1 - b) * (z * exp(s(b * z)) + t)
    ```

    - class AffineHalfFlow(Flow): is MaskedAffineFlow with alternating bit mask
    - NICE is AffineFlow with only shifts (volume preserving)
    """

    def __init__(self, b, t=None, s=None):
        """Constructor

        Args:
          b: mask for features, i.e. tensor of same size as latent data point filled with 0s and 1s
          t: translation mapping, i.e. neural network, where first input dimension is batch dim, if None no translation is applied
          s: scale mapping, i.e. neural network, where first input dimension is batch dim, if None no scale is applied
        """
        super().__init__()
        self.b_cpu = b.view(1, *b.size())
        self.register_buffer("b", self.b_cpu)

        if s is None:
            self.s = torch.zeros_like
        else:
            self.add_module("s", s)

        if t is None:
            self.t = torch.zeros_like
        else:
            self.add_module("t", t)

    def forward(self, z):
        z_masked = self.b * z
        scale = self.s(z_masked)
        nan = torch.tensor(np.nan, dtype=z.dtype, device=z.device)
        scale = torch.where(torch.isfinite(scale), scale, nan)
        trans = self.t(z_masked)
        trans = torch.where(torch.isfinite(trans), trans, nan)
        z_ = z_masked + (1 - self.b) * (z * torch.exp(scale) + trans)
        log_det = torch.sum((1 - self.b) * scale, dim=list(range(1, self.b.dim())))
        return z_, log_det

    def inverse(self, z):
        z_masked = self.b * z
        scale = self.s(z_masked)
        nan = torch.tensor(np.nan, dtype=z.dtype, device=z.device)
        scale = torch.where(torch.isfinite(scale), scale, nan)
        trans = self.t(z_masked)
        trans = torch.where(torch.isfinite(trans), trans, nan)
        z_ = z_masked + (1 - self.b) * (z - trans) * torch.exp(-scale)
        log_det = -torch.sum((1 - self.b) * scale, dim=list(range(1, self.b.dim())))
        return z_, log_det
__init__(b, t=None, s=None)

Constructor

Parameters:

Name Type Description Default
b

mask for features, i.e. tensor of same size as latent data point filled with 0s and 1s

required
t

translation mapping, i.e. neural network, where first input dimension is batch dim, if None no translation is applied

None
s

scale mapping, i.e. neural network, where first input dimension is batch dim, if None no scale is applied

None
Source code in normflows/flows/affine/coupling.py
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
def __init__(self, b, t=None, s=None):
    """Constructor

    Args:
      b: mask for features, i.e. tensor of same size as latent data point filled with 0s and 1s
      t: translation mapping, i.e. neural network, where first input dimension is batch dim, if None no translation is applied
      s: scale mapping, i.e. neural network, where first input dimension is batch dim, if None no scale is applied
    """
    super().__init__()
    self.b_cpu = b.view(1, *b.size())
    self.register_buffer("b", self.b_cpu)

    if s is None:
        self.s = torch.zeros_like
    else:
        self.add_module("s", s)

    if t is None:
        self.t = torch.zeros_like
    else:
        self.add_module("t", t)

glow

GlowBlock

Bases: Flow

Glow: Generative Flow with Invertible 1×1 Convolutions, arXiv: 1807.03039

One Block of the Glow model, comprised of

  • MaskedAffineFlow (affine coupling layer)
  • Invertible1x1Conv (dropped if there is only one channel)
  • ActNorm (first batch used for initialization)
Source code in normflows/flows/affine/glow.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
class GlowBlock(Flow):
    """Glow: Generative Flow with Invertible 1×1 Convolutions, [arXiv: 1807.03039](https://arxiv.org/abs/1807.03039)

    One Block of the Glow model, comprised of

    - MaskedAffineFlow (affine coupling layer)
    - Invertible1x1Conv (dropped if there is only one channel)
    - ActNorm (first batch used for initialization)
    """

    def __init__(
        self,
        channels,
        hidden_channels,
        scale=True,
        scale_map="sigmoid",
        split_mode="channel",
        leaky=0.0,
        init_zeros=True,
        use_lu=True,
        net_actnorm=False,
    ):
        """Constructor

        Args:
          channels: Number of channels of the data
          hidden_channels: number of channels in the hidden layer of the ConvNet
          scale: Flag, whether to include scale in affine coupling layer
          scale_map: Map to be applied to the scale parameter, can be 'exp' as in RealNVP or 'sigmoid' as in Glow
          split_mode: Splitting mode, for possible values see Split class
          leaky: Leaky parameter of LeakyReLUs of ConvNet2d
          init_zeros: Flag whether to initialize last conv layer with zeros
          use_lu: Flag whether to parametrize weights through the LU decomposition in invertible 1x1 convolution layers
          logscale_factor: Factor which can be used to control the scale of the log scale factor, see [source](https://github.com/openai/glow)
        """
        super().__init__()
        self.flows = nn.ModuleList([])
        # Coupling layer
        kernel_size = (3, 1, 3)
        num_param = 2 if scale else 1
        if "channel" == split_mode:
            channels_ = ((channels + 1) // 2,) + 2 * (hidden_channels,)
            channels_ += (num_param * (channels // 2),)
        elif "channel_inv" == split_mode:
            channels_ = (channels // 2,) + 2 * (hidden_channels,)
            channels_ += (num_param * ((channels + 1) // 2),)
        elif "checkerboard" in split_mode:
            channels_ = (channels,) + 2 * (hidden_channels,)
            channels_ += (num_param * channels,)
        else:
            raise NotImplementedError("Mode " + split_mode + " is not implemented.")
        param_map = nets.ConvNet2d(
            channels_, kernel_size, leaky, init_zeros, actnorm=net_actnorm
        )
        self.flows += [AffineCouplingBlock(param_map, scale, scale_map, split_mode)]
        # Invertible 1x1 convolution
        if channels > 1:
            self.flows += [Invertible1x1Conv(channels, use_lu)]
        # Activation normalization
        self.flows += [ActNorm((channels,) + (1, 1))]

    def forward(self, z):
        log_det_tot = torch.zeros(z.shape[0], dtype=z.dtype, device=z.device)
        for flow in self.flows:
            z, log_det = flow(z)
            log_det_tot += log_det
        return z, log_det_tot

    def inverse(self, z):
        log_det_tot = torch.zeros(z.shape[0], dtype=z.dtype, device=z.device)
        for i in range(len(self.flows) - 1, -1, -1):
            z, log_det = self.flows[i].inverse(z)
            log_det_tot += log_det
        return z, log_det_tot
__init__(channels, hidden_channels, scale=True, scale_map='sigmoid', split_mode='channel', leaky=0.0, init_zeros=True, use_lu=True, net_actnorm=False)

Constructor

Parameters:

Name Type Description Default
channels

Number of channels of the data

required
hidden_channels

number of channels in the hidden layer of the ConvNet

required
scale

Flag, whether to include scale in affine coupling layer

True
scale_map

Map to be applied to the scale parameter, can be 'exp' as in RealNVP or 'sigmoid' as in Glow

'sigmoid'
split_mode

Splitting mode, for possible values see Split class

'channel'
leaky

Leaky parameter of LeakyReLUs of ConvNet2d

0.0
init_zeros

Flag whether to initialize last conv layer with zeros

True
use_lu

Flag whether to parametrize weights through the LU decomposition in invertible 1x1 convolution layers

True
logscale_factor

Factor which can be used to control the scale of the log scale factor, see source

required
Source code in normflows/flows/affine/glow.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def __init__(
    self,
    channels,
    hidden_channels,
    scale=True,
    scale_map="sigmoid",
    split_mode="channel",
    leaky=0.0,
    init_zeros=True,
    use_lu=True,
    net_actnorm=False,
):
    """Constructor

    Args:
      channels: Number of channels of the data
      hidden_channels: number of channels in the hidden layer of the ConvNet
      scale: Flag, whether to include scale in affine coupling layer
      scale_map: Map to be applied to the scale parameter, can be 'exp' as in RealNVP or 'sigmoid' as in Glow
      split_mode: Splitting mode, for possible values see Split class
      leaky: Leaky parameter of LeakyReLUs of ConvNet2d
      init_zeros: Flag whether to initialize last conv layer with zeros
      use_lu: Flag whether to parametrize weights through the LU decomposition in invertible 1x1 convolution layers
      logscale_factor: Factor which can be used to control the scale of the log scale factor, see [source](https://github.com/openai/glow)
    """
    super().__init__()
    self.flows = nn.ModuleList([])
    # Coupling layer
    kernel_size = (3, 1, 3)
    num_param = 2 if scale else 1
    if "channel" == split_mode:
        channels_ = ((channels + 1) // 2,) + 2 * (hidden_channels,)
        channels_ += (num_param * (channels // 2),)
    elif "channel_inv" == split_mode:
        channels_ = (channels // 2,) + 2 * (hidden_channels,)
        channels_ += (num_param * ((channels + 1) // 2),)
    elif "checkerboard" in split_mode:
        channels_ = (channels,) + 2 * (hidden_channels,)
        channels_ += (num_param * channels,)
    else:
        raise NotImplementedError("Mode " + split_mode + " is not implemented.")
    param_map = nets.ConvNet2d(
        channels_, kernel_size, leaky, init_zeros, actnorm=net_actnorm
    )
    self.flows += [AffineCouplingBlock(param_map, scale, scale_map, split_mode)]
    # Invertible 1x1 convolution
    if channels > 1:
        self.flows += [Invertible1x1Conv(channels, use_lu)]
    # Activation normalization
    self.flows += [ActNorm((channels,) + (1, 1))]

base

Composite

Bases: Flow

Composes several flows into one, in the order they are given.

Source code in normflows/flows/base.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
class Composite(Flow):
    """
    Composes several flows into one, in the order they are given.
    """

    def __init__(self, flows):
        """Constructor

        Args:
          flows: Iterable of flows to composite
        """
        super().__init__()
        self._flows = nn.ModuleList(flows)

    @staticmethod
    def _cascade(inputs, funcs):
        batch_size = inputs.shape[0]
        outputs = inputs
        total_logabsdet = torch.zeros(batch_size)
        for func in funcs:
            outputs, logabsdet = func(outputs)
            total_logabsdet += logabsdet
        return outputs, total_logabsdet

    def forward(self, inputs):
        funcs = self._flows
        return self._cascade(inputs, funcs)

    def inverse(self, inputs):
        funcs = (flow.inverse for flow in self._flows[::-1])
        return self._cascade(inputs, funcs)
__init__(flows)

Constructor

Parameters:

Name Type Description Default
flows

Iterable of flows to composite

required
Source code in normflows/flows/base.py
53
54
55
56
57
58
59
60
def __init__(self, flows):
    """Constructor

    Args:
      flows: Iterable of flows to composite
    """
    super().__init__()
    self._flows = nn.ModuleList(flows)

Flow

Bases: Module

Generic class for flow functions

Source code in normflows/flows/base.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class Flow(nn.Module):
    """
    Generic class for flow functions
    """

    def __init__(self):
        super().__init__()

    def forward(self, z):
        """
        Args:
          z: input variable, first dimension is batch dim

        Returns:
          transformed z and log of absolute determinant
        """
        raise NotImplementedError("Forward pass has not been implemented.")

    def inverse(self, z):
        raise NotImplementedError("This flow has no algebraic inverse.")
forward(z)

Parameters:

Name Type Description Default
z

input variable, first dimension is batch dim

required

Returns:

Type Description

transformed z and log of absolute determinant

Source code in normflows/flows/base.py
13
14
15
16
17
18
19
20
21
def forward(self, z):
    """
    Args:
      z: input variable, first dimension is batch dim

    Returns:
      transformed z and log of absolute determinant
    """
    raise NotImplementedError("Forward pass has not been implemented.")

Reverse

Bases: Flow

Switches the forward transform of a flow layer with its inverse and vice versa

Source code in normflows/flows/base.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
class Reverse(Flow):
    """
    Switches the forward transform of a flow layer with its inverse and vice versa
    """

    def __init__(self, flow):
        """Constructor

        Args:
          flow: Flow layer to be reversed
        """
        super().__init__()
        self.flow = flow

    def forward(self, z):
        return self.flow.inverse(z)

    def inverse(self, z):
        return self.flow.forward(z)
__init__(flow)

Constructor

Parameters:

Name Type Description Default
flow

Flow layer to be reversed

required
Source code in normflows/flows/base.py
32
33
34
35
36
37
38
39
def __init__(self, flow):
    """Constructor

    Args:
      flow: Flow layer to be reversed
    """
    super().__init__()
    self.flow = flow

flow_test

FlowTest

Bases: TestCase

Generic test case for flow modules

Source code in normflows/flows/flow_test.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
class FlowTest(unittest.TestCase):
    """
    Generic test case for flow modules
    """
    def assertClose(self, actual, expected, atol=None, rtol=None):
        assert_close(actual, expected, atol=atol, rtol=rtol)

    def checkForward(self, flow, inputs, context=None):
        # Do forward transform
        if context is None:
            outputs, log_det = flow(inputs)
        else:
            outputs, log_det = flow(inputs, context)
        # Check type
        assert outputs.dtype == inputs.dtype
        # Check shape
        assert outputs.shape == inputs.shape
        # Return results
        return outputs, log_det

    def checkInverse(self, flow, inputs, context=None):
        # Do inverse transform
        if context is None:
            outputs, log_det = flow.inverse(inputs)
        else:
            outputs, log_det = flow.inverse(inputs, context)
        # Check type
        assert outputs.dtype == inputs.dtype
        # Check shape
        assert outputs.shape == inputs.shape
        # Return results
        return outputs, log_det

    def checkForwardInverse(self, flow, inputs, context=None, atol=None, rtol=None):
        # Check forward
        outputs, log_det = self.checkForward(flow, inputs, context)
        # Check inverse
        input_, log_det_ = self.checkInverse(flow, outputs, context)
        # Check identity
        self.assertClose(input_, inputs, atol, rtol)
        ld_id = log_det + log_det_
        self.assertClose(ld_id, torch.zeros_like(ld_id), atol, rtol)

mixing

Invertible1x1Conv

Bases: Flow

Invertible 1x1 convolution introduced in the Glow paper Assumes 4d input/output tensors of the form NCHW

Source code in normflows/flows/mixing.py
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
class Invertible1x1Conv(Flow):
    """
    Invertible 1x1 convolution introduced in the Glow paper
    Assumes 4d input/output tensors of the form NCHW
    """

    def __init__(self, num_channels, use_lu=False):
        """Constructor

        Args:
          num_channels: Number of channels of the data
          use_lu: Flag whether to parametrize weights through the LU decomposition
        """
        super().__init__()
        self.num_channels = num_channels
        self.use_lu = use_lu
        Q, _ = torch.linalg.qr(torch.randn(self.num_channels, self.num_channels))
        if use_lu:
            P, L, U = torch.lu_unpack(*Q.lu())
            self.register_buffer("P", P)  # remains fixed during optimization
            self.L = nn.Parameter(L)  # lower triangular portion
            S = U.diag()  # "crop out" the diagonal to its own parameter
            self.register_buffer("sign_S", torch.sign(S))
            self.log_S = nn.Parameter(torch.log(torch.abs(S)))
            self.U = nn.Parameter(
                torch.triu(U, diagonal=1)
            )  # "crop out" diagonal, stored in S
            self.register_buffer("eye", torch.diag(torch.ones(self.num_channels)))
        else:
            self.W = nn.Parameter(Q)

    def _assemble_W(self, inverse=False):
        # assemble W from its components (P, L, U, S)
        L = torch.tril(self.L, diagonal=-1) + self.eye
        U = torch.triu(self.U, diagonal=1) + torch.diag(
            self.sign_S * torch.exp(self.log_S)
        )
        if inverse:
            if self.log_S.dtype == torch.float64:
                L_inv = torch.inverse(L)
                U_inv = torch.inverse(U)
            else:
                L_inv = torch.inverse(L.double()).type(self.log_S.dtype)
                U_inv = torch.inverse(U.double()).type(self.log_S.dtype)
            W = U_inv @ L_inv @ self.P.t()
        else:
            W = self.P @ L @ U
        return W

    def forward(self, z):
        if self.use_lu:
            W = self._assemble_W(inverse=True)
            log_det = -torch.sum(self.log_S)
        else:
            W_dtype = self.W.dtype
            if W_dtype == torch.float64:
                W = torch.inverse(self.W)
            else:
                W = torch.inverse(self.W.double()).type(W_dtype)
            W = W.view(*W.size(), 1, 1)
            log_det = -torch.slogdet(self.W)[1]
        W = W.view(self.num_channels, self.num_channels, 1, 1)
        z_ = torch.nn.functional.conv2d(z, W)
        log_det = log_det * z.size(2) * z.size(3)
        return z_, log_det

    def inverse(self, z):
        if self.use_lu:
            W = self._assemble_W()
            log_det = torch.sum(self.log_S)
        else:
            W = self.W
            log_det = torch.slogdet(self.W)[1]
        W = W.view(self.num_channels, self.num_channels, 1, 1)
        z_ = torch.nn.functional.conv2d(z, W)
        log_det = log_det * z.size(2) * z.size(3)
        return z_, log_det
__init__(num_channels, use_lu=False)

Constructor

Parameters:

Name Type Description Default
num_channels

Number of channels of the data

required
use_lu

Flag whether to parametrize weights through the LU decomposition

False
Source code in normflows/flows/mixing.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def __init__(self, num_channels, use_lu=False):
    """Constructor

    Args:
      num_channels: Number of channels of the data
      use_lu: Flag whether to parametrize weights through the LU decomposition
    """
    super().__init__()
    self.num_channels = num_channels
    self.use_lu = use_lu
    Q, _ = torch.linalg.qr(torch.randn(self.num_channels, self.num_channels))
    if use_lu:
        P, L, U = torch.lu_unpack(*Q.lu())
        self.register_buffer("P", P)  # remains fixed during optimization
        self.L = nn.Parameter(L)  # lower triangular portion
        S = U.diag()  # "crop out" the diagonal to its own parameter
        self.register_buffer("sign_S", torch.sign(S))
        self.log_S = nn.Parameter(torch.log(torch.abs(S)))
        self.U = nn.Parameter(
            torch.triu(U, diagonal=1)
        )  # "crop out" diagonal, stored in S
        self.register_buffer("eye", torch.diag(torch.ones(self.num_channels)))
    else:
        self.W = nn.Parameter(Q)

InvertibleAffine

Bases: Flow

Invertible affine transformation without shift, i.e. one-dimensional version of the invertible 1x1 convolutions

Source code in normflows/flows/mixing.py
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
class InvertibleAffine(Flow):
    """
    Invertible affine transformation without shift, i.e. one-dimensional
    version of the invertible 1x1 convolutions
    """

    def __init__(self, num_channels, use_lu=True):
        """Constructor

        Args:
          num_channels: Number of channels of the data
          use_lu: Flag whether to parametrize weights through the LU decomposition
        """
        super().__init__()
        self.num_channels = num_channels
        self.use_lu = use_lu
        Q, _ = torch.linalg.qr(torch.randn(self.num_channels, self.num_channels))
        if use_lu:
            P, L, U = torch.lu_unpack(*Q.lu())
            self.register_buffer("P", P)  # remains fixed during optimization
            self.L = nn.Parameter(L)  # lower triangular portion
            S = U.diag()  # "crop out" the diagonal to its own parameter
            self.register_buffer("sign_S", torch.sign(S))
            self.log_S = nn.Parameter(torch.log(torch.abs(S)))
            self.U = nn.Parameter(
                torch.triu(U, diagonal=1)
            )  # "crop out" diagonal, stored in S
            self.register_buffer("eye", torch.diag(torch.ones(self.num_channels)))
        else:
            self.W = nn.Parameter(Q)

    def _assemble_W(self, inverse=False):
        # assemble W from its components (P, L, U, S)
        L = torch.tril(self.L, diagonal=-1) + self.eye
        U = torch.triu(self.U, diagonal=1) + torch.diag(
            self.sign_S * torch.exp(self.log_S)
        )
        if inverse:
            if self.log_S.dtype == torch.float64:
                L_inv = torch.inverse(L)
                U_inv = torch.inverse(U)
            else:
                L_inv = torch.inverse(L.double()).type(self.log_S.dtype)
                U_inv = torch.inverse(U.double()).type(self.log_S.dtype)
            W = U_inv @ L_inv @ self.P.t()
        else:
            W = self.P @ L @ U
        return W

    def forward(self, z, context=None):
        if self.use_lu:
            W = self._assemble_W(inverse=True)
            log_det = -torch.sum(self.log_S)
        else:
            W_dtype = self.W.dtype
            if W_dtype == torch.float64:
                W = torch.inverse(self.W)
            else:
                W = torch.inverse(self.W.double()).type(W_dtype)
            log_det = -torch.slogdet(self.W)[1]
        z_ = z @ W
        return z_, log_det

    def inverse(self, z, context=None):
        if self.use_lu:
            W = self._assemble_W()
            log_det = torch.sum(self.log_S)
        else:
            W = self.W
            log_det = torch.slogdet(self.W)[1]
        z_ = z @ W
        return z_, log_det
__init__(num_channels, use_lu=True)

Constructor

Parameters:

Name Type Description Default
num_channels

Number of channels of the data

required
use_lu

Flag whether to parametrize weights through the LU decomposition

True
Source code in normflows/flows/mixing.py
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
def __init__(self, num_channels, use_lu=True):
    """Constructor

    Args:
      num_channels: Number of channels of the data
      use_lu: Flag whether to parametrize weights through the LU decomposition
    """
    super().__init__()
    self.num_channels = num_channels
    self.use_lu = use_lu
    Q, _ = torch.linalg.qr(torch.randn(self.num_channels, self.num_channels))
    if use_lu:
        P, L, U = torch.lu_unpack(*Q.lu())
        self.register_buffer("P", P)  # remains fixed during optimization
        self.L = nn.Parameter(L)  # lower triangular portion
        S = U.diag()  # "crop out" the diagonal to its own parameter
        self.register_buffer("sign_S", torch.sign(S))
        self.log_S = nn.Parameter(torch.log(torch.abs(S)))
        self.U = nn.Parameter(
            torch.triu(U, diagonal=1)
        )  # "crop out" diagonal, stored in S
        self.register_buffer("eye", torch.diag(torch.ones(self.num_channels)))
    else:
        self.W = nn.Parameter(Q)

LULinearPermute

Bases: Flow

Fixed permutation combined with a linear transformation parametrized using the LU decomposition, used in https://arxiv.org/abs/1906.04032

Source code in normflows/flows/mixing.py
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
class LULinearPermute(Flow):
    """
    Fixed permutation combined with a linear transformation parametrized
    using the LU decomposition, used in https://arxiv.org/abs/1906.04032
    """

    def __init__(self, num_channels, identity_init=True):
        """Constructor

        Args:
          num_channels: Number of dimensions of the data
          identity_init: Flag, whether to initialize linear transform as identity matrix
        """
        # Initialize
        super().__init__()

        # Define modules
        self.permutation = _RandomPermutation(num_channels)
        self.linear = _LULinear(num_channels, identity_init=identity_init)

    def forward(self, z, context=None):
        z, log_det = self.linear.inverse(z, context=context)
        z, _ = self.permutation.inverse(z, context=context)
        return z, log_det.view(-1)

    def inverse(self, z, context=None):
        z, _ = self.permutation(z, context=context)
        z, log_det = self.linear(z, context=context)
        return z, log_det.view(-1)
__init__(num_channels, identity_init=True)

Constructor

Parameters:

Name Type Description Default
num_channels

Number of dimensions of the data

required
identity_init

Flag, whether to initialize linear transform as identity matrix

True
Source code in normflows/flows/mixing.py
541
542
543
544
545
546
547
548
549
550
551
552
553
def __init__(self, num_channels, identity_init=True):
    """Constructor

    Args:
      num_channels: Number of dimensions of the data
      identity_init: Flag, whether to initialize linear transform as identity matrix
    """
    # Initialize
    super().__init__()

    # Define modules
    self.permutation = _RandomPermutation(num_channels)
    self.linear = _LULinear(num_channels, identity_init=identity_init)

Permute

Bases: Flow

Permutation features along the channel dimension

Source code in normflows/flows/mixing.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class Permute(Flow):
    """
    Permutation features along the channel dimension
    """

    def __init__(self, num_channels, mode="shuffle"):
        """Constructor

        Args:
          num_channel: Number of channels
          mode: Mode of permuting features, can be shuffle for random permutation or swap for interchanging upper and lower part
        """
        super().__init__()
        self.mode = mode
        self.num_channels = num_channels
        if self.mode == "shuffle":
            perm = torch.randperm(self.num_channels)
            inv_perm = torch.empty_like(perm).scatter_(
                dim=0, index=perm, src=torch.arange(self.num_channels)
            )
            self.register_buffer("perm", perm)
            self.register_buffer("inv_perm", inv_perm)

    def forward(self, z, context=None):
        if self.mode == "shuffle":
            z = z[:, self.perm, ...]
        elif self.mode == "swap":
            z1 = z[:, : self.num_channels // 2, ...]
            z2 = z[:, self.num_channels // 2 :, ...]
            z = torch.cat([z2, z1], dim=1)
        else:
            raise NotImplementedError("The mode " + self.mode + " is not implemented.")
        log_det = torch.zeros(len(z), device=z.device)
        return z, log_det

    def inverse(self, z, context=None):
        if self.mode == "shuffle":
            z = z[:, self.inv_perm, ...]
        elif self.mode == "swap":
            z1 = z[:, : (self.num_channels + 1) // 2, ...]
            z2 = z[:, (self.num_channels + 1) // 2 :, ...]
            z = torch.cat([z2, z1], dim=1)
        else:
            raise NotImplementedError("The mode " + self.mode + " is not implemented.")
        log_det = torch.zeros(len(z), device=z.device)
        return z, log_det
__init__(num_channels, mode='shuffle')

Constructor

Parameters:

Name Type Description Default
num_channel

Number of channels

required
mode

Mode of permuting features, can be shuffle for random permutation or swap for interchanging upper and lower part

'shuffle'
Source code in normflows/flows/mixing.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def __init__(self, num_channels, mode="shuffle"):
    """Constructor

    Args:
      num_channel: Number of channels
      mode: Mode of permuting features, can be shuffle for random permutation or swap for interchanging upper and lower part
    """
    super().__init__()
    self.mode = mode
    self.num_channels = num_channels
    if self.mode == "shuffle":
        perm = torch.randperm(self.num_channels)
        inv_perm = torch.empty_like(perm).scatter_(
            dim=0, index=perm, src=torch.arange(self.num_channels)
        )
        self.register_buffer("perm", perm)
        self.register_buffer("inv_perm", inv_perm)

neural_spline

autoregressive

Implementations of autoregressive transforms. Code taken from https://github.com/bayesiains/nsf

autoregressive_test

Tests for the autoregressive transforms. Code partially taken from https://github.com/bayesiains/nsf

coupling

Implementations of various coupling layers. Code taken from https://github.com/bayesiains/nsf

Coupling

Bases: Flow

A base class for coupling layers. Supports 2D inputs (NxD), as well as 4D inputs for images (NxCxHxW). For images the splitting is done on the channel dimension, using the provided 1D mask.

Source code in normflows/flows/neural_spline/coupling.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
class Coupling(Flow):
    """A base class for coupling layers. Supports 2D inputs (NxD), as well as 4D inputs for
    images (NxCxHxW). For images the splitting is done on the channel dimension, using the
    provided 1D mask."""

    def __init__(self, mask, transform_net_create_fn, unconditional_transform=None):
        """Constructor.

        mask: a 1-dim tensor, tuple or list. It indexes inputs as follows:

        - if `mask[i] > 0`, `input[i]` will be transformed.
        - if `mask[i] <= 0`, `input[i]` will be passed unchanged.

        Args:
          mask
        """
        mask = torch.as_tensor(mask)
        if mask.dim() != 1:
            raise ValueError("Mask must be a 1-dim tensor.")
        if mask.numel() <= 0:
            raise ValueError("Mask can't be empty.")

        super().__init__()
        self.features = len(mask)
        features_vector = torch.arange(self.features)

        self.register_buffer(
            "identity_features", features_vector.masked_select(mask <= 0)
        )
        self.register_buffer(
            "transform_features", features_vector.masked_select(mask > 0)
        )

        assert self.num_identity_features + self.num_transform_features == self.features

        self.transform_net = transform_net_create_fn(
            self.num_identity_features,
            self.num_transform_features * self._transform_dim_multiplier(),
        )

        if unconditional_transform is None:
            self.unconditional_transform = None
        else:
            self.unconditional_transform = unconditional_transform(
                features=self.num_identity_features
            )

    @property
    def num_identity_features(self):
        return len(self.identity_features)

    @property
    def num_transform_features(self):
        return len(self.transform_features)

    def forward(self, inputs, context=None):
        if inputs.dim() not in [2, 4]:
            raise ValueError("Inputs must be a 2D or a 4D tensor.")

        if inputs.shape[1] != self.features:
            raise ValueError(
                "Expected features = {}, got {}.".format(self.features, inputs.shape[1])
            )

        identity_split = inputs[:, self.identity_features, ...]
        transform_split = inputs[:, self.transform_features, ...]

        transform_params = self.transform_net(identity_split, context)
        transform_split, logabsdet = self._coupling_transform_forward(
            inputs=transform_split, transform_params=transform_params
        )

        if self.unconditional_transform is not None:
            identity_split, logabsdet_identity = self.unconditional_transform(
                identity_split, context
            )
            logabsdet += logabsdet_identity

        outputs = torch.empty_like(inputs)
        outputs[:, self.identity_features, ...] = identity_split
        outputs[:, self.transform_features, ...] = transform_split

        return outputs, logabsdet

    def inverse(self, inputs, context=None):
        if inputs.dim() not in [2, 4]:
            raise ValueError("Inputs must be a 2D or a 4D tensor.")

        if inputs.shape[1] != self.features:
            raise ValueError(
                "Expected features = {}, got {}.".format(self.features, inputs.shape[1])
            )

        identity_split = inputs[:, self.identity_features, ...]
        transform_split = inputs[:, self.transform_features, ...]

        logabsdet = 0.0
        if self.unconditional_transform is not None:
            identity_split, logabsdet = self.unconditional_transform.inverse(
                identity_split, context
            )

        transform_params = self.transform_net(identity_split, context)
        transform_split, logabsdet_split = self._coupling_transform_inverse(
            inputs=transform_split, transform_params=transform_params
        )
        logabsdet += logabsdet_split

        outputs = torch.empty_like(inputs)
        outputs[:, self.identity_features] = identity_split
        outputs[:, self.transform_features] = transform_split

        return outputs, logabsdet

    def _transform_dim_multiplier(self):
        """Number of features to output for each transform dimension."""
        raise NotImplementedError()

    def _coupling_transform_forward(self, inputs, transform_params):
        """Forward pass of the coupling transform."""
        raise NotImplementedError()

    def _coupling_transform_inverse(self, inputs, transform_params):
        """Inverse of the coupling transform."""
        raise NotImplementedError()
__init__(mask, transform_net_create_fn, unconditional_transform=None)

Constructor.

mask: a 1-dim tensor, tuple or list. It indexes inputs as follows:

  • if mask[i] > 0, input[i] will be transformed.
  • if mask[i] <= 0, input[i] will be passed unchanged.
Source code in normflows/flows/neural_spline/coupling.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def __init__(self, mask, transform_net_create_fn, unconditional_transform=None):
    """Constructor.

    mask: a 1-dim tensor, tuple or list. It indexes inputs as follows:

    - if `mask[i] > 0`, `input[i]` will be transformed.
    - if `mask[i] <= 0`, `input[i]` will be passed unchanged.

    Args:
      mask
    """
    mask = torch.as_tensor(mask)
    if mask.dim() != 1:
        raise ValueError("Mask must be a 1-dim tensor.")
    if mask.numel() <= 0:
        raise ValueError("Mask can't be empty.")

    super().__init__()
    self.features = len(mask)
    features_vector = torch.arange(self.features)

    self.register_buffer(
        "identity_features", features_vector.masked_select(mask <= 0)
    )
    self.register_buffer(
        "transform_features", features_vector.masked_select(mask > 0)
    )

    assert self.num_identity_features + self.num_transform_features == self.features

    self.transform_net = transform_net_create_fn(
        self.num_identity_features,
        self.num_transform_features * self._transform_dim_multiplier(),
    )

    if unconditional_transform is None:
        self.unconditional_transform = None
    else:
        self.unconditional_transform = unconditional_transform(
            features=self.num_identity_features
        )

coupling_test

Tests for the coupling Transforms. Code partially taken from https://github.com/bayesiains/nsf

wrapper

AutoregressiveRationalQuadraticSpline

Bases: Flow

Neural spline flow coupling layer, wrapper for the implementation of Durkan et al., see sources

Source code in normflows/flows/neural_spline/wrapper.py
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
class AutoregressiveRationalQuadraticSpline(Flow):
    """
    Neural spline flow coupling layer, wrapper for the implementation
    of Durkan et al., see [sources](https://github.com/bayesiains/nsf)
    """

    def __init__(
        self,
        num_input_channels,
        num_blocks,
        num_hidden_channels,
        num_context_channels=None,
        num_bins=8,
        tail_bound=3,
        activation=nn.ReLU,
        dropout_probability=0.0,
        permute_mask=False,
        init_identity=True,
    ):
        """Constructor

        Args:
          num_input_channels (int): Flow dimension
          num_blocks (int): Number of residual blocks of the parameter NN
          num_hidden_channels (int): Number of hidden units of the NN
          num_context_channels (int): Number of context/conditional channels
          num_bins (int): Number of bins
          tail_bound (int): Bound of the spline tails
          activation (torch.nn.Module): Activation function
          dropout_probability (float): Dropout probability of the NN
          permute_mask (bool): Flag, permutes the mask of the NN
          init_identity (bool): Flag, initialize transform as identity
        """
        super().__init__()

        self.mprqat = MaskedPiecewiseRationalQuadraticAutoregressive(
            features=num_input_channels,
            hidden_features=num_hidden_channels,
            context_features=num_context_channels,
            num_bins=num_bins,
            tails="linear",
            tail_bound=tail_bound,
            num_blocks=num_blocks,
            use_residual_blocks=True,
            random_mask=False,
            permute_mask=permute_mask,
            activation=activation(),
            dropout_probability=dropout_probability,
            use_batch_norm=False,
            init_identity=init_identity,
        )

    def forward(self, z, context=None):
        z, log_det = self.mprqat.inverse(z, context=context)
        return z, log_det.view(-1)

    def inverse(self, z, context=None):
        z, log_det = self.mprqat(z, context=context)
        return z, log_det.view(-1)
__init__(num_input_channels, num_blocks, num_hidden_channels, num_context_channels=None, num_bins=8, tail_bound=3, activation=nn.ReLU, dropout_probability=0.0, permute_mask=False, init_identity=True)

Constructor

Parameters:

Name Type Description Default
num_input_channels int

Flow dimension

required
num_blocks int

Number of residual blocks of the parameter NN

required
num_hidden_channels int

Number of hidden units of the NN

required
num_context_channels int

Number of context/conditional channels

None
num_bins int

Number of bins

8
tail_bound int

Bound of the spline tails

3
activation Module

Activation function

ReLU
dropout_probability float

Dropout probability of the NN

0.0
permute_mask bool

Flag, permutes the mask of the NN

False
init_identity bool

Flag, initialize transform as identity

True
Source code in normflows/flows/neural_spline/wrapper.py
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
def __init__(
    self,
    num_input_channels,
    num_blocks,
    num_hidden_channels,
    num_context_channels=None,
    num_bins=8,
    tail_bound=3,
    activation=nn.ReLU,
    dropout_probability=0.0,
    permute_mask=False,
    init_identity=True,
):
    """Constructor

    Args:
      num_input_channels (int): Flow dimension
      num_blocks (int): Number of residual blocks of the parameter NN
      num_hidden_channels (int): Number of hidden units of the NN
      num_context_channels (int): Number of context/conditional channels
      num_bins (int): Number of bins
      tail_bound (int): Bound of the spline tails
      activation (torch.nn.Module): Activation function
      dropout_probability (float): Dropout probability of the NN
      permute_mask (bool): Flag, permutes the mask of the NN
      init_identity (bool): Flag, initialize transform as identity
    """
    super().__init__()

    self.mprqat = MaskedPiecewiseRationalQuadraticAutoregressive(
        features=num_input_channels,
        hidden_features=num_hidden_channels,
        context_features=num_context_channels,
        num_bins=num_bins,
        tails="linear",
        tail_bound=tail_bound,
        num_blocks=num_blocks,
        use_residual_blocks=True,
        random_mask=False,
        permute_mask=permute_mask,
        activation=activation(),
        dropout_probability=dropout_probability,
        use_batch_norm=False,
        init_identity=init_identity,
    )
CircularAutoregressiveRationalQuadraticSpline

Bases: Flow

Neural spline flow coupling layer, wrapper for the implementation of Durkan et al., see sources

Source code in normflows/flows/neural_spline/wrapper.py
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
class CircularAutoregressiveRationalQuadraticSpline(Flow):
    """
    Neural spline flow coupling layer, wrapper for the implementation
    of Durkan et al., see [sources](https://github.com/bayesiains/nsf)
    """

    def __init__(
        self,
        num_input_channels,
        num_blocks,
        num_hidden_channels,
        ind_circ,
        num_context_channels=None,
        num_bins=8,
        tail_bound=3,
        activation=nn.ReLU,
        dropout_probability=0.0,
        permute_mask=True,
        init_identity=True,
    ):
        """Constructor

        Args:
          num_input_channels (int): Flow dimension
          num_blocks (int): Number of residual blocks of the parameter NN
          num_hidden_channels (int): Number of hidden units of the NN
          ind_circ (Iterable): Indices of the circular coordinates
          num_context_channels (int): Number of context/conditional channels
          num_bins (int): Number of bins
          tail_bound (int): Bound of the spline tails
          activation (torch module): Activation function
          dropout_probability (float): Dropout probability of the NN
          permute_mask (bool): Flag, permutes the mask of the NN
          init_identity (bool): Flag, initialize transform as identity
        """
        super().__init__()

        tails = [
            "circular" if i in ind_circ else "linear" for i in range(num_input_channels)
        ]

        self.mprqat = MaskedPiecewiseRationalQuadraticAutoregressive(
            features=num_input_channels,
            hidden_features=num_hidden_channels,
            context_features=num_context_channels,
            num_bins=num_bins,
            tails=tails,
            tail_bound=tail_bound,
            num_blocks=num_blocks,
            use_residual_blocks=True,
            random_mask=False,
            permute_mask=permute_mask,
            activation=activation(),
            dropout_probability=dropout_probability,
            use_batch_norm=False,
            init_identity=init_identity,
        )

    def forward(self, z, context=None):
        z, log_det = self.mprqat.inverse(z, context=context)
        return z, log_det.view(-1)

    def inverse(self, z, context=None):
        z, log_det = self.mprqat(z, context=context)
        return z, log_det.view(-1)
__init__(num_input_channels, num_blocks, num_hidden_channels, ind_circ, num_context_channels=None, num_bins=8, tail_bound=3, activation=nn.ReLU, dropout_probability=0.0, permute_mask=True, init_identity=True)

Constructor

Parameters:

Name Type Description Default
num_input_channels int

Flow dimension

required
num_blocks int

Number of residual blocks of the parameter NN

required
num_hidden_channels int

Number of hidden units of the NN

required
ind_circ Iterable

Indices of the circular coordinates

required
num_context_channels int

Number of context/conditional channels

None
num_bins int

Number of bins

8
tail_bound int

Bound of the spline tails

3
activation torch module

Activation function

ReLU
dropout_probability float

Dropout probability of the NN

0.0
permute_mask bool

Flag, permutes the mask of the NN

True
init_identity bool

Flag, initialize transform as identity

True
Source code in normflows/flows/neural_spline/wrapper.py
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
def __init__(
    self,
    num_input_channels,
    num_blocks,
    num_hidden_channels,
    ind_circ,
    num_context_channels=None,
    num_bins=8,
    tail_bound=3,
    activation=nn.ReLU,
    dropout_probability=0.0,
    permute_mask=True,
    init_identity=True,
):
    """Constructor

    Args:
      num_input_channels (int): Flow dimension
      num_blocks (int): Number of residual blocks of the parameter NN
      num_hidden_channels (int): Number of hidden units of the NN
      ind_circ (Iterable): Indices of the circular coordinates
      num_context_channels (int): Number of context/conditional channels
      num_bins (int): Number of bins
      tail_bound (int): Bound of the spline tails
      activation (torch module): Activation function
      dropout_probability (float): Dropout probability of the NN
      permute_mask (bool): Flag, permutes the mask of the NN
      init_identity (bool): Flag, initialize transform as identity
    """
    super().__init__()

    tails = [
        "circular" if i in ind_circ else "linear" for i in range(num_input_channels)
    ]

    self.mprqat = MaskedPiecewiseRationalQuadraticAutoregressive(
        features=num_input_channels,
        hidden_features=num_hidden_channels,
        context_features=num_context_channels,
        num_bins=num_bins,
        tails=tails,
        tail_bound=tail_bound,
        num_blocks=num_blocks,
        use_residual_blocks=True,
        random_mask=False,
        permute_mask=permute_mask,
        activation=activation(),
        dropout_probability=dropout_probability,
        use_batch_norm=False,
        init_identity=init_identity,
    )
CircularCoupledRationalQuadraticSpline

Bases: Flow

Neural spline flow coupling layer with circular coordinates

Source code in normflows/flows/neural_spline/wrapper.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
class CircularCoupledRationalQuadraticSpline(Flow):
    """
    Neural spline flow coupling layer with circular coordinates
    """

    def __init__(
        self,
        num_input_channels,
        num_blocks,
        num_hidden_channels,
        ind_circ,
        num_context_channels=None,
        num_bins=8,
        tail_bound=3.0,
        activation=nn.ReLU,
        dropout_probability=0.0,
        reverse_mask=False,
        mask=None,
        init_identity=True,
    ):
        """Constructor

        Args:
          num_input_channels (int): Flow dimension
          num_blocks (int): Number of residual blocks of the parameter NN
          num_hidden_channels (int): Number of hidden units of the NN
          num_context_channels (int): Number of context/conditional channels
          ind_circ (Iterable): Indices of the circular coordinates
          num_bins (int): Number of bins
          tail_bound (float or Iterable): Bound of the spline tails
          activation (torch module): Activation function
          dropout_probability (float): Dropout probability of the NN
          reverse_mask (bool): Flag whether the reverse mask should be used
          mask (torch tensor): Mask to be used, alternating masked generated is None
          init_identity (bool): Flag, initialize transform as identity
        """
        super().__init__()

        if mask is None:
            mask = create_alternating_binary_mask(num_input_channels, even=reverse_mask)
        features_vector = torch.arange(num_input_channels)
        identity_features = features_vector.masked_select(mask <= 0)
        ind_circ = torch.tensor(ind_circ)
        ind_circ_id = []
        for i, id in enumerate(identity_features):
            if id in ind_circ:
                ind_circ_id += [i]

        if torch.is_tensor(tail_bound):
            scale_pf = np.pi / tail_bound[ind_circ_id]
        else:
            scale_pf = np.pi / tail_bound

        def transform_net_create_fn(in_features, out_features):
            if len(ind_circ_id) > 0:
                pf = PeriodicFeaturesElementwise(in_features, ind_circ_id, scale_pf)
            else:
                pf = None
            net = ResidualNet(
                in_features=in_features,
                out_features=out_features,
                context_features=num_context_channels,
                hidden_features=num_hidden_channels,
                num_blocks=num_blocks,
                activation=activation(),
                dropout_probability=dropout_probability,
                use_batch_norm=False,
                preprocessing=pf,
            )
            if init_identity:
                torch.nn.init.constant_(net.final_layer.weight, 0.0)
                torch.nn.init.constant_(
                    net.final_layer.bias, np.log(np.exp(1 - DEFAULT_MIN_DERIVATIVE) - 1)
                )
            return net

        tails = [
            "circular" if i in ind_circ else "linear" for i in range(num_input_channels)
        ]

        self.prqct = PiecewiseRationalQuadraticCoupling(
            mask=mask,
            transform_net_create_fn=transform_net_create_fn,
            num_bins=num_bins,
            tails=tails,
            tail_bound=tail_bound,
            apply_unconditional_transform=True,
        )

    def forward(self, z, context=None):
        z, log_det = self.prqct.inverse(z, context)
        return z, log_det.view(-1)

    def inverse(self, z, context=None):
        z, log_det = self.prqct(z, context)
        return z, log_det.view(-1)
__init__(num_input_channels, num_blocks, num_hidden_channels, ind_circ, num_context_channels=None, num_bins=8, tail_bound=3.0, activation=nn.ReLU, dropout_probability=0.0, reverse_mask=False, mask=None, init_identity=True)

Constructor

Parameters:

Name Type Description Default
num_input_channels int

Flow dimension

required
num_blocks int

Number of residual blocks of the parameter NN

required
num_hidden_channels int

Number of hidden units of the NN

required
num_context_channels int

Number of context/conditional channels

None
ind_circ Iterable

Indices of the circular coordinates

required
num_bins int

Number of bins

8
tail_bound float or Iterable

Bound of the spline tails

3.0
activation torch module

Activation function

ReLU
dropout_probability float

Dropout probability of the NN

0.0
reverse_mask bool

Flag whether the reverse mask should be used

False
mask torch tensor

Mask to be used, alternating masked generated is None

None
init_identity bool

Flag, initialize transform as identity

True
Source code in normflows/flows/neural_spline/wrapper.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
def __init__(
    self,
    num_input_channels,
    num_blocks,
    num_hidden_channels,
    ind_circ,
    num_context_channels=None,
    num_bins=8,
    tail_bound=3.0,
    activation=nn.ReLU,
    dropout_probability=0.0,
    reverse_mask=False,
    mask=None,
    init_identity=True,
):
    """Constructor

    Args:
      num_input_channels (int): Flow dimension
      num_blocks (int): Number of residual blocks of the parameter NN
      num_hidden_channels (int): Number of hidden units of the NN
      num_context_channels (int): Number of context/conditional channels
      ind_circ (Iterable): Indices of the circular coordinates
      num_bins (int): Number of bins
      tail_bound (float or Iterable): Bound of the spline tails
      activation (torch module): Activation function
      dropout_probability (float): Dropout probability of the NN
      reverse_mask (bool): Flag whether the reverse mask should be used
      mask (torch tensor): Mask to be used, alternating masked generated is None
      init_identity (bool): Flag, initialize transform as identity
    """
    super().__init__()

    if mask is None:
        mask = create_alternating_binary_mask(num_input_channels, even=reverse_mask)
    features_vector = torch.arange(num_input_channels)
    identity_features = features_vector.masked_select(mask <= 0)
    ind_circ = torch.tensor(ind_circ)
    ind_circ_id = []
    for i, id in enumerate(identity_features):
        if id in ind_circ:
            ind_circ_id += [i]

    if torch.is_tensor(tail_bound):
        scale_pf = np.pi / tail_bound[ind_circ_id]
    else:
        scale_pf = np.pi / tail_bound

    def transform_net_create_fn(in_features, out_features):
        if len(ind_circ_id) > 0:
            pf = PeriodicFeaturesElementwise(in_features, ind_circ_id, scale_pf)
        else:
            pf = None
        net = ResidualNet(
            in_features=in_features,
            out_features=out_features,
            context_features=num_context_channels,
            hidden_features=num_hidden_channels,
            num_blocks=num_blocks,
            activation=activation(),
            dropout_probability=dropout_probability,
            use_batch_norm=False,
            preprocessing=pf,
        )
        if init_identity:
            torch.nn.init.constant_(net.final_layer.weight, 0.0)
            torch.nn.init.constant_(
                net.final_layer.bias, np.log(np.exp(1 - DEFAULT_MIN_DERIVATIVE) - 1)
            )
        return net

    tails = [
        "circular" if i in ind_circ else "linear" for i in range(num_input_channels)
    ]

    self.prqct = PiecewiseRationalQuadraticCoupling(
        mask=mask,
        transform_net_create_fn=transform_net_create_fn,
        num_bins=num_bins,
        tails=tails,
        tail_bound=tail_bound,
        apply_unconditional_transform=True,
    )
CoupledRationalQuadraticSpline

Bases: Flow

Neural spline flow coupling layer, wrapper for the implementation of Durkan et al., see source

Source code in normflows/flows/neural_spline/wrapper.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
class CoupledRationalQuadraticSpline(Flow):
    """
    Neural spline flow coupling layer, wrapper for the implementation
    of Durkan et al., see [source](https://github.com/bayesiains/nsf)
    """

    def __init__(
        self,
        num_input_channels,
        num_blocks,
        num_hidden_channels,
        num_context_channels=None,
        num_bins=8,
        tails="linear",
        tail_bound=3.0,
        activation=nn.ReLU,
        dropout_probability=0.0,
        reverse_mask=False,
        init_identity=True,
    ):
        """Constructor

        Args:
          num_input_channels (int): Flow dimension
          num_blocks (int): Number of residual blocks of the parameter NN
          num_hidden_channels (int): Number of hidden units of the NN
          num_context_channels (int): Number of context/conditional channels
          num_bins (int): Number of bins
          tails (str): Behaviour of the tails of the distribution, can be linear, circular for periodic distribution, or None for distribution on the compact interval
          tail_bound (float): Bound of the spline tails
          activation (torch module): Activation function
          dropout_probability (float): Dropout probability of the NN
          reverse_mask (bool): Flag whether the reverse mask should be used
          init_identity (bool): Flag, initialize transform as identity
        """
        super().__init__()

        def transform_net_create_fn(in_features, out_features):
            net = ResidualNet(
                in_features=in_features,
                out_features=out_features,
                context_features=num_context_channels,
                hidden_features=num_hidden_channels,
                num_blocks=num_blocks,
                activation=activation(),
                dropout_probability=dropout_probability,
                use_batch_norm=False,
            )
            if init_identity:
                torch.nn.init.constant_(net.final_layer.weight, 0.0)
                torch.nn.init.constant_(
                    net.final_layer.bias, np.log(np.exp(1 - DEFAULT_MIN_DERIVATIVE) - 1)
                )
            return net

        self.prqct = PiecewiseRationalQuadraticCoupling(
            mask=create_alternating_binary_mask(num_input_channels, even=reverse_mask),
            transform_net_create_fn=transform_net_create_fn,
            num_bins=num_bins,
            tails=tails,
            tail_bound=tail_bound,
            # Setting True corresponds to equations (4), (5), (6) in the NSF paper:
            apply_unconditional_transform=True,
        )

    def forward(self, z, context=None):
        z, log_det = self.prqct.inverse(z, context)
        return z, log_det.view(-1)

    def inverse(self, z, context=None):
        z, log_det = self.prqct(z, context)
        return z, log_det.view(-1)
__init__(num_input_channels, num_blocks, num_hidden_channels, num_context_channels=None, num_bins=8, tails='linear', tail_bound=3.0, activation=nn.ReLU, dropout_probability=0.0, reverse_mask=False, init_identity=True)

Constructor

Parameters:

Name Type Description Default
num_input_channels int

Flow dimension

required
num_blocks int

Number of residual blocks of the parameter NN

required
num_hidden_channels int

Number of hidden units of the NN

required
num_context_channels int

Number of context/conditional channels

None
num_bins int

Number of bins

8
tails str

Behaviour of the tails of the distribution, can be linear, circular for periodic distribution, or None for distribution on the compact interval

'linear'
tail_bound float

Bound of the spline tails

3.0
activation torch module

Activation function

ReLU
dropout_probability float

Dropout probability of the NN

0.0
reverse_mask bool

Flag whether the reverse mask should be used

False
init_identity bool

Flag, initialize transform as identity

True
Source code in normflows/flows/neural_spline/wrapper.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def __init__(
    self,
    num_input_channels,
    num_blocks,
    num_hidden_channels,
    num_context_channels=None,
    num_bins=8,
    tails="linear",
    tail_bound=3.0,
    activation=nn.ReLU,
    dropout_probability=0.0,
    reverse_mask=False,
    init_identity=True,
):
    """Constructor

    Args:
      num_input_channels (int): Flow dimension
      num_blocks (int): Number of residual blocks of the parameter NN
      num_hidden_channels (int): Number of hidden units of the NN
      num_context_channels (int): Number of context/conditional channels
      num_bins (int): Number of bins
      tails (str): Behaviour of the tails of the distribution, can be linear, circular for periodic distribution, or None for distribution on the compact interval
      tail_bound (float): Bound of the spline tails
      activation (torch module): Activation function
      dropout_probability (float): Dropout probability of the NN
      reverse_mask (bool): Flag whether the reverse mask should be used
      init_identity (bool): Flag, initialize transform as identity
    """
    super().__init__()

    def transform_net_create_fn(in_features, out_features):
        net = ResidualNet(
            in_features=in_features,
            out_features=out_features,
            context_features=num_context_channels,
            hidden_features=num_hidden_channels,
            num_blocks=num_blocks,
            activation=activation(),
            dropout_probability=dropout_probability,
            use_batch_norm=False,
        )
        if init_identity:
            torch.nn.init.constant_(net.final_layer.weight, 0.0)
            torch.nn.init.constant_(
                net.final_layer.bias, np.log(np.exp(1 - DEFAULT_MIN_DERIVATIVE) - 1)
            )
        return net

    self.prqct = PiecewiseRationalQuadraticCoupling(
        mask=create_alternating_binary_mask(num_input_channels, even=reverse_mask),
        transform_net_create_fn=transform_net_create_fn,
        num_bins=num_bins,
        tails=tails,
        tail_bound=tail_bound,
        # Setting True corresponds to equations (4), (5), (6) in the NSF paper:
        apply_unconditional_transform=True,
    )

normalization

ActNorm

Bases: AffineConstFlow

An AffineConstFlow but with a data-dependent initialization, where on the very first batch we clever initialize the s,t so that the output is unit gaussian. As described in Glow paper.

Source code in normflows/flows/normalization.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class ActNorm(AffineConstFlow):
    """
    An AffineConstFlow but with a data-dependent initialization,
    where on the very first batch we clever initialize the s,t so that the output
    is unit gaussian. As described in Glow paper.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.data_dep_init_done_cpu = torch.tensor(0.0)
        self.register_buffer("data_dep_init_done", self.data_dep_init_done_cpu)

    def forward(self, z):
        # first batch is used for initialization, c.f. batchnorm
        if not self.data_dep_init_done > 0.0:
            assert self.s is not None and self.t is not None
            s_init = -torch.log(z.std(dim=self.batch_dims, keepdim=True) + 1e-6)
            self.s.data = s_init.data
            self.t.data = (
                -z.mean(dim=self.batch_dims, keepdim=True) * torch.exp(self.s)
            ).data
            self.data_dep_init_done = torch.tensor(1.0)
        return super().forward(z)

    def inverse(self, z):
        # first batch is used for initialization, c.f. batchnorm
        if not self.data_dep_init_done:
            assert self.s is not None and self.t is not None
            s_init = torch.log(z.std(dim=self.batch_dims, keepdim=True) + 1e-6)
            self.s.data = s_init.data
            self.t.data = z.mean(dim=self.batch_dims, keepdim=True).data
            self.data_dep_init_done = torch.tensor(1.0)
        return super().inverse(z)

BatchNorm

Bases: Flow

Batch Normalization with out considering the derivatives of the batch statistics, see arXiv: 1605.08803

Source code in normflows/flows/normalization.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
class BatchNorm(Flow):
    """
    Batch Normalization with out considering the derivatives of the batch statistics, see [arXiv: 1605.08803](https://arxiv.org/abs/1605.08803)
    """

    def __init__(self, eps=1.0e-10):
        super().__init__()
        self.eps_cpu = torch.tensor(eps)
        self.register_buffer("eps", self.eps_cpu)

    def forward(self, z):
        """
        Do batch norm over batch and sample dimension
        """
        mean = torch.mean(z, dim=0, keepdims=True)
        std = torch.std(z, dim=0, keepdims=True)
        z_ = (z - mean) / torch.sqrt(std**2 + self.eps)
        log_det = torch.log(1 / torch.prod(torch.sqrt(std**2 + self.eps))).repeat(
            z.size()[0]
        )
        return z_, log_det
forward(z)

Do batch norm over batch and sample dimension

Source code in normflows/flows/normalization.py
52
53
54
55
56
57
58
59
60
61
62
def forward(self, z):
    """
    Do batch norm over batch and sample dimension
    """
    mean = torch.mean(z, dim=0, keepdims=True)
    std = torch.std(z, dim=0, keepdims=True)
    z_ = (z - mean) / torch.sqrt(std**2 + self.eps)
    log_det = torch.log(1 / torch.prod(torch.sqrt(std**2 + self.eps))).repeat(
        z.size()[0]
    )
    return z_, log_det

periodic

PeriodicShift

Bases: Flow

Shift and wrap periodic coordinates

Source code in normflows/flows/periodic.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
class PeriodicShift(Flow):
    """
    Shift and wrap periodic coordinates
    """

    def __init__(self, ind, bound=1.0, shift=0.0):
        """Constructor

        Args:
          ind: Iterable, indices of coordinates to be mapped
          bound: Float or iterable, bound of interval
          shift: Tensor, shift to be applied
        """
        super().__init__()
        self.ind = ind
        if torch.is_tensor(bound):
            self.register_buffer("bound", bound)
        else:
            self.bound = bound
        if torch.is_tensor(shift):
            self.register_buffer("shift", shift)
        else:
            self.shift = shift

    def forward(self, z):
        z_ = z.clone()
        z_[..., self.ind] = (
            torch.remainder(z_[..., self.ind] + self.shift + self.bound, 2 * self.bound)
            - self.bound
        )
        return z_, torch.zeros(len(z), dtype=z.dtype, device=z.device)

    def inverse(self, z):
        z_ = z.clone()
        z_[..., self.ind] = (
            torch.remainder(z_[..., self.ind] - self.shift + self.bound, 2 * self.bound)
            - self.bound
        )
        return z_, torch.zeros(len(z), dtype=z.dtype, device=z.device)
__init__(ind, bound=1.0, shift=0.0)

Constructor

Parameters:

Name Type Description Default
ind

Iterable, indices of coordinates to be mapped

required
bound

Float or iterable, bound of interval

1.0
shift

Tensor, shift to be applied

0.0
Source code in normflows/flows/periodic.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def __init__(self, ind, bound=1.0, shift=0.0):
    """Constructor

    Args:
      ind: Iterable, indices of coordinates to be mapped
      bound: Float or iterable, bound of interval
      shift: Tensor, shift to be applied
    """
    super().__init__()
    self.ind = ind
    if torch.is_tensor(bound):
        self.register_buffer("bound", bound)
    else:
        self.bound = bound
    if torch.is_tensor(shift):
        self.register_buffer("shift", shift)
    else:
        self.shift = shift

PeriodicWrap

Bases: Flow

Map periodic coordinates to fixed interval

Source code in normflows/flows/periodic.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class PeriodicWrap(Flow):
    """
    Map periodic coordinates to fixed interval
    """

    def __init__(self, ind, bound=1.0):
        """Constructor

        ind: Iterable, indices of coordinates to be mapped
        bound: Float or iterable, bound of interval
        """
        super().__init__()
        self.ind = ind
        if torch.is_tensor(bound):
            self.register_buffer("bound", bound)
        else:
            self.bound = bound

    def forward(self, z):
        return z, torch.zeros(len(z), dtype=z.dtype, device=z.device)

    def inverse(self, z):
        z_ = z.clone()
        z_[..., self.ind] = (
            torch.remainder(z_[..., self.ind] + self.bound, 2 * self.bound) - self.bound
        )
        return z_, torch.zeros(len(z), dtype=z.dtype, device=z.device)
__init__(ind, bound=1.0)

Constructor

ind: Iterable, indices of coordinates to be mapped bound: Float or iterable, bound of interval

Source code in normflows/flows/periodic.py
11
12
13
14
15
16
17
18
19
20
21
22
def __init__(self, ind, bound=1.0):
    """Constructor

    ind: Iterable, indices of coordinates to be mapped
    bound: Float or iterable, bound of interval
    """
    super().__init__()
    self.ind = ind
    if torch.is_tensor(bound):
        self.register_buffer("bound", bound)
    else:
        self.bound = bound

planar

Planar

Bases: Flow

Planar flow as introduced in arXiv: 1505.05770

    f(z) = z + u * h(w * z + b)
Source code in normflows/flows/planar.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
class Planar(Flow):
    """Planar flow as introduced in [arXiv: 1505.05770](https://arxiv.org/abs/1505.05770)

    ```
        f(z) = z + u * h(w * z + b)
    ```
    """

    def __init__(self, shape, act="tanh", u=None, w=None, b=None):
        """Constructor of the planar flow

        Args:
          shape: shape of the latent variable z
          h: nonlinear function h of the planar flow (see definition of f above)
          u,w,b: optional initialization for parameters
        """
        super().__init__()
        lim_w = np.sqrt(2.0 / np.prod(shape))
        lim_u = np.sqrt(2)

        if u is not None:
            self.u = nn.Parameter(u)
        else:
            self.u = nn.Parameter(torch.empty(shape)[None])
            nn.init.uniform_(self.u, -lim_u, lim_u)
        if w is not None:
            self.w = nn.Parameter(w)
        else:
            self.w = nn.Parameter(torch.empty(shape)[None])
            nn.init.uniform_(self.w, -lim_w, lim_w)
        if b is not None:
            self.b = nn.Parameter(b)
        else:
            self.b = nn.Parameter(torch.zeros(1))

        self.act = act
        if act == "tanh":
            self.h = torch.tanh
        elif act == "leaky_relu":
            self.h = torch.nn.LeakyReLU(negative_slope=0.2)
        else:
            raise NotImplementedError("Nonlinearity is not implemented.")

    def forward(self, z):
        lin = torch.sum(self.w * z, list(range(1, self.w.dim())),
                        keepdim=True) + self.b
        inner = torch.sum(self.w * self.u)
        u = self.u + (torch.log(1 + torch.exp(inner)) - 1 - inner) \
            * self.w / torch.sum(self.w ** 2)  # constraint w.T * u > -1
        if self.act == "tanh":
            h_ = lambda x: 1 / torch.cosh(x) ** 2
        elif self.act == "leaky_relu":
            h_ = lambda x: (x < 0) * (self.h.negative_slope - 1.0) + 1.0

        z_ = z + u * self.h(lin)
        log_det = torch.log(torch.abs(1 + torch.sum(self.w * u) * h_(lin.reshape(-1))))
        return z_, log_det

    def inverse(self, z):
        if self.act != "leaky_relu":
            raise NotImplementedError("This flow has no algebraic inverse.")
        lin = torch.sum(self.w * z, list(range(1, self.w.dim()))) + self.b
        a = (lin < 0) * (
            self.h.negative_slope - 1.0
        ) + 1.0  # absorb leakyReLU slope into u
        inner = torch.sum(self.w * self.u)
        u = self.u + (torch.log(1 + torch.exp(inner)) - 1 - inner) \
            * self.w / torch.sum(self.w ** 2)
        dims = [-1] + (u.dim() - 1) * [1]
        u = a.reshape(*dims) * u
        inner_ = torch.sum(self.w * u, list(range(1, self.w.dim())))
        z_ = z - u * (lin / (1 + inner_)).reshape(*dims)
        log_det = -torch.log(torch.abs(1 + inner_))
        return z_, log_det
__init__(shape, act='tanh', u=None, w=None, b=None)

Constructor of the planar flow

Parameters:

Name Type Description Default
shape

shape of the latent variable z

required
h

nonlinear function h of the planar flow (see definition of f above)

required
u,w,b

optional initialization for parameters

required
Source code in normflows/flows/planar.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def __init__(self, shape, act="tanh", u=None, w=None, b=None):
    """Constructor of the planar flow

    Args:
      shape: shape of the latent variable z
      h: nonlinear function h of the planar flow (see definition of f above)
      u,w,b: optional initialization for parameters
    """
    super().__init__()
    lim_w = np.sqrt(2.0 / np.prod(shape))
    lim_u = np.sqrt(2)

    if u is not None:
        self.u = nn.Parameter(u)
    else:
        self.u = nn.Parameter(torch.empty(shape)[None])
        nn.init.uniform_(self.u, -lim_u, lim_u)
    if w is not None:
        self.w = nn.Parameter(w)
    else:
        self.w = nn.Parameter(torch.empty(shape)[None])
        nn.init.uniform_(self.w, -lim_w, lim_w)
    if b is not None:
        self.b = nn.Parameter(b)
    else:
        self.b = nn.Parameter(torch.zeros(1))

    self.act = act
    if act == "tanh":
        self.h = torch.tanh
    elif act == "leaky_relu":
        self.h = torch.nn.LeakyReLU(negative_slope=0.2)
    else:
        raise NotImplementedError("Nonlinearity is not implemented.")

radial

Radial

Bases: Flow

Radial flow as introduced in arXiv: 1505.05770

    f(z) = z + beta * h(alpha, r) * (z - z_0)
Source code in normflows/flows/radial.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class Radial(Flow):
    """Radial flow as introduced in [arXiv: 1505.05770](https://arxiv.org/abs/1505.05770)

    ```
        f(z) = z + beta * h(alpha, r) * (z - z_0)
    ```
    """

    def __init__(self, shape, z_0=None):
        """Constructor of the radial flow

        Args:
          shape: shape of the latent variable z
          z_0: parameter of the radial flow
        """
        super().__init__()
        self.d_cpu = torch.prod(torch.tensor(shape))
        self.register_buffer("d", self.d_cpu)
        self.beta = nn.Parameter(torch.empty(1))
        lim = 1.0 / np.prod(shape)
        nn.init.uniform_(self.beta, -lim - 1.0, lim - 1.0)
        self.alpha = nn.Parameter(torch.empty(1))
        nn.init.uniform_(self.alpha, -lim, lim)

        if z_0 is not None:
            self.z_0 = nn.Parameter(z_0)
        else:
            self.z_0 = nn.Parameter(torch.randn(shape)[None])

    def forward(self, z):
        beta = torch.log(1 + torch.exp(self.beta)) - torch.abs(self.alpha)
        dz = z - self.z_0
        r = torch.linalg.vector_norm(dz, dim=list(range(1, self.z_0.dim())), keepdim=True)
        h_arr = beta / (torch.abs(self.alpha) + r)
        h_arr_ = -beta * r / (torch.abs(self.alpha) + r) ** 2
        z_ = z + h_arr * dz
        log_det = (self.d - 1) * torch.log(1 + h_arr) + torch.log(1 + h_arr + h_arr_)
        log_det = log_det.reshape(-1)
        return z_, log_det
__init__(shape, z_0=None)

Constructor of the radial flow

Parameters:

Name Type Description Default
shape

shape of the latent variable z

required
z_0

parameter of the radial flow

None
Source code in normflows/flows/radial.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def __init__(self, shape, z_0=None):
    """Constructor of the radial flow

    Args:
      shape: shape of the latent variable z
      z_0: parameter of the radial flow
    """
    super().__init__()
    self.d_cpu = torch.prod(torch.tensor(shape))
    self.register_buffer("d", self.d_cpu)
    self.beta = nn.Parameter(torch.empty(1))
    lim = 1.0 / np.prod(shape)
    nn.init.uniform_(self.beta, -lim - 1.0, lim - 1.0)
    self.alpha = nn.Parameter(torch.empty(1))
    nn.init.uniform_(self.alpha, -lim, lim)

    if z_0 is not None:
        self.z_0 = nn.Parameter(z_0)
    else:
        self.z_0 = nn.Parameter(torch.randn(shape)[None])

reshape

Merge

Bases: Split

Same as Split but with forward and backward pass interchanged

Source code in normflows/flows/reshape.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
class Merge(Split):
    """
    Same as Split but with forward and backward pass interchanged
    """

    def __init__(self, mode="channel"):
        super().__init__(mode)

    def forward(self, z):
        return super().inverse(z)

    def inverse(self, z):
        return super().forward(z)

Split

Bases: Flow

Split features into two sets

Source code in normflows/flows/reshape.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
class Split(Flow):
    """
    Split features into two sets
    """

    def __init__(self, mode="channel"):
        """Constructor

        The splitting mode can be:

        - channel: Splits first feature dimension, usually channels, into two halfs
        - channel_inv: Same as channel, but with z1 and z2 flipped
        - checkerboard: Splits features using a checkerboard pattern (last feature dimension must be even)
        - checkerboard_inv: Same as checkerboard, but with inverted coloring

        Args:
         mode: splitting mode
        """
        super().__init__()
        self.mode = mode

    def forward(self, z):
        if self.mode == "channel":
            z1, z2 = z.chunk(2, dim=1)
        elif self.mode == "channel_inv":
            z2, z1 = z.chunk(2, dim=1)
        elif "checkerboard" in self.mode:
            n_dims = z.dim()
            cb0 = 0
            cb1 = 1
            for i in range(1, n_dims):
                cb0_ = cb0
                cb1_ = cb1
                cb0 = [cb0_ if j % 2 == 0 else cb1_ for j in range(z.size(n_dims - i))]
                cb1 = [cb1_ if j % 2 == 0 else cb0_ for j in range(z.size(n_dims - i))]
            cb = cb1 if "inv" in self.mode else cb0
            cb = torch.tensor(cb)[None].repeat(len(z), *((n_dims - 1) * [1]))
            cb = cb.to(z.device)
            z_size = z.size()
            z1 = z.reshape(-1)[torch.nonzero(cb.view(-1), as_tuple=False)].view(
                *z_size[:-1], -1
            )
            z2 = z.reshape(-1)[torch.nonzero((1 - cb).view(-1), as_tuple=False)].view(
                *z_size[:-1], -1
            )
        else:
            raise NotImplementedError("Mode " + self.mode + " is not implemented.")
        log_det = 0
        return [z1, z2], log_det

    def inverse(self, z):
        z1, z2 = z
        if self.mode == "channel":
            z = torch.cat([z1, z2], 1)
        elif self.mode == "channel_inv":
            z = torch.cat([z2, z1], 1)
        elif "checkerboard" in self.mode:
            n_dims = z1.dim()
            z_size = list(z1.size())
            z_size[-1] *= 2
            cb0 = 0
            cb1 = 1
            for i in range(1, n_dims):
                cb0_ = cb0
                cb1_ = cb1
                cb0 = [cb0_ if j % 2 == 0 else cb1_ for j in range(z_size[n_dims - i])]
                cb1 = [cb1_ if j % 2 == 0 else cb0_ for j in range(z_size[n_dims - i])]
            cb = cb1 if "inv" in self.mode else cb0
            cb = torch.tensor(cb)[None].repeat(z_size[0], *((n_dims - 1) * [1]))
            cb = cb.to(z1.device)
            z1 = z1[..., None].repeat(*(n_dims * [1]), 2).view(*z_size[:-1], -1)
            z2 = z2[..., None].repeat(*(n_dims * [1]), 2).view(*z_size[:-1], -1)
            z = cb * z1 + (1 - cb) * z2
        else:
            raise NotImplementedError("Mode " + self.mode + " is not implemented.")
        log_det = 0
        return z, log_det
__init__(mode='channel')

Constructor

The splitting mode can be:

  • channel: Splits first feature dimension, usually channels, into two halfs
  • channel_inv: Same as channel, but with z1 and z2 flipped
  • checkerboard: Splits features using a checkerboard pattern (last feature dimension must be even)
  • checkerboard_inv: Same as checkerboard, but with inverted coloring

Parameters:

Name Type Description Default
mode

splitting mode

'channel'
Source code in normflows/flows/reshape.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def __init__(self, mode="channel"):
    """Constructor

    The splitting mode can be:

    - channel: Splits first feature dimension, usually channels, into two halfs
    - channel_inv: Same as channel, but with z1 and z2 flipped
    - checkerboard: Splits features using a checkerboard pattern (last feature dimension must be even)
    - checkerboard_inv: Same as checkerboard, but with inverted coloring

    Args:
     mode: splitting mode
    """
    super().__init__()
    self.mode = mode

Squeeze

Bases: Flow

Squeeze operation of multi-scale architecture, RealNVP or Glow paper

Source code in normflows/flows/reshape.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
class Squeeze(Flow):
    """
    Squeeze operation of multi-scale architecture, RealNVP or Glow paper
    """

    def __init__(self):
        """
        Constructor
        """
        super().__init__()

    def forward(self, z):
        log_det = 0
        s = z.size()
        z = z.view(s[0], s[1] // 4, 2, 2, s[2], s[3])
        z = z.permute(0, 1, 4, 2, 5, 3).contiguous()
        z = z.view(s[0], s[1] // 4, 2 * s[2], 2 * s[3])
        return z, log_det

    def inverse(self, z):
        log_det = 0
        s = z.size()
        z = z.view(*s[:2], s[2] // 2, 2, s[3] // 2, 2)
        z = z.permute(0, 1, 3, 5, 2, 4).contiguous()
        z = z.view(s[0], 4 * s[1], s[2] // 2, s[3] // 2)
        return z, log_det
__init__()

Constructor

Source code in normflows/flows/reshape.py
108
109
110
111
112
def __init__(self):
    """
    Constructor
    """
    super().__init__()

residual

Residual

Bases: Flow

Invertible residual net block, wrapper to the implementation of Chen et al., see sources

Source code in normflows/flows/residual.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
class Residual(Flow):
    """
    Invertible residual net block, wrapper to the implementation of Chen et al.,
    see [sources](https://github.com/rtqichen/residual-flows)
    """

    def __init__(
        self,
        net,
        reverse=True,
        reduce_memory=True,
        geom_p=0.5,
        lamb=2.0,
        n_power_series=None,
        exact_trace=False,
        brute_force=False,
        n_samples=1,
        n_exact_terms=2,
        n_dist="geometric"
    ):
        """Constructor

        Args:
          net: Neural network, must be Lipschitz continuous with L < 1
          reverse: Flag, if true the map ```f(x) = x + net(x)``` is applied in the inverse pass, otherwise it is done in forward
          reduce_memory: Flag, if true Neumann series and precomputations, for backward pass in forward pass are done
          geom_p: Parameter of the geometric distribution used for the Neumann series
          lamb: Parameter of the geometric distribution used for the Neumann series
          n_power_series: Number of terms in the Neumann series
          exact_trace: Flag, if true the trace of the Jacobian is computed exactly
          brute_force: Flag, if true the Jacobian is computed exactly in 2D
          n_samples: Number of samples used to estimate power series
          n_exact_terms: Number of terms always included in the power series
          n_dist: Distribution used for the power series, either "geometric" or "poisson"
        """
        super().__init__()
        self.reverse = reverse
        self.iresblock = iResBlock(
            net,
            n_samples=n_samples,
            n_exact_terms=n_exact_terms,
            neumann_grad=reduce_memory,
            grad_in_forward=reduce_memory,
            exact_trace=exact_trace,
            geom_p=geom_p,
            lamb=lamb,
            n_power_series=n_power_series,
            brute_force=brute_force,
            n_dist=n_dist,
        )

    def forward(self, z):
        if self.reverse:
            z, log_det = self.iresblock.inverse(z, 0)
        else:
            z, log_det = self.iresblock.forward(z, 0)
        return z, -log_det.view(-1)

    def inverse(self, z):
        if self.reverse:
            z, log_det = self.iresblock.forward(z, 0)
        else:
            z, log_det = self.iresblock.inverse(z, 0)
        return z, -log_det.view(-1)
__init__(net, reverse=True, reduce_memory=True, geom_p=0.5, lamb=2.0, n_power_series=None, exact_trace=False, brute_force=False, n_samples=1, n_exact_terms=2, n_dist='geometric')

Constructor

Parameters:

Name Type Description Default
net

Neural network, must be Lipschitz continuous with L < 1

required
reverse

Flag, if true the map f(x) = x + net(x) is applied in the inverse pass, otherwise it is done in forward

True
reduce_memory

Flag, if true Neumann series and precomputations, for backward pass in forward pass are done

True
geom_p

Parameter of the geometric distribution used for the Neumann series

0.5
lamb

Parameter of the geometric distribution used for the Neumann series

2.0
n_power_series

Number of terms in the Neumann series

None
exact_trace

Flag, if true the trace of the Jacobian is computed exactly

False
brute_force

Flag, if true the Jacobian is computed exactly in 2D

False
n_samples

Number of samples used to estimate power series

1
n_exact_terms

Number of terms always included in the power series

2
n_dist

Distribution used for the power series, either "geometric" or "poisson"

'geometric'
Source code in normflows/flows/residual.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def __init__(
    self,
    net,
    reverse=True,
    reduce_memory=True,
    geom_p=0.5,
    lamb=2.0,
    n_power_series=None,
    exact_trace=False,
    brute_force=False,
    n_samples=1,
    n_exact_terms=2,
    n_dist="geometric"
):
    """Constructor

    Args:
      net: Neural network, must be Lipschitz continuous with L < 1
      reverse: Flag, if true the map ```f(x) = x + net(x)``` is applied in the inverse pass, otherwise it is done in forward
      reduce_memory: Flag, if true Neumann series and precomputations, for backward pass in forward pass are done
      geom_p: Parameter of the geometric distribution used for the Neumann series
      lamb: Parameter of the geometric distribution used for the Neumann series
      n_power_series: Number of terms in the Neumann series
      exact_trace: Flag, if true the trace of the Jacobian is computed exactly
      brute_force: Flag, if true the Jacobian is computed exactly in 2D
      n_samples: Number of samples used to estimate power series
      n_exact_terms: Number of terms always included in the power series
      n_dist: Distribution used for the power series, either "geometric" or "poisson"
    """
    super().__init__()
    self.reverse = reverse
    self.iresblock = iResBlock(
        net,
        n_samples=n_samples,
        n_exact_terms=n_exact_terms,
        neumann_grad=reduce_memory,
        grad_in_forward=reduce_memory,
        exact_trace=exact_trace,
        geom_p=geom_p,
        lamb=lamb,
        n_power_series=n_power_series,
        brute_force=brute_force,
        n_dist=n_dist,
    )

iResBlock

Bases: Module

Source code in normflows/flows/residual.py
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
class iResBlock(nn.Module):
    def __init__(
        self,
        nnet,
        geom_p=0.5,
        lamb=2.0,
        n_power_series=None,
        exact_trace=False,
        brute_force=False,
        n_samples=1,
        n_exact_terms=2,
        n_dist="geometric",
        neumann_grad=True,
        grad_in_forward=False,
    ):
        """
        Args:
            nnet: a nn.Module
            n_power_series: number of power series. If not None, uses a biased approximation to logdet.
            exact_trace: if False, uses a Hutchinson trace estimator. Otherwise computes the exact full Jacobian.
            brute_force: Computes the exact logdet. Only available for 2D inputs.
        """
        nn.Module.__init__(self)
        self.nnet = nnet
        self.n_dist = n_dist
        self.geom_p = nn.Parameter(torch.tensor(np.log(geom_p) - np.log(1.0 - geom_p)))
        self.lamb = nn.Parameter(torch.tensor(lamb))
        self.n_samples = n_samples
        self.n_power_series = n_power_series
        self.exact_trace = exact_trace
        self.brute_force = brute_force
        self.n_exact_terms = n_exact_terms
        self.grad_in_forward = grad_in_forward
        self.neumann_grad = neumann_grad

        # store the samples of n.
        self.register_buffer("last_n_samples", torch.zeros(self.n_samples))
        self.register_buffer("last_firmom", torch.zeros(1))
        self.register_buffer("last_secmom", torch.zeros(1))

    def forward(self, x, logpx=None):
        if logpx is None:
            y = x + self.nnet(x)
            return y
        else:
            g, logdetgrad = self._logdetgrad(x)
            return x + g, logpx - logdetgrad

    def inverse(self, y, logpy=None):
        x = self._inverse_fixed_point(y)
        if logpy is None:
            return x
        else:
            return x, logpy + self._logdetgrad(x)[1]

    def _inverse_fixed_point(self, y, atol=1e-5, rtol=1e-5):
        x, x_prev = y - self.nnet(y), y
        i = 0
        tol = atol + y.abs() * rtol
        while not torch.all((x - x_prev) ** 2 / tol < 1):
            x, x_prev = y - self.nnet(x), x
            i += 1
            if i > 1000:
                break
        return x

    def _logdetgrad(self, x):
        """Returns g(x) and ```logdet|d(x+g(x))/dx|```"""

        with torch.enable_grad():
            if (self.brute_force or not self.training) and (
                x.ndimension() == 2 and x.shape[1] == 2
            ):
                ###########################################
                # Brute-force compute Jacobian determinant.
                ###########################################
                x = x.requires_grad_(True)
                g = self.nnet(x)
                # Brute-force logdet only available for 2D.
                jac = batch_jacobian(g, x)
                batch_dets = (jac[:, 0, 0] + 1) * (jac[:, 1, 1] + 1) - jac[
                    :, 0, 1
                ] * jac[:, 1, 0]
                return g, torch.log(torch.abs(batch_dets)).view(-1, 1)

            if self.n_dist == "geometric":
                geom_p = torch.sigmoid(self.geom_p).item()
                sample_fn = lambda m: geometric_sample(geom_p, m)
                rcdf_fn = lambda k, offset: geometric_1mcdf(geom_p, k, offset)
            elif self.n_dist == "poisson":
                lamb = self.lamb.item()
                sample_fn = lambda m: poisson_sample(lamb, m)
                rcdf_fn = lambda k, offset: poisson_1mcdf(lamb, k, offset)

            if self.training:
                if self.n_power_series is None:
                    # Unbiased estimation.
                    lamb = self.lamb.item()
                    n_samples = sample_fn(self.n_samples)
                    n_power_series = max(n_samples) + self.n_exact_terms
                    coeff_fn = (
                        lambda k: 1
                        / rcdf_fn(k, self.n_exact_terms)
                        * sum(n_samples >= k - self.n_exact_terms)
                        / len(n_samples)
                    )
                else:
                    # Truncated estimation.
                    n_power_series = self.n_power_series
                    coeff_fn = lambda k: 1.0
            else:
                # Unbiased estimation with more exact terms.
                lamb = self.lamb.item()
                n_samples = sample_fn(self.n_samples)
                n_power_series = max(n_samples) + 20
                coeff_fn = (
                    lambda k: 1
                    / rcdf_fn(k, 20)
                    * sum(n_samples >= k - 20)
                    / len(n_samples)
                )

            if not self.exact_trace:
                ####################################
                # Power series with trace estimator.
                ####################################
                vareps = torch.randn_like(x)

                # Choose the type of estimator.
                if self.training and self.neumann_grad:
                    estimator_fn = neumann_logdet_estimator
                else:
                    estimator_fn = basic_logdet_estimator

                # Do backprop-in-forward to save memory.
                if self.training and self.grad_in_forward:
                    g, logdetgrad = mem_eff_wrapper(
                        estimator_fn,
                        self.nnet,
                        x,
                        n_power_series,
                        vareps,
                        coeff_fn,
                        self.training,
                    )
                else:
                    x = x.requires_grad_(True)
                    g = self.nnet(x)
                    logdetgrad = estimator_fn(
                        g, x, n_power_series, vareps, coeff_fn, self.training
                    )
            else:
                ############################################
                # Power series with exact trace computation.
                ############################################
                x = x.requires_grad_(True)
                g = self.nnet(x)
                jac = batch_jacobian(g, x)
                logdetgrad = batch_trace(jac)
                jac_k = jac
                for k in range(2, n_power_series + 1):
                    jac_k = torch.bmm(jac, jac_k)
                    logdetgrad = logdetgrad + (-1) ** (k + 1) / k * coeff_fn(
                        k
                    ) * batch_trace(jac_k)

            if self.training and self.n_power_series is None:
                self.last_n_samples.copy_(
                    torch.tensor(n_samples).to(self.last_n_samples)
                )
                estimator = logdetgrad.detach()
                self.last_firmom.copy_(torch.mean(estimator).to(self.last_firmom))
                self.last_secmom.copy_(torch.mean(estimator**2).to(self.last_secmom))
            return g, logdetgrad.view(-1, 1)

    def extra_repr(self):
        return "dist={}, n_samples={}, n_power_series={}, neumann_grad={}, exact_trace={}, brute_force={}".format(
            self.n_dist,
            self.n_samples,
            self.n_power_series,
            self.neumann_grad,
            self.exact_trace,
            self.brute_force,
        )
__init__(nnet, geom_p=0.5, lamb=2.0, n_power_series=None, exact_trace=False, brute_force=False, n_samples=1, n_exact_terms=2, n_dist='geometric', neumann_grad=True, grad_in_forward=False)

Parameters:

Name Type Description Default
nnet

a nn.Module

required
n_power_series

number of power series. If not None, uses a biased approximation to logdet.

None
exact_trace

if False, uses a Hutchinson trace estimator. Otherwise computes the exact full Jacobian.

False
brute_force

Computes the exact logdet. Only available for 2D inputs.

False
Source code in normflows/flows/residual.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def __init__(
    self,
    nnet,
    geom_p=0.5,
    lamb=2.0,
    n_power_series=None,
    exact_trace=False,
    brute_force=False,
    n_samples=1,
    n_exact_terms=2,
    n_dist="geometric",
    neumann_grad=True,
    grad_in_forward=False,
):
    """
    Args:
        nnet: a nn.Module
        n_power_series: number of power series. If not None, uses a biased approximation to logdet.
        exact_trace: if False, uses a Hutchinson trace estimator. Otherwise computes the exact full Jacobian.
        brute_force: Computes the exact logdet. Only available for 2D inputs.
    """
    nn.Module.__init__(self)
    self.nnet = nnet
    self.n_dist = n_dist
    self.geom_p = nn.Parameter(torch.tensor(np.log(geom_p) - np.log(1.0 - geom_p)))
    self.lamb = nn.Parameter(torch.tensor(lamb))
    self.n_samples = n_samples
    self.n_power_series = n_power_series
    self.exact_trace = exact_trace
    self.brute_force = brute_force
    self.n_exact_terms = n_exact_terms
    self.grad_in_forward = grad_in_forward
    self.neumann_grad = neumann_grad

    # store the samples of n.
    self.register_buffer("last_n_samples", torch.zeros(self.n_samples))
    self.register_buffer("last_firmom", torch.zeros(1))
    self.register_buffer("last_secmom", torch.zeros(1))

stochastic

HamiltonianMonteCarlo

Bases: Flow

Flow layer using the HMC proposal in Stochastic Normalising Flows

See arXiv: 2002.06707

Source code in normflows/flows/stochastic.py
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
class HamiltonianMonteCarlo(Flow):
    """Flow layer using the HMC proposal in Stochastic Normalising Flows

    See [arXiv: 2002.06707](https://arxiv.org/abs/2002.06707)
    """

    def __init__(self, target, steps, log_step_size, log_mass, max_abs_grad=None):
        """Constructor

        Args:
          target: The stationary distribution of this Markov transition, i.e. the target distribution to sample from.
          steps: The number of leapfrog steps
          log_step_size: The log step size used in the leapfrog integrator. shape (dim)
          log_mass: The log_mass determining the variance of the momentum samples. shape (dim)
          max_abs_grad: Maximum absolute value of the gradient of the target distribution's log probability. If set to None then no gradient clipping is applied. Useful for improving numerical stability."""
        super().__init__()
        self.target = target
        self.steps = steps
        self.register_parameter("log_step_size", torch.nn.Parameter(log_step_size))
        self.register_parameter("log_mass", torch.nn.Parameter(log_mass))
        self.max_abs_grad = max_abs_grad

    def forward(self, z):
        # Draw momentum
        p = torch.randn_like(z) * torch.exp(0.5 * self.log_mass)

        # leapfrog
        z_new = z.clone()
        p_new = p.clone()
        step_size = torch.exp(self.log_step_size)
        for i in range(self.steps):
            p_half = p_new - (step_size / 2.0) * -self.gradlogP(z_new)
            z_new = z_new + step_size * (p_half / torch.exp(self.log_mass))
            p_new = p_half - (step_size / 2.0) * -self.gradlogP(z_new)

        # Metropolis Hastings correction
        probabilities = torch.exp(
            self.target.log_prob(z_new)
            - self.target.log_prob(z)
            - 0.5 * torch.sum(p_new**2 / torch.exp(self.log_mass), 1)
            + 0.5 * torch.sum(p**2 / torch.exp(self.log_mass), 1)
        )
        uniforms = torch.rand_like(probabilities)
        mask = uniforms < probabilities
        z_out = torch.where(mask.unsqueeze(1), z_new, z)

        return z_out, self.target.log_prob(z) - self.target.log_prob(z_out)

    def inverse(self, z):
        return self.forward(z)

    def gradlogP(self, z):
        z_ = z.detach().requires_grad_()
        logp = self.target.log_prob(z_)
        grad = torch.autograd.grad(logp, z_, grad_outputs=torch.ones_like(logp))[0]
        if self.max_abs_grad:
            grad = torch.clamp(grad, max=self.max_abs_grad, min=-self.max_abs_grad)
        return grad
__init__(target, steps, log_step_size, log_mass, max_abs_grad=None)

Constructor

Parameters:

Name Type Description Default
target

The stationary distribution of this Markov transition, i.e. the target distribution to sample from.

required
steps

The number of leapfrog steps

required
log_step_size

The log step size used in the leapfrog integrator. shape (dim)

required
log_mass

The log_mass determining the variance of the momentum samples. shape (dim)

required
max_abs_grad

Maximum absolute value of the gradient of the target distribution's log probability. If set to None then no gradient clipping is applied. Useful for improving numerical stability.

None
Source code in normflows/flows/stochastic.py
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def __init__(self, target, steps, log_step_size, log_mass, max_abs_grad=None):
    """Constructor

    Args:
      target: The stationary distribution of this Markov transition, i.e. the target distribution to sample from.
      steps: The number of leapfrog steps
      log_step_size: The log step size used in the leapfrog integrator. shape (dim)
      log_mass: The log_mass determining the variance of the momentum samples. shape (dim)
      max_abs_grad: Maximum absolute value of the gradient of the target distribution's log probability. If set to None then no gradient clipping is applied. Useful for improving numerical stability."""
    super().__init__()
    self.target = target
    self.steps = steps
    self.register_parameter("log_step_size", torch.nn.Parameter(log_step_size))
    self.register_parameter("log_mass", torch.nn.Parameter(log_mass))
    self.max_abs_grad = max_abs_grad

MetropolisHastings

Bases: Flow

Sampling through Metropolis Hastings in Stochastic Normalizing Flow

See arXiv: 2002.06707

Source code in normflows/flows/stochastic.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class MetropolisHastings(Flow):
    """Sampling through Metropolis Hastings in Stochastic Normalizing Flow

    See [arXiv: 2002.06707](https://arxiv.org/abs/2002.06707)
    """

    def __init__(self, target, proposal, steps):
        """Constructor

        Args:
          target: The stationary distribution of this Markov transition, i.e. the target distribution to sample from.
          proposal: Proposal distribution
          steps: Number of MCMC steps to perform
        """
        super().__init__()
        self.target = target
        self.proposal = proposal
        self.steps = steps

    def forward(self, z):
        # Initialize number of samples and log(det)
        num_samples = len(z)
        log_det = torch.zeros(num_samples, dtype=z.dtype, device=z.device)
        # Get log(p) for current samples
        log_p = self.target.log_prob(z)
        for i in range(self.steps):
            # Make proposal and get log(p)
            z_, log_p_diff = self.proposal(z)
            log_p_ = self.target.log_prob(z_)
            # Make acceptance decision
            w = torch.rand(num_samples, dtype=z.dtype, device=z.device)
            log_w_accept = log_p_ - log_p + log_p_diff
            w_accept = torch.clamp(torch.exp(log_w_accept), max=1)
            accept = w <= w_accept
            # Update samples, log(det), and log(p)
            z = torch.where(accept.unsqueeze(1), z_, z)
            log_det_ = log_p - log_p_
            log_det = torch.where(accept, log_det + log_det_, log_det)
            log_p = torch.where(accept, log_p_, log_p)
        return z, log_det

    def inverse(self, z):
        # Equivalent to forward pass
        return self.forward(z)
__init__(target, proposal, steps)

Constructor

Parameters:

Name Type Description Default
target

The stationary distribution of this Markov transition, i.e. the target distribution to sample from.

required
proposal

Proposal distribution

required
steps

Number of MCMC steps to perform

required
Source code in normflows/flows/stochastic.py
12
13
14
15
16
17
18
19
20
21
22
23
def __init__(self, target, proposal, steps):
    """Constructor

    Args:
      target: The stationary distribution of this Markov transition, i.e. the target distribution to sample from.
      proposal: Proposal distribution
      steps: Number of MCMC steps to perform
    """
    super().__init__()
    self.target = target
    self.proposal = proposal
    self.steps = steps

nets

cnn

ConvNet2d

Bases: Module

Convolutional Neural Network with leaky ReLU nonlinearities

Source code in normflows/nets/cnn.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
class ConvNet2d(nn.Module):
    """
    Convolutional Neural Network with leaky ReLU nonlinearities
    """

    def __init__(
        self,
        channels,
        kernel_size,
        leaky=0.0,
        init_zeros=True,
        actnorm=False,
        weight_std=None,
    ):
        """Constructor

        Args:
          channels: List of channels of conv layers, first entry is in_channels
          kernel_size: List of kernel sizes, same for height and width
          leaky: Leaky part of ReLU
          init_zeros: Flag whether last layer shall be initialized with zeros
          scale_output: Flag whether to scale output with a log scale parameter
          logscale_factor: Constant factor to be multiplied to log scaling
          actnorm: Flag whether activation normalization shall be done after each conv layer except output
          weight_std: Fixed std used to initialize every layer
        """
        super().__init__()
        # Build network
        net = nn.ModuleList([])
        for i in range(len(kernel_size) - 1):
            conv = nn.Conv2d(
                channels[i],
                channels[i + 1],
                kernel_size[i],
                padding=kernel_size[i] // 2,
                bias=(not actnorm),
            )
            if weight_std is not None:
                conv.weight.data.normal_(mean=0.0, std=weight_std)
            net.append(conv)
            if actnorm:
                net.append(utils.ActNorm((channels[i + 1],) + (1, 1)))
            net.append(nn.LeakyReLU(leaky))
        i = len(kernel_size)
        net.append(
            nn.Conv2d(
                channels[i - 1],
                channels[i],
                kernel_size[i - 1],
                padding=kernel_size[i - 1] // 2,
            )
        )
        if init_zeros:
            nn.init.zeros_(net[-1].weight)
            nn.init.zeros_(net[-1].bias)
        self.net = nn.Sequential(*net)

    def forward(self, x):
        return self.net(x)
__init__(channels, kernel_size, leaky=0.0, init_zeros=True, actnorm=False, weight_std=None)

Constructor

Parameters:

Name Type Description Default
channels

List of channels of conv layers, first entry is in_channels

required
kernel_size

List of kernel sizes, same for height and width

required
leaky

Leaky part of ReLU

0.0
init_zeros

Flag whether last layer shall be initialized with zeros

True
scale_output

Flag whether to scale output with a log scale parameter

required
logscale_factor

Constant factor to be multiplied to log scaling

required
actnorm

Flag whether activation normalization shall be done after each conv layer except output

False
weight_std

Fixed std used to initialize every layer

None
Source code in normflows/nets/cnn.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def __init__(
    self,
    channels,
    kernel_size,
    leaky=0.0,
    init_zeros=True,
    actnorm=False,
    weight_std=None,
):
    """Constructor

    Args:
      channels: List of channels of conv layers, first entry is in_channels
      kernel_size: List of kernel sizes, same for height and width
      leaky: Leaky part of ReLU
      init_zeros: Flag whether last layer shall be initialized with zeros
      scale_output: Flag whether to scale output with a log scale parameter
      logscale_factor: Constant factor to be multiplied to log scaling
      actnorm: Flag whether activation normalization shall be done after each conv layer except output
      weight_std: Fixed std used to initialize every layer
    """
    super().__init__()
    # Build network
    net = nn.ModuleList([])
    for i in range(len(kernel_size) - 1):
        conv = nn.Conv2d(
            channels[i],
            channels[i + 1],
            kernel_size[i],
            padding=kernel_size[i] // 2,
            bias=(not actnorm),
        )
        if weight_std is not None:
            conv.weight.data.normal_(mean=0.0, std=weight_std)
        net.append(conv)
        if actnorm:
            net.append(utils.ActNorm((channels[i + 1],) + (1, 1)))
        net.append(nn.LeakyReLU(leaky))
    i = len(kernel_size)
    net.append(
        nn.Conv2d(
            channels[i - 1],
            channels[i],
            kernel_size[i - 1],
            padding=kernel_size[i - 1] // 2,
        )
    )
    if init_zeros:
        nn.init.zeros_(net[-1].weight)
        nn.init.zeros_(net[-1].bias)
    self.net = nn.Sequential(*net)

lipschitz

LipschitzCNN

Bases: Module

Convolutional neural network which is Lipschitz continuous with Lipschitz constant L < 1

Source code in normflows/nets/lipschitz.py
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
class LipschitzCNN(nn.Module):
    """
    Convolutional neural network which is Lipschitz continuous
    with Lipschitz constant L < 1
    """

    def __init__(
        self,
        channels,
        kernel_size,
        lipschitz_const=0.97,
        max_lipschitz_iter=5,
        lipschitz_tolerance=None,
        init_zeros=True,
    ):
        """Constructor

        Args:
          channels: Integer list with the number of channels of the layers
          kernel_size: Integer list of kernel sizes of the layers
          lipschitz_const: Maximum Lipschitz constant of each layer
          max_lipschitz_iter: Maximum number of iterations used to ensure that layers are Lipschitz continuous with L smaller than set maximum; if None, tolerance is used
          lipschitz_tolerance: Float, tolerance used to ensure Lipschitz continuity if max_lipschitz_iter is None, typically 1e-3
          init_zeros: Flag, whether to initialize last layer approximately with zeros
        """
        super().__init__()

        self.n_layers = len(kernel_size)
        self.channels = channels
        self.kernel_size = kernel_size
        self.lipschitz_const = lipschitz_const
        self.max_lipschitz_iter = max_lipschitz_iter
        self.lipschitz_tolerance = lipschitz_tolerance
        self.init_zeros = init_zeros

        layers = []
        for i in range(self.n_layers):
            layers += [
                Swish(),
                InducedNormConv2d(
                    in_channels=channels[i],
                    out_channels=channels[i + 1],
                    kernel_size=kernel_size[i],
                    stride=1,
                    padding=kernel_size[i] // 2,
                    bias=True,
                    coeff=lipschitz_const,
                    domain=2,
                    codomain=2,
                    n_iterations=max_lipschitz_iter,
                    atol=lipschitz_tolerance,
                    rtol=lipschitz_tolerance,
                    zero_init=init_zeros if i == (self.n_layers - 1) else False,
                ),
            ]

        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)
__init__(channels, kernel_size, lipschitz_const=0.97, max_lipschitz_iter=5, lipschitz_tolerance=None, init_zeros=True)

Constructor

Parameters:

Name Type Description Default
channels

Integer list with the number of channels of the layers

required
kernel_size

Integer list of kernel sizes of the layers

required
lipschitz_const

Maximum Lipschitz constant of each layer

0.97
max_lipschitz_iter

Maximum number of iterations used to ensure that layers are Lipschitz continuous with L smaller than set maximum; if None, tolerance is used

5
lipschitz_tolerance

Float, tolerance used to ensure Lipschitz continuity if max_lipschitz_iter is None, typically 1e-3

None
init_zeros

Flag, whether to initialize last layer approximately with zeros

True
Source code in normflows/nets/lipschitz.py
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
def __init__(
    self,
    channels,
    kernel_size,
    lipschitz_const=0.97,
    max_lipschitz_iter=5,
    lipschitz_tolerance=None,
    init_zeros=True,
):
    """Constructor

    Args:
      channels: Integer list with the number of channels of the layers
      kernel_size: Integer list of kernel sizes of the layers
      lipschitz_const: Maximum Lipschitz constant of each layer
      max_lipschitz_iter: Maximum number of iterations used to ensure that layers are Lipschitz continuous with L smaller than set maximum; if None, tolerance is used
      lipschitz_tolerance: Float, tolerance used to ensure Lipschitz continuity if max_lipschitz_iter is None, typically 1e-3
      init_zeros: Flag, whether to initialize last layer approximately with zeros
    """
    super().__init__()

    self.n_layers = len(kernel_size)
    self.channels = channels
    self.kernel_size = kernel_size
    self.lipschitz_const = lipschitz_const
    self.max_lipschitz_iter = max_lipschitz_iter
    self.lipschitz_tolerance = lipschitz_tolerance
    self.init_zeros = init_zeros

    layers = []
    for i in range(self.n_layers):
        layers += [
            Swish(),
            InducedNormConv2d(
                in_channels=channels[i],
                out_channels=channels[i + 1],
                kernel_size=kernel_size[i],
                stride=1,
                padding=kernel_size[i] // 2,
                bias=True,
                coeff=lipschitz_const,
                domain=2,
                codomain=2,
                n_iterations=max_lipschitz_iter,
                atol=lipschitz_tolerance,
                rtol=lipschitz_tolerance,
                zero_init=init_zeros if i == (self.n_layers - 1) else False,
            ),
        ]

    self.net = nn.Sequential(*layers)

LipschitzMLP

Bases: Module

Fully connected neural net which is Lipschitz continuou with Lipschitz constant L < 1

Source code in normflows/nets/lipschitz.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
class LipschitzMLP(nn.Module):
    """Fully connected neural net which is Lipschitz continuou with Lipschitz constant L < 1"""

    def __init__(
        self,
        channels,
        lipschitz_const=0.97,
        max_lipschitz_iter=5,
        lipschitz_tolerance=None,
        init_zeros=True,
    ):
        """
        Constructor
          channels: Integer list with the number of channels of
        the layers
          lipschitz_const: Maximum Lipschitz constant of each layer
          max_lipschitz_iter: Maximum number of iterations used to
        ensure that layers are Lipschitz continuous with L smaller than
        set maximum; if None, tolerance is used
          lipschitz_tolerance: Float, tolerance used to ensure
        Lipschitz continuity if max_lipschitz_iter is None, typically 1e-3
          init_zeros: Flag, whether to initialize last layer
        approximately with zeros
        """
        super().__init__()

        self.n_layers = len(channels) - 1
        self.channels = channels
        self.lipschitz_const = lipschitz_const
        self.max_lipschitz_iter = max_lipschitz_iter
        self.lipschitz_tolerance = lipschitz_tolerance
        self.init_zeros = init_zeros

        layers = []
        for i in range(self.n_layers):
            layers += [
                Swish(),
                InducedNormLinear(
                    in_features=channels[i],
                    out_features=channels[i + 1],
                    coeff=lipschitz_const,
                    domain=2,
                    codomain=2,
                    n_iterations=max_lipschitz_iter,
                    atol=lipschitz_tolerance,
                    rtol=lipschitz_tolerance,
                    zero_init=init_zeros if i == (self.n_layers - 1) else False,
                ),
            ]

        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)
__init__(channels, lipschitz_const=0.97, max_lipschitz_iter=5, lipschitz_tolerance=None, init_zeros=True)

Constructor channels: Integer list with the number of channels of the layers lipschitz_const: Maximum Lipschitz constant of each layer max_lipschitz_iter: Maximum number of iterations used to ensure that layers are Lipschitz continuous with L smaller than set maximum; if None, tolerance is used lipschitz_tolerance: Float, tolerance used to ensure Lipschitz continuity if max_lipschitz_iter is None, typically 1e-3 init_zeros: Flag, whether to initialize last layer approximately with zeros

Source code in normflows/nets/lipschitz.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def __init__(
    self,
    channels,
    lipschitz_const=0.97,
    max_lipschitz_iter=5,
    lipschitz_tolerance=None,
    init_zeros=True,
):
    """
    Constructor
      channels: Integer list with the number of channels of
    the layers
      lipschitz_const: Maximum Lipschitz constant of each layer
      max_lipschitz_iter: Maximum number of iterations used to
    ensure that layers are Lipschitz continuous with L smaller than
    set maximum; if None, tolerance is used
      lipschitz_tolerance: Float, tolerance used to ensure
    Lipschitz continuity if max_lipschitz_iter is None, typically 1e-3
      init_zeros: Flag, whether to initialize last layer
    approximately with zeros
    """
    super().__init__()

    self.n_layers = len(channels) - 1
    self.channels = channels
    self.lipschitz_const = lipschitz_const
    self.max_lipschitz_iter = max_lipschitz_iter
    self.lipschitz_tolerance = lipschitz_tolerance
    self.init_zeros = init_zeros

    layers = []
    for i in range(self.n_layers):
        layers += [
            Swish(),
            InducedNormLinear(
                in_features=channels[i],
                out_features=channels[i + 1],
                coeff=lipschitz_const,
                domain=2,
                codomain=2,
                n_iterations=max_lipschitz_iter,
                atol=lipschitz_tolerance,
                rtol=lipschitz_tolerance,
                zero_init=init_zeros if i == (self.n_layers - 1) else False,
            ),
        ]

    self.net = nn.Sequential(*layers)

projmax_(v)

Inplace argmax on absolute value.

Source code in normflows/nets/lipschitz.py
651
652
653
654
655
656
def projmax_(v):
    """Inplace argmax on absolute value."""
    ind = torch.argmax(torch.abs(v))
    v.zero_()
    v[ind] = 1
    return v

made

Implementation of MADE. Code taken from https://github.com/bayesiains/nsf

MADE

Bases: Module

Implementation of MADE.

It can use either feedforward blocks or residual blocks (default is residual). Optionally, it can use batch norm or dropout within blocks (default is no).

Source code in normflows/nets/made.py
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
class MADE(nn.Module):
    """Implementation of MADE.

    It can use either feedforward blocks or residual blocks (default is residual).
    Optionally, it can use batch norm or dropout within blocks (default is no).
    """

    def __init__(
        self,
        features,
        hidden_features,
        context_features=None,
        num_blocks=2,
        output_multiplier=1,
        use_residual_blocks=True,
        random_mask=False,
        permute_mask=False,
        activation=F.relu,
        dropout_probability=0.0,
        use_batch_norm=False,
        preprocessing=None,
    ):
        if use_residual_blocks and random_mask:
            raise ValueError("Residual blocks can't be used with random masks.")
        super().__init__()

        # Preprocessing
        if preprocessing is None:
            self.preprocessing = torch.nn.Identity()
        else:
            self.preprocessing = preprocessing

        # Initial layer.
        input_degrees_ = _get_input_degrees(features)
        if permute_mask:
            input_degrees_ = input_degrees_[torch.randperm(features)]
        self.initial_layer = MaskedLinear(
            in_degrees=input_degrees_,
            out_features=hidden_features,
            autoregressive_features=features,
            random_mask=random_mask,
            is_output=False,
        )

        if context_features is not None:
            self.context_layer = nn.Linear(context_features, hidden_features)

        # Residual blocks.
        blocks = []
        if use_residual_blocks:
            block_constructor = MaskedResidualBlock
        else:
            block_constructor = MaskedFeedforwardBlock
        prev_out_degrees = self.initial_layer.degrees
        for _ in range(num_blocks):
            blocks.append(
                block_constructor(
                    in_degrees=prev_out_degrees,
                    autoregressive_features=features,
                    context_features=context_features,
                    random_mask=random_mask,
                    activation=activation,
                    dropout_probability=dropout_probability,
                    use_batch_norm=use_batch_norm,
                )
            )
            prev_out_degrees = blocks[-1].degrees
        self.blocks = nn.ModuleList(blocks)

        # Final layer.
        self.final_layer = MaskedLinear(
            in_degrees=prev_out_degrees,
            out_features=features * output_multiplier,
            autoregressive_features=features,
            random_mask=random_mask,
            is_output=True,
            out_degrees_=input_degrees_,
        )

    def forward(self, inputs, context=None):
        outputs = self.preprocessing(inputs)
        outputs = self.initial_layer(outputs)
        if context is not None:
            outputs += self.context_layer(context)
        for block in self.blocks:
            outputs = block(outputs, context)
        outputs = self.final_layer(outputs)
        return outputs

MaskedFeedforwardBlock

Bases: Module

A feedforward block based on a masked linear module.

NOTE In this implementation, the number of output features is taken to be equal to the number of input features.

Source code in normflows/nets/made.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
class MaskedFeedforwardBlock(nn.Module):
    """A feedforward block based on a masked linear module.

    **NOTE** In this implementation, the number of output features is taken to be equal to the number of input features.
    """

    def __init__(
        self,
        in_degrees,
        autoregressive_features,
        context_features=None,
        random_mask=False,
        activation=F.relu,
        dropout_probability=0.0,
        use_batch_norm=False,
    ):
        super().__init__()
        features = len(in_degrees)

        # Batch norm.
        if use_batch_norm:
            self.batch_norm = nn.BatchNorm1d(features, eps=1e-3)
        else:
            self.batch_norm = None

        if context_features is not None:
            raise NotImplementedError()

        # Masked linear.
        self.linear = MaskedLinear(
            in_degrees=in_degrees,
            out_features=features,
            autoregressive_features=autoregressive_features,
            random_mask=random_mask,
            is_output=False,
        )
        self.degrees = self.linear.degrees

        # Activation and dropout.
        self.activation = activation
        self.dropout = nn.Dropout(p=dropout_probability)

    def forward(self, inputs, context=None):
        if context is not None:
            raise NotImplementedError()

        if self.batch_norm:
            outputs = self.batch_norm(inputs)
        else:
            outputs = inputs
        outputs = self.linear(outputs)
        outputs = self.activation(outputs)
        outputs = self.dropout(outputs)
        return outputs

MaskedLinear

Bases: Linear

A linear module with a masked weight matrix.

Source code in normflows/nets/made.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
class MaskedLinear(nn.Linear):
    """A linear module with a masked weight matrix."""

    def __init__(
        self,
        in_degrees,
        out_features,
        autoregressive_features,
        random_mask,
        is_output,
        bias=True,
        out_degrees_=None,
    ):
        super().__init__(
            in_features=len(in_degrees), out_features=out_features, bias=bias
        )
        mask, degrees = self._get_mask_and_degrees(
            in_degrees=in_degrees,
            out_features=out_features,
            autoregressive_features=autoregressive_features,
            random_mask=random_mask,
            is_output=is_output,
            out_degrees_=out_degrees_,
        )
        self.register_buffer("mask", mask)
        self.register_buffer("degrees", degrees)

    @classmethod
    def _get_mask_and_degrees(
        cls,
        in_degrees,
        out_features,
        autoregressive_features,
        random_mask,
        is_output,
        out_degrees_=None,
    ):
        if is_output:
            if out_degrees_ is None:
                out_degrees_ = _get_input_degrees(autoregressive_features)
            out_degrees = tile(out_degrees_, out_features // autoregressive_features)
            mask = (out_degrees[..., None] > in_degrees).float()

        else:
            if random_mask:
                min_in_degree = torch.min(in_degrees).item()
                min_in_degree = min(min_in_degree, autoregressive_features - 1)
                out_degrees = torch.randint(
                    low=min_in_degree,
                    high=autoregressive_features,
                    size=[out_features],
                    dtype=torch.long,
                )
            else:
                max_ = max(1, autoregressive_features - 1)
                min_ = min(1, autoregressive_features - 1)
                out_degrees = torch.arange(out_features) % max_ + min_
            mask = (out_degrees[..., None] >= in_degrees).float()

        return mask, out_degrees

    def forward(self, x):
        return F.linear(x, self.weight * self.mask, self.bias)

MaskedResidualBlock

Bases: Module

A residual block containing masked linear modules.

Source code in normflows/nets/made.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
class MaskedResidualBlock(nn.Module):
    """A residual block containing masked linear modules."""

    def __init__(
        self,
        in_degrees,
        autoregressive_features,
        context_features=None,
        random_mask=False,
        activation=F.relu,
        dropout_probability=0.0,
        use_batch_norm=False,
        zero_initialization=True,
    ):
        if random_mask:
            raise ValueError("Masked residual block can't be used with random masks.")
        super().__init__()
        features = len(in_degrees)

        if context_features is not None:
            self.context_layer = nn.Linear(context_features, features)

        # Batch norm.
        self.use_batch_norm = use_batch_norm
        if use_batch_norm:
            self.batch_norm_layers = nn.ModuleList(
                [nn.BatchNorm1d(features, eps=1e-3) for _ in range(2)]
            )

        # Masked linear.
        linear_0 = MaskedLinear(
            in_degrees=in_degrees,
            out_features=features,
            autoregressive_features=autoregressive_features,
            random_mask=False,
            is_output=False,
        )
        linear_1 = MaskedLinear(
            in_degrees=linear_0.degrees,
            out_features=features,
            autoregressive_features=autoregressive_features,
            random_mask=False,
            is_output=False,
        )
        self.linear_layers = nn.ModuleList([linear_0, linear_1])
        self.degrees = linear_1.degrees
        if torch.all(self.degrees >= in_degrees).item() != 1:
            raise RuntimeError(
                "In a masked residual block, the output degrees can't be"
                " less than the corresponding input degrees."
            )

        # Activation and dropout
        self.activation = activation
        self.dropout = nn.Dropout(p=dropout_probability)

        # Initialization.
        if zero_initialization:
            init.uniform_(self.linear_layers[-1].weight, a=-1e-3, b=1e-3)
            init.uniform_(self.linear_layers[-1].bias, a=-1e-3, b=1e-3)

    def forward(self, inputs, context=None):
        temps = inputs
        if self.use_batch_norm:
            temps = self.batch_norm_layers[0](temps)
        temps = self.activation(temps)
        temps = self.linear_layers[0](temps)
        if self.use_batch_norm:
            temps = self.batch_norm_layers[1](temps)
        temps = self.activation(temps)
        temps = self.dropout(temps)
        temps = self.linear_layers[1](temps)
        if context is not None:
            temps = F.glu(torch.cat((temps, self.context_layer(context)), dim=1), dim=1)
        return inputs + temps

made_test

Tests for MADE. Code partially taken from https://github.com/bayesiains/nsf

mlp

MLP

Bases: Module

A multilayer perceptron with Leaky ReLU nonlinearities

Source code in normflows/nets/mlp.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
class MLP(nn.Module):
    """
    A multilayer perceptron with Leaky ReLU nonlinearities
    """

    def __init__(
        self,
        layers,
        leaky=0.0,
        score_scale=None,
        output_fn=None,
        output_scale=None,
        init_zeros=False,
        dropout=None,
    ):
        """
        layers: list of layer sizes from start to end
        leaky: slope of the leaky part of the ReLU, if 0.0, standard ReLU is used
        score_scale: Factor to apply to the scores, i.e. output before output_fn.
        output_fn: String, function to be applied to the output, either None, "sigmoid", "relu", "tanh", or "clampexp"
        output_scale: Rescale outputs if output_fn is specified, i.e. ```scale * output_fn(out / scale)```
        init_zeros: Flag, if true, weights and biases of last layer are initialized with zeros (helpful for deep models, see [arXiv 1807.03039](https://arxiv.org/abs/1807.03039))
        dropout: Float, if specified, dropout is done before last layer; if None, no dropout is done
        """
        super().__init__()
        net = nn.ModuleList([])
        for k in range(len(layers) - 2):
            net.append(nn.Linear(layers[k], layers[k + 1]))
            net.append(nn.LeakyReLU(leaky))
        if dropout is not None:
            net.append(nn.Dropout(p=dropout))
        net.append(nn.Linear(layers[-2], layers[-1]))
        if init_zeros:
            nn.init.zeros_(net[-1].weight)
            nn.init.zeros_(net[-1].bias)
        if output_fn is not None:
            if score_scale is not None:
                net.append(utils.ConstScaleLayer(score_scale))
            if output_fn == "sigmoid":
                net.append(nn.Sigmoid())
            elif output_fn == "relu":
                net.append(nn.ReLU())
            elif output_fn == "tanh":
                net.append(nn.Tanh())
            elif output_fn == "clampexp":
                net.append(utils.ClampExp())
            else:
                NotImplementedError("This output function is not implemented.")
            if output_scale is not None:
                net.append(utils.ConstScaleLayer(output_scale))
        self.net = nn.Sequential(*net)

    def forward(self, x):
        return self.net(x)
__init__(layers, leaky=0.0, score_scale=None, output_fn=None, output_scale=None, init_zeros=False, dropout=None)

layers: list of layer sizes from start to end leaky: slope of the leaky part of the ReLU, if 0.0, standard ReLU is used score_scale: Factor to apply to the scores, i.e. output before output_fn. output_fn: String, function to be applied to the output, either None, "sigmoid", "relu", "tanh", or "clampexp" output_scale: Rescale outputs if output_fn is specified, i.e. scale * output_fn(out / scale) init_zeros: Flag, if true, weights and biases of last layer are initialized with zeros (helpful for deep models, see arXiv 1807.03039) dropout: Float, if specified, dropout is done before last layer; if None, no dropout is done

Source code in normflows/nets/mlp.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def __init__(
    self,
    layers,
    leaky=0.0,
    score_scale=None,
    output_fn=None,
    output_scale=None,
    init_zeros=False,
    dropout=None,
):
    """
    layers: list of layer sizes from start to end
    leaky: slope of the leaky part of the ReLU, if 0.0, standard ReLU is used
    score_scale: Factor to apply to the scores, i.e. output before output_fn.
    output_fn: String, function to be applied to the output, either None, "sigmoid", "relu", "tanh", or "clampexp"
    output_scale: Rescale outputs if output_fn is specified, i.e. ```scale * output_fn(out / scale)```
    init_zeros: Flag, if true, weights and biases of last layer are initialized with zeros (helpful for deep models, see [arXiv 1807.03039](https://arxiv.org/abs/1807.03039))
    dropout: Float, if specified, dropout is done before last layer; if None, no dropout is done
    """
    super().__init__()
    net = nn.ModuleList([])
    for k in range(len(layers) - 2):
        net.append(nn.Linear(layers[k], layers[k + 1]))
        net.append(nn.LeakyReLU(leaky))
    if dropout is not None:
        net.append(nn.Dropout(p=dropout))
    net.append(nn.Linear(layers[-2], layers[-1]))
    if init_zeros:
        nn.init.zeros_(net[-1].weight)
        nn.init.zeros_(net[-1].bias)
    if output_fn is not None:
        if score_scale is not None:
            net.append(utils.ConstScaleLayer(score_scale))
        if output_fn == "sigmoid":
            net.append(nn.Sigmoid())
        elif output_fn == "relu":
            net.append(nn.ReLU())
        elif output_fn == "tanh":
            net.append(nn.Tanh())
        elif output_fn == "clampexp":
            net.append(utils.ClampExp())
        else:
            NotImplementedError("This output function is not implemented.")
        if output_scale is not None:
            net.append(utils.ConstScaleLayer(output_scale))
    self.net = nn.Sequential(*net)

resnet

ResidualBlock

Bases: Module

A general-purpose residual block. Works only with 1-dim inputs.

Source code in normflows/nets/resnet.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class ResidualBlock(nn.Module):
    """A general-purpose residual block. Works only with 1-dim inputs."""

    def __init__(
        self,
        features,
        context_features,
        activation=F.relu,
        dropout_probability=0.0,
        use_batch_norm=False,
        zero_initialization=True,
    ):
        super().__init__()
        self.activation = activation

        self.use_batch_norm = use_batch_norm
        if use_batch_norm:
            self.batch_norm_layers = nn.ModuleList(
                [nn.BatchNorm1d(features, eps=1e-3) for _ in range(2)]
            )
        if context_features is not None:
            self.context_layer = nn.Linear(context_features, features)
        self.linear_layers = nn.ModuleList(
            [nn.Linear(features, features) for _ in range(2)]
        )
        self.dropout = nn.Dropout(p=dropout_probability)
        if zero_initialization:
            init.uniform_(self.linear_layers[-1].weight, -1e-3, 1e-3)
            init.uniform_(self.linear_layers[-1].bias, -1e-3, 1e-3)

    def forward(self, inputs, context=None):
        temps = inputs
        if self.use_batch_norm:
            temps = self.batch_norm_layers[0](temps)
        temps = self.activation(temps)
        temps = self.linear_layers[0](temps)
        if self.use_batch_norm:
            temps = self.batch_norm_layers[1](temps)
        temps = self.activation(temps)
        temps = self.dropout(temps)
        temps = self.linear_layers[1](temps)
        if context is not None:
            temps = F.glu(torch.cat((temps, self.context_layer(context)), dim=1), dim=1)
        return inputs + temps

ResidualNet

Bases: Module

A general-purpose residual network. Works only with 1-dim inputs.

Source code in normflows/nets/resnet.py
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
class ResidualNet(nn.Module):
    """A general-purpose residual network. Works only with 1-dim inputs."""

    def __init__(
        self,
        in_features,
        out_features,
        hidden_features,
        context_features=None,
        num_blocks=2,
        activation=F.relu,
        dropout_probability=0.0,
        use_batch_norm=False,
        preprocessing=None,
    ):
        super().__init__()
        self.hidden_features = hidden_features
        self.context_features = context_features
        self.preprocessing = preprocessing
        if context_features is not None:
            self.initial_layer = nn.Linear(
                in_features + context_features, hidden_features
            )
        else:
            self.initial_layer = nn.Linear(in_features, hidden_features)
        self.blocks = nn.ModuleList(
            [
                ResidualBlock(
                    features=hidden_features,
                    context_features=context_features,
                    activation=activation,
                    dropout_probability=dropout_probability,
                    use_batch_norm=use_batch_norm,
                )
                for _ in range(num_blocks)
            ]
        )
        self.final_layer = nn.Linear(hidden_features, out_features)

    def forward(self, inputs, context=None):
        if self.preprocessing is None:
            temps = inputs
        else:
            temps = self.preprocessing(inputs)
        if context is None:
            temps = self.initial_layer(temps)
        else:
            temps = self.initial_layer(torch.cat((temps, context), dim=1))
        for block in self.blocks:
            temps = block(temps, context=context)
        outputs = self.final_layer(temps)
        return outputs

sampling

hais

HAIS

Class which performs HAIS

Source code in normflows/sampling/hais.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class HAIS:
    """
    Class which performs HAIS
    """

    def __init__(self, betas, prior, target, num_leapfrog, step_size, log_mass):
        """
        Args:
          betas: Annealing schedule, the jth target is ```f_j(x) = f_0(x)^{\beta_j} f_n(x)^{1-\beta_j}``` where the target is proportional to f_0 and the prior is proportional to f_n. The number of intermediate steps is infered from the shape of betas. Should be of the form 1 = \beta_0 > \beta_1 > ... > \beta_n = 0
          prior: The prior distribution to start the HAIS chain.
          target: The target distribution from which we would like to draw weighted samples.
          num_leapfrog: Number of leapfrog steps in the HMC transitions.
          step_size: step_size to use for HMC transitions.
          log_mass: log_mass to use for HMC transitions.
        """
        self.prior = prior
        self.target = target
        self.layers = []
        n = betas.shape[0] - 1
        for i in range(n - 1, 0, -1):
            intermediate_target = distributions.LinearInterpolation(
                self.target, self.prior, betas[i]
            )
            self.layers += [
                flows.HamiltonianMonteCarlo(
                    intermediate_target, num_leapfrog, torch.log(step_size), log_mass
                )
            ]

    def sample(self, num_samples):
        """Run HAIS to draw samples from the target with appropriate weights.

        Args:
          num_samples: The number of samples to draw.a
        """
        samples, log_weights = self.prior.forward(num_samples)
        log_weights = -log_weights
        for i in range(len(self.layers)):
            samples, log_weights_addition = self.layers[i].forward(samples)
            log_weights += log_weights_addition
        log_weights += self.target.log_prob(samples)
        return samples, log_weights
__init__(betas, prior, target, num_leapfrog, step_size, log_mass)

Parameters:

Name Type Description Default
betas

Annealing schedule, the jth target is f_j(x) = f_0(x)^{eta_j} f_n(x)^{1-eta_j} where the target is proportional to f_0 and the prior is proportional to f_n. The number of intermediate steps is infered from the shape of betas. Should be of the form 1 = eta_0 > eta_1 > ... > eta_n = 0

required
prior

The prior distribution to start the HAIS chain.

required
target

The target distribution from which we would like to draw weighted samples.

required
num_leapfrog

Number of leapfrog steps in the HMC transitions.

required
step_size

step_size to use for HMC transitions.

required
log_mass

log_mass to use for HMC transitions.

required
Source code in normflows/sampling/hais.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def __init__(self, betas, prior, target, num_leapfrog, step_size, log_mass):
    """
    Args:
      betas: Annealing schedule, the jth target is ```f_j(x) = f_0(x)^{\beta_j} f_n(x)^{1-\beta_j}``` where the target is proportional to f_0 and the prior is proportional to f_n. The number of intermediate steps is infered from the shape of betas. Should be of the form 1 = \beta_0 > \beta_1 > ... > \beta_n = 0
      prior: The prior distribution to start the HAIS chain.
      target: The target distribution from which we would like to draw weighted samples.
      num_leapfrog: Number of leapfrog steps in the HMC transitions.
      step_size: step_size to use for HMC transitions.
      log_mass: log_mass to use for HMC transitions.
    """
    self.prior = prior
    self.target = target
    self.layers = []
    n = betas.shape[0] - 1
    for i in range(n - 1, 0, -1):
        intermediate_target = distributions.LinearInterpolation(
            self.target, self.prior, betas[i]
        )
        self.layers += [
            flows.HamiltonianMonteCarlo(
                intermediate_target, num_leapfrog, torch.log(step_size), log_mass
            )
        ]
sample(num_samples)

Run HAIS to draw samples from the target with appropriate weights.

Parameters:

Name Type Description Default
num_samples

The number of samples to draw.a

required
Source code in normflows/sampling/hais.py
37
38
39
40
41
42
43
44
45
46
47
48
49
def sample(self, num_samples):
    """Run HAIS to draw samples from the target with appropriate weights.

    Args:
      num_samples: The number of samples to draw.a
    """
    samples, log_weights = self.prior.forward(num_samples)
    log_weights = -log_weights
    for i in range(len(self.layers)):
        samples, log_weights_addition = self.layers[i].forward(samples)
        log_weights += log_weights_addition
    log_weights += self.target.log_prob(samples)
    return samples, log_weights

transforms

Logit

Bases: Flow

Logit mapping of image tensor, see RealNVP paper

logit(alpha + (1 - alpha) * x) where logit(x) = log(x / (1 - x))
Source code in normflows/transforms.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class Logit(flows.Flow):
    """Logit mapping of image tensor, see RealNVP paper

    ```
    logit(alpha + (1 - alpha) * x) where logit(x) = log(x / (1 - x))
    ```

    """

    def __init__(self, alpha=0.05):
        """Constructor

        Args:
          alpha: Alpha parameter, see above
        """
        super().__init__()
        self.alpha = alpha

    def forward(self, z):
        beta = 1 - 2 * self.alpha
        sum_dims = list(range(1, z.dim()))
        ls = torch.sum(torch.nn.functional.logsigmoid(z), dim=sum_dims)
        mls = torch.sum(torch.nn.functional.logsigmoid(-z), dim=sum_dims)
        log_det = -np.log(beta) * np.prod([*z.shape[1:]]) + ls + mls
        z = (torch.sigmoid(z) - self.alpha) / beta
        return z, log_det

    def inverse(self, z):
        beta = 1 - 2 * self.alpha
        z = self.alpha + beta * z
        logz = torch.log(z)
        log1mz = torch.log(1 - z)
        z = logz - log1mz
        sum_dims = list(range(1, z.dim()))
        log_det = (
            np.log(beta) * np.prod([*z.shape[1:]])
            - torch.sum(logz, dim=sum_dims)
            - torch.sum(log1mz, dim=sum_dims)
        )
        return z, log_det

__init__(alpha=0.05)

Constructor

Parameters:

Name Type Description Default
alpha

Alpha parameter, see above

0.05
Source code in normflows/transforms.py
17
18
19
20
21
22
23
24
def __init__(self, alpha=0.05):
    """Constructor

    Args:
      alpha: Alpha parameter, see above
    """
    super().__init__()
    self.alpha = alpha

Shift

Bases: Flow

Shift data by a fixed constant

Default is -0.5 to shift data from interval [0, 1] to [-0.5, 0.5]

Source code in normflows/transforms.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
class Shift(flows.Flow):
    """Shift data by a fixed constant

    Default is -0.5 to shift data from
    interval [0, 1] to [-0.5, 0.5]
    """

    def __init__(self, shift=-0.5):
        """Constructor

        Args:
          shift: Shift to apply to the data
        """
        super().__init__()
        self.shift = shift

    def forward(self, z):
        z -= self.shift
        log_det = torch.zeros(z.shape[0], dtype=z.dtype,
                              device=z.device)
        return z, log_det

    def inverse(self, z):
        z += self.shift
        log_det = torch.zeros(z.shape[0], dtype=z.dtype,
                              device=z.device)
        return z, log_det

__init__(shift=-0.5)

Constructor

Parameters:

Name Type Description Default
shift

Shift to apply to the data

-0.5
Source code in normflows/transforms.py
57
58
59
60
61
62
63
64
def __init__(self, shift=-0.5):
    """Constructor

    Args:
      shift: Shift to apply to the data
    """
    super().__init__()
    self.shift = shift

utils

eval

bitsPerDim(model, x, y=None, trans='logit', trans_param=[0.05])

Computes the bits per dim for a batch of data

Parameters:

Name Type Description Default
model

Model to compute bits per dim for

required
x

Batch of data

required
y

Class labels for batch of data if base distribution is class conditional

None
trans

Transformation to be applied to images during training

'logit'
trans_param

List of parameters of the transformation

[0.05]

Returns:

Type Description

Bits per dim for data batch under model

Source code in normflows/utils/eval.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def bitsPerDim(model, x, y=None, trans="logit", trans_param=[0.05]):
    """Computes the bits per dim for a batch of data

    Args:
      model: Model to compute bits per dim for
      x: Batch of data
      y: Class labels for batch of data if base distribution is class conditional
      trans: Transformation to be applied to images during training
      trans_param: List of parameters of the transformation

    Returns:
      Bits per dim for data batch under model
    """
    dims = torch.prod(torch.tensor(x.size()[1:]))
    if trans == "logit":
        if y is None:
            log_q = model.log_prob(x)
        else:
            log_q = model.log_prob(x, y)
        sum_dims = list(range(1, x.dim()))
        ls = torch.nn.LogSigmoid()
        sig_ = torch.sum(ls(x) / np.log(2), sum_dims)
        sig_ += torch.sum(ls(-x) / np.log(2), sum_dims)
        b = -log_q / dims / np.log(2) - np.log2(1 - trans_param[0]) + 8
        b += sig_ / dims
    else:
        raise NotImplementedError(
            "The transformation " + trans + " is not implemented."
        )
    return b

bitsPerDimDataset(model, data_loader, class_cond=True, trans='logit', trans_param=[0.05])

Computes average bits per dim for an entire dataset given by a data loader

Parameters:

Name Type Description Default
model

Model to compute bits per dim for

required
data_loader

Data loader of dataset

required
class_cond

Flag indicating whether model is class_conditional

True
trans

Transformation to be applied to images during training

'logit'
trans_param

List of parameters of the transformation

[0.05]

Returns:

Type Description

Average bits per dim for dataset

Source code in normflows/utils/eval.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def bitsPerDimDataset(
    model, data_loader, class_cond=True, trans="logit", trans_param=[0.05]
):
    """Computes average bits per dim for an entire dataset given by a data loader

    Args:
      model: Model to compute bits per dim for
      data_loader: Data loader of dataset
      class_cond: Flag indicating whether model is class_conditional
      trans: Transformation to be applied to images during training
      trans_param: List of parameters of the transformation

    Returns:
      Average bits per dim for dataset
    """
    n = 0
    b_cum = 0
    with torch.no_grad():
        for x, y in iter(data_loader):
            b_ = bitsPerDim(
                model, x, y.to(x.device) if class_cond else None, trans, trans_param
            )
            b_np = b_.to("cpu").numpy()
            b_cum += np.nansum(b_np)
            n += len(x) - np.sum(np.isnan(b_np))
        b = b_cum / n
    return b

masks

create_alternating_binary_mask(features, even=True)

Creates a binary mask of a given dimension which alternates its masking.

Parameters:

Name Type Description Default
features

Dimension of mask.

required
even

If True, even values are assigned 1s, odd 0s. If False, vice versa.

True

Returns:

Type Description

Alternating binary mask of type torch.Tensor.

Source code in normflows/utils/masks.py
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
def create_alternating_binary_mask(features, even=True):
    """Creates a binary mask of a given dimension which alternates its masking.

    Args:
      features: Dimension of mask.
      even: If True, even values are assigned 1s, odd 0s. If False, vice versa.

    Returns:
      Alternating binary mask of type torch.Tensor.
    """
    mask = torch.zeros(features).byte()
    start = 0 if even else 1
    mask[start::2] += 1
    return mask

create_mid_split_binary_mask(features)

Creates a binary mask of a given dimension which splits its masking at the midpoint.

Parameters:

Name Type Description Default
features

Dimension of mask.

required

Returns:

Type Description

Binary mask split at midpoint of type torch.Tensor

Source code in normflows/utils/masks.py
20
21
22
23
24
25
26
27
28
29
30
31
32
def create_mid_split_binary_mask(features):
    """Creates a binary mask of a given dimension which splits its masking at the midpoint.

    Args:
      features: Dimension of mask.

    Returns:
      Binary mask split at midpoint of type torch.Tensor
    """
    mask = torch.zeros(features).byte()
    midpoint = features // 2 if features % 2 == 0 else features // 2 + 1
    mask[:midpoint] += 1
    return mask

create_random_binary_mask(features, seed=None)

Creates a random binary mask of a given dimension with half of its entries randomly set to 1s.

Parameters:

Name Type Description Default
features

Dimension of mask.

required
seed

Seed to be used

None

Returns:

Type Description

Binary mask with half of its entries set to 1s, of type torch.Tensor.

Source code in normflows/utils/masks.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def create_random_binary_mask(features, seed=None):
    """Creates a random binary mask of a given dimension with half of its entries randomly set to 1s.

    Args:
      features: Dimension of mask.
      seed: Seed to be used

    Returns:
      Binary mask with half of its entries set to 1s, of type torch.Tensor.
    """
    mask = torch.zeros(features).byte()
    weights = torch.ones(features).float()
    num_samples = features // 2 if features % 2 == 0 else features // 2 + 1
    if seed is None:
        generator = None
    else:
        generator = torch.Generator()
        generator.manual_seed(seed)
    indices = torch.multinomial(
        input=weights, num_samples=num_samples, replacement=False, generator=generator
    )
    mask[indices] += 1
    return mask

nn

ActNorm

Bases: Module

ActNorm layer with just one forward pass

Source code in normflows/utils/nn.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class ActNorm(nn.Module):
    """
    ActNorm layer with just one forward pass
    """
    def __init__(self, shape):
        """Constructor

        Args:
          shape: Same as shape in flows.ActNorm
          logscale_factor: Same as shape in flows.ActNorm

        """
        super().__init__()
        self.actNorm = flows.ActNorm(shape)

    def forward(self, input):
        out, _ = self.actNorm(input)
        return out
__init__(shape)

Constructor

Parameters:

Name Type Description Default
shape

Same as shape in flows.ActNorm

required
logscale_factor

Same as shape in flows.ActNorm

required
Source code in normflows/utils/nn.py
30
31
32
33
34
35
36
37
38
39
def __init__(self, shape):
    """Constructor

    Args:
      shape: Same as shape in flows.ActNorm
      logscale_factor: Same as shape in flows.ActNorm

    """
    super().__init__()
    self.actNorm = flows.ActNorm(shape)

ClampExp

Bases: Module

Nonlinearity min(exp(lam * x), 1)

Source code in normflows/utils/nn.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
class ClampExp(nn.Module):
    """
    Nonlinearity min(exp(lam * x), 1)
    """

    def __init__(self):
        """Constructor

        Args:
          lam: Lambda parameter
        """
        super(ClampExp, self).__init__()

    def forward(self, x):
        one = torch.tensor(1.0, device=x.device, dtype=x.dtype)
        return torch.min(torch.exp(x), one)
__init__()

Constructor

Parameters:

Name Type Description Default
lam

Lambda parameter

required
Source code in normflows/utils/nn.py
51
52
53
54
55
56
57
def __init__(self):
    """Constructor

    Args:
      lam: Lambda parameter
    """
    super(ClampExp, self).__init__()

ConstScaleLayer

Bases: Module

Scaling features by a fixed factor

Source code in normflows/utils/nn.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class ConstScaleLayer(nn.Module):
    """
    Scaling features by a fixed factor
    """

    def __init__(self, scale=1.0):
        """Constructor

        Args:
          scale: Scale to apply to features
        """
        super().__init__()
        self.scale_cpu = torch.tensor(scale)
        self.register_buffer("scale", self.scale_cpu)

    def forward(self, input):
        return input * self.scale
__init__(scale=1.0)

Constructor

Parameters:

Name Type Description Default
scale

Scale to apply to features

1.0
Source code in normflows/utils/nn.py
12
13
14
15
16
17
18
19
20
def __init__(self, scale=1.0):
    """Constructor

    Args:
      scale: Scale to apply to features
    """
    super().__init__()
    self.scale_cpu = torch.tensor(scale)
    self.register_buffer("scale", self.scale_cpu)

PeriodicFeaturesCat

Bases: Module

Converts a specified part of the input to periodic features by replacing those features f with [sin(scale * f), cos(scale * f)].

Note that this decreases the number of features and their order is changed.

Source code in normflows/utils/nn.py
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
class PeriodicFeaturesCat(nn.Module):
    """
    Converts a specified part of the input to periodic features by
    replacing those features f with [sin(scale * f), cos(scale * f)].

    Note that this decreases the number of features and their order
    is changed.
    """

    def __init__(self, ndim, ind, scale=1.0):
        """
        Constructor
        :param ndim: Int, number of dimensions
        :param ind: Iterable, indices of input elements to convert to
        periodic features
        :param scale: Scalar or iterable, used to scale inputs before
        converting them to periodic features
        """
        super(PeriodicFeaturesCat, self).__init__()

        # Set up indices and permutations
        self.ndim = ndim
        if torch.is_tensor(ind):
            self.register_buffer("ind", torch._cast_Long(ind))
        else:
            self.register_buffer("ind", torch.tensor(ind, dtype=torch.long))

        ind_ = []
        for i in range(self.ndim):
            if not i in self.ind:
                ind_ += [i]
        self.register_buffer("ind_", torch.tensor(ind_, dtype=torch.long))

        if torch.is_tensor(scale):
            self.register_buffer("scale", scale)
        else:
            self.scale = scale

    def forward(self, inputs):
        inputs_ = inputs[..., self.ind]
        inputs_ = self.scale * inputs_
        inputs_sin = torch.sin(inputs_)
        inputs_cos = torch.cos(inputs_)
        out = torch.cat((inputs_sin, inputs_cos,
                         inputs[..., self.ind_]), -1)
        return out
__init__(ndim, ind, scale=1.0)

Constructor :param ndim: Int, number of dimensions :param ind: Iterable, indices of input elements to convert to periodic features :param scale: Scalar or iterable, used to scale inputs before converting them to periodic features

Source code in normflows/utils/nn.py
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
def __init__(self, ndim, ind, scale=1.0):
    """
    Constructor
    :param ndim: Int, number of dimensions
    :param ind: Iterable, indices of input elements to convert to
    periodic features
    :param scale: Scalar or iterable, used to scale inputs before
    converting them to periodic features
    """
    super(PeriodicFeaturesCat, self).__init__()

    # Set up indices and permutations
    self.ndim = ndim
    if torch.is_tensor(ind):
        self.register_buffer("ind", torch._cast_Long(ind))
    else:
        self.register_buffer("ind", torch.tensor(ind, dtype=torch.long))

    ind_ = []
    for i in range(self.ndim):
        if not i in self.ind:
            ind_ += [i]
    self.register_buffer("ind_", torch.tensor(ind_, dtype=torch.long))

    if torch.is_tensor(scale):
        self.register_buffer("scale", scale)
    else:
        self.scale = scale

PeriodicFeaturesElementwise

Bases: Module

Converts a specified part of the input to periodic features by replacing those features f with w1 * sin(scale * f) + w2 * cos(scale * f).

Note that this operation is done elementwise and, therefore, some information about the feature can be lost.

Source code in normflows/utils/nn.py
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
class PeriodicFeaturesElementwise(nn.Module):
    """
    Converts a specified part of the input to periodic features by
    replacing those features f with
    w1 * sin(scale * f) + w2 * cos(scale * f).

    Note that this operation is done elementwise and, therefore,
    some information about the feature can be lost.
    """

    def __init__(self, ndim, ind, scale=1.0, bias=False, activation=None):
        """Constructor

        Args:
          ndim (int): number of dimensions
          ind (iterable): indices of input elements to convert to periodic features
          scale: Scalar or iterable, used to scale inputs before converting them to periodic features
          bias: Flag, whether to add a bias
          activation: Function or None, activation function to be applied
        """
        super(PeriodicFeaturesElementwise, self).__init__()

        # Set up indices and permutations
        self.ndim = ndim
        if torch.is_tensor(ind):
            self.register_buffer("ind", torch._cast_Long(ind))
        else:
            self.register_buffer("ind", torch.tensor(ind, dtype=torch.long))

        ind_ = []
        for i in range(self.ndim):
            if not i in self.ind:
                ind_ += [i]
        self.register_buffer("ind_", torch.tensor(ind_, dtype=torch.long))

        perm_ = torch.cat((self.ind, self.ind_))
        inv_perm_ = torch.zeros_like(perm_)
        for i in range(self.ndim):
            inv_perm_[perm_[i]] = i
        self.register_buffer("inv_perm", inv_perm_)

        self.weights = nn.Parameter(torch.ones(len(self.ind), 2))
        if torch.is_tensor(scale):
            self.register_buffer("scale", scale)
        else:
            self.scale = scale

        self.apply_bias = bias
        if self.apply_bias:
            self.bias = nn.Parameter(torch.zeros(len(self.ind)))

        if activation is None:
            self.activation = torch.nn.Identity()
        else:
            self.activation = activation

    def forward(self, inputs):
        inputs_ = inputs[..., self.ind]
        inputs_ = self.scale * inputs_
        inputs_ = self.weights[:, 0] * torch.sin(inputs_) + self.weights[
            :, 1
        ] * torch.cos(inputs_)
        if self.apply_bias:
            inputs_ = inputs_ + self.bias
        inputs_ = self.activation(inputs_)
        out = torch.cat((inputs_, inputs[..., self.ind_]), -1)
        return out[..., self.inv_perm]
__init__(ndim, ind, scale=1.0, bias=False, activation=None)

Constructor

Parameters:

Name Type Description Default
ndim int

number of dimensions

required
ind iterable

indices of input elements to convert to periodic features

required
scale

Scalar or iterable, used to scale inputs before converting them to periodic features

1.0
bias

Flag, whether to add a bias

False
activation

Function or None, activation function to be applied

None
Source code in normflows/utils/nn.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
def __init__(self, ndim, ind, scale=1.0, bias=False, activation=None):
    """Constructor

    Args:
      ndim (int): number of dimensions
      ind (iterable): indices of input elements to convert to periodic features
      scale: Scalar or iterable, used to scale inputs before converting them to periodic features
      bias: Flag, whether to add a bias
      activation: Function or None, activation function to be applied
    """
    super(PeriodicFeaturesElementwise, self).__init__()

    # Set up indices and permutations
    self.ndim = ndim
    if torch.is_tensor(ind):
        self.register_buffer("ind", torch._cast_Long(ind))
    else:
        self.register_buffer("ind", torch.tensor(ind, dtype=torch.long))

    ind_ = []
    for i in range(self.ndim):
        if not i in self.ind:
            ind_ += [i]
    self.register_buffer("ind_", torch.tensor(ind_, dtype=torch.long))

    perm_ = torch.cat((self.ind, self.ind_))
    inv_perm_ = torch.zeros_like(perm_)
    for i in range(self.ndim):
        inv_perm_[perm_[i]] = i
    self.register_buffer("inv_perm", inv_perm_)

    self.weights = nn.Parameter(torch.ones(len(self.ind), 2))
    if torch.is_tensor(scale):
        self.register_buffer("scale", scale)
    else:
        self.scale = scale

    self.apply_bias = bias
    if self.apply_bias:
        self.bias = nn.Parameter(torch.zeros(len(self.ind)))

    if activation is None:
        self.activation = torch.nn.Identity()
    else:
        self.activation = activation

sum_except_batch(x, num_batch_dims=1)

Sums all elements of x except for the first num_batch_dims dimensions.

Source code in normflows/utils/nn.py
190
191
192
193
def sum_except_batch(x, num_batch_dims=1):
    """Sums all elements of `x` except for the first `num_batch_dims` dimensions."""
    reduce_dims = list(range(num_batch_dims, x.ndimension()))
    return torch.sum(x, dim=reduce_dims)

optim

clear_grad(model)

Set gradients of model parameter to None as this speeds up training,

See youtube

Parameters:

Name Type Description Default
model

Model to clear gradients of

required
Source code in normflows/utils/optim.py
16
17
18
19
20
21
22
23
24
25
def clear_grad(model):
    """Set gradients of model parameter to None as this speeds up training,

    See [youtube](https://www.youtube.com/watch?v=9mS1fIYj1So)

    Args:
      model: Model to clear gradients of
    """
    for param in model.parameters():
        param.grad = None

set_requires_grad(module, flag)

Sets requires_grad flag of all parameters of a torch.nn.module

Parameters:

Name Type Description Default
module

torch.nn.module

required
flag

Flag to set requires_grad to

required
Source code in normflows/utils/optim.py
 4
 5
 6
 7
 8
 9
10
11
12
13
def set_requires_grad(module, flag):
    """Sets requires_grad flag of all parameters of a torch.nn.module

    Args:
      module: torch.nn.module
      flag: Flag to set requires_grad to
    """

    for param in module.parameters():
        param.requires_grad = flag

preprocessing

Jitter

Transform for dataloader, adds uniform jitter noise to data

Source code in normflows/utils/preprocessing.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class Jitter:
    """Transform for dataloader, adds uniform jitter noise to data"""

    def __init__(self, scale=1.0 / 256):
        """Constructor

        Args:
          scale: Scaling factor for noise
        """
        self.scale = scale

    def __call__(self, x):
        eps = torch.rand_like(x) * self.scale
        x_ = x + eps
        return x_
__init__(scale=1.0 / 256)

Constructor

Parameters:

Name Type Description Default
scale

Scaling factor for noise

1.0 / 256
Source code in normflows/utils/preprocessing.py
31
32
33
34
35
36
37
def __init__(self, scale=1.0 / 256):
    """Constructor

    Args:
      scale: Scaling factor for noise
    """
    self.scale = scale

Logit

Transform for dataloader

logit(alpha + (1 - alpha) * x) where logit(x) = log(x / (1 - x))
Source code in normflows/utils/preprocessing.py
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class Logit:
    """Transform for dataloader

    ```
    logit(alpha + (1 - alpha) * x) where logit(x) = log(x / (1 - x))
    ```
    """

    def __init__(self, alpha=0):
        """Constructor

        Args:
          alpha: see above
        """
        self.alpha = alpha

    def __call__(self, x):
        x_ = self.alpha + (1 - self.alpha) * x
        return torch.log(x_ / (1 - x_))

    def inverse(self, x):
        return (torch.sigmoid(x) - self.alpha) / (1 - self.alpha)
__init__(alpha=0)

Constructor

Parameters:

Name Type Description Default
alpha

see above

0
Source code in normflows/utils/preprocessing.py
12
13
14
15
16
17
18
def __init__(self, alpha=0):
    """Constructor

    Args:
      alpha: see above
    """
    self.alpha = alpha

Scale

Transform for dataloader, adds uniform jitter noise to data

Source code in normflows/utils/preprocessing.py
45
46
47
48
49
50
51
52
53
54
55
56
57
class Scale:
    """Transform for dataloader, adds uniform jitter noise to data"""

    def __init__(self, scale=255.0 / 256.0):
        """Constructor

        Args:
          scale: Scaling factor for noise
        """
        self.scale = scale

    def __call__(self, x):
        return x * self.scale
__init__(scale=255.0 / 256.0)

Constructor

Parameters:

Name Type Description Default
scale

Scaling factor for noise

255.0 / 256.0
Source code in normflows/utils/preprocessing.py
48
49
50
51
52
53
54
def __init__(self, scale=255.0 / 256.0):
    """Constructor

    Args:
      scale: Scaling factor for noise
    """
    self.scale = scale