Skip to content

How-to guide

In the project we put every functionality changes into yaml files. Every aspect of tunning, dataset change, trainer change can be done with simply changing parameters in the yaml file. Our pipeline uses hydra to load and manage pipeline with YAML structures. Firstly we will show you structure of the yaml file and than present you three sample yaml files used in this project. At the end we will also provide major modules used in the pipeline.

If you play to run train.py script your yaml file need to be in conf folder with the name config.yaml as hydra loads this file for the script. If you want to make changes just to them to the fail and for saving for now just rename it 🥸.

your_project/
│
├── conf/
│   ├── config.yaml
│   └── legacy.yaml
│
├── modules/
│   ├── __init__.py
│   ├── config.py
│   ├── losses.py
│   ├── models.py
│   ├── training.py
│   ├── utils.py
│   └── vqgan.py
│
└── train.py

YAML Structures

Main yaml config file should have base structure LoadConfig defined in modules.config.py

LoadConfig

Load configuration class to store the configuration of a train model and data. Main configuration class to be used for training.

Parameters:

Name Type Description Default
train DataConfig

data configuration.

required
data TrainConfig

training configuration.

required
Source code in modules/config.py
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
@dataclass
class LoadConfig:
    """Load configuration class to store the configuration of a train model and data.
    Main configuration class to be used for training.

    Arguments:
        train (DataConfig): data configuration.
        data (TrainConfig): training configuration.
    """

    train: TrainConfig
    data: DataConfig

    def __post_init__(self):
        self.train = (
            TrainConfig(**self.train) if self.train is not None else TrainConfig()  # type: ignore
        )
        self.data = (
            DataConfig(**self.data) if self.data is not None else DataConfig()  # type: ignore
        )

        # set resolution
        self.train.disc_hparams.resolution = self.data.size
        self.train.model_hparams.resolution = self.data.size

Config structure have two parameters data specifying dataset and dataloader parameters and train specifing architecture and training parameters.

DataConfig

Data configuration class.

Parameters:

Name Type Description Default
train_params DataParams

training data parameters. Check DataParams for more details.

required
test_params DataParams

testing data parameters. Check DataParams for more details.

required
dataset_name str

name of the dataset to use. Currently only supports "voc".

''
dataset_root str

root directory of the dataset.

''
transform optional, dict

transform to apply to the dataset. Default None for no transf. Transform dict comes from albumentations library. Check albumentations for more details. Loading transform should proceed with albumentations.from_dict method.

None
size int

size of image width and height.

224
Source code in modules/config.py
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
@dataclass
class DataConfig:
    """Data configuration class.
    Arguments:
        train_params (DataParams): training data parameters. Check DataParams for more details.
        test_params (DataParams): testing data parameters. Check DataParams for more details.
        dataset_name (str): name of the dataset to use. Currently only supports "voc".
        dataset_root (str): root directory of the dataset.
        transform (optional, dict): transform to apply to the dataset. Default None for no transf.
            Transform dict comes from albumentations library. Check albumentations for more details.
            Loading transform should proceed with albumentations.from_dict method.
        size (int): size of image width and height.

    """

    train_params: DataParams
    test_params: DataParams
    dataset_name: str = ""
    dataset_root: str = ""
    transform: Optional[Dict[str, Any]] = None
    size: int = 224

    def __post_init__(self):
        # set train_params and test_params
        self.train_params = DataParams(**self.train_params)  # type: ignore
        self.test_params = DataParams(**self.test_params)  # type: ignore

DataConfig tells everything you need to know about downloading Tensorflow datasets and processing it. Agumentation information for the pipeline is based on albumentations framework and please refer to it for additional changes. This config relays on train_params and test_params telling about shuffling and batch size for train and test splits.

DataParams

Train and test data parameters.

Parameters:

Name Type Description Default
batch_size int

batch size for training.

required
shuffle bool

whether to shuffle the dataset.

required
Source code in modules/config.py
215
216
217
218
219
220
221
222
223
224
@dataclass
class DataParams:
    """Train and test data parameters.
    Arguments:
        batch_size (int): batch size for training.
        shuffle (bool): whether to shuffle the dataset.
    """

    batch_size: int
    shuffle: bool

TrainConfig

Configuration class to store the configuration of a train model.

Parameters:

Name Type Description Default
model_name str

name of the model to train. Used for saving and logging.

required
model_hparams VQGANConfig

model hyperparameters. Check VQGANConfig for more details.

required
disc_hparams DiscConfig

discriminator hyperparameters. Check DiscConfig for more details.

required
save_dir str

directory to save the model.

required
log_dir str

directory to save the logs for tensorboard.

required
check_val_every_n_epoch int

number of epochs to run validation.

required
log_img_every_n_epoch int

number of epochs to log images.

required
input_shape Tuple[int, int, int]

shape of the input image (H, W, C).

required
codebook_weight float

weight for the codebook loss (Quantizer part).

required
monitor str

metric to monitor for saving best model.

required
recon_loss str

reconstruction loss to use. Can be one of l1, l2, comb, mape.

required
disc_loss str

discriminator loss to use. Can be vanilla or hinge.

required
disc_weight float

weight for the discriminator loss.

required
num_epochs int

number of epochs to train.

required
dtype str

dtype to use for training. Supported: float32, float16, float16, bfloat16.

required
distributed bool

whether to use distributed training.

required
seed int

seed for random number generation.

required
optimizer str

optimizer to use for training. Structure needs to be optax Optimizer (Check optax for more details) with 'target' parameter, for specifing optax optimizer, and 'kwargs' parameter for passing to optimizer. check config_test.yaml for example.

required
optimizer_disc str

optimizer to use for discriminator training. Similar to optimizer.

required
disc_start int

number of epochs to past to start using the discriminator.

required
temp_scheduler optional

temperature scheduler to use for training. Similar to optimizer but uses optax scheduler with 'target' parameter, for specifing optax scheduler. if None, then no scheduler is used. Check config_test.yaml for example.

required
Source code in modules/config.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
@dataclass
class TrainConfig:
    """Configuration class to store the configuration of a train model.

    Arguments:
        model_name (str): name of the model to train. Used for saving and logging.
        model_hparams (VQGANConfig): model hyperparameters. Check VQGANConfig for more details.
        disc_hparams (DiscConfig): discriminator hyperparameters.
            Check DiscConfig for more details.
        save_dir (str): directory to save the model.
        log_dir (str): directory to save the logs for tensorboard.
        check_val_every_n_epoch (int): number of epochs to run validation.
        log_img_every_n_epoch (int): number of epochs to log images.
        input_shape (Tuple[int, int, int]): shape of the input image (H, W, C).
        codebook_weight (float): weight for the codebook loss (Quantizer part).
        monitor (str): metric to monitor for saving best model.
        recon_loss (str): reconstruction loss to use. Can be one of `l1`, `l2`, `comb`, `mape`.
        disc_loss (str): discriminator loss to use. Can be `vanilla` or `hinge`.
        disc_weight (float): weight for the discriminator loss.
        num_epochs (int): number of epochs to train.
        dtype (str): dtype to use for training.
            Supported: `float32`, `float16`, `float16`, `bfloat16`.
        distributed (bool): whether to use distributed training.
        seed (int): seed for random number generation.
        optimizer (str): optimizer to use for training. Structure needs to be
            optax Optimizer (Check optax for more details) with '__target__' parameter,
            for specifing optax optimizer, and 'kwargs' parameter for passing to optimizer.
            check config_test.yaml for example.
        optimizer_disc (str): optimizer to use for discriminator training. Similar to optimizer.
        disc_start (int): number of epochs to past to start using the discriminator.
        temp_scheduler (optional): temperature scheduler to use for training. Similar to optimizer
            but uses optax scheduler with '__target__' parameter, for specifing optax scheduler.
            if None, then no scheduler is used. Check config_test.yaml for example.
    """

    model_name: str
    model_hparams: VQGANConfig
    disc_hparams: DiscConfig
    save_dir: str
    log_dir: str
    check_val_every_n_epoch: int
    log_img_every_n_epoch: int
    input_shape: Tuple[int, ...]
    codebook_weight: float
    monitor: str
    recon_loss: str
    disc_loss: str
    disc_weight: float
    num_epochs: int
    dtype: jnp.dtype
    distributed: bool
    seed: int
    optimizer: optax.GradientTransformation
    optimizer_disc: optax.GradientTransformation
    disc_start: int
    temp_scheduler: Optional[Callable]

    def __post_init__(self):
        # load model hparams
        self.model_hparams = (
            VQGANConfig(**self.model_hparams) if self.model_hparams is not None else VQGANConfig()
        )
        if not isinstance(self.model_hparams, VQGANConfig):
            raise TypeError("model_hparams could not create VQGANConfig")
        # load disc hparams
        self.disc_hparams = (
            DiscConfig(**self.disc_hparams) if self.disc_hparams is not None else DiscConfig()
        )
        if not isinstance(self.disc_hparams, DiscConfig):
            raise TypeError("disc_hparams could not create DiscConfig")
        # conver shape list to tuple shape
        self.input_shape = tuple(self.input_shape)
        if len(self.input_shape) != 3:
            raise ValueError(f"input_shape: {self.input_shape} should be of length 3")
        # set dtype
        if self.dtype == "float64":
            self.dtype = jnp.float64
        elif self.dtype == "float32":
            self.dtype = jnp.float32
        elif self.dtype == "float16":
            self.dtype = jnp.float16
        elif self.dtype == "bfloat16":
            self.dtype = jnp.bfloat16
        else:
            raise ValueError(
                f"""Invalid dtype {self.dtype}
                             expected one of float64, float32, float16, bfloat16"""
            )
        # instantiate the optimizer
        self.optimizer = instantiate(self.optimizer)
        if not isinstance(self.optimizer, optax.GradientTransformation):
            raise TypeError("optimizer should be optax GradientTransformation dict to instantiate")
        # instantiate the optimizer for discriminator
        self.optimizer_disc = instantiate(self.optimizer_disc)
        if not isinstance(self.optimizer_disc, optax.GradientTransformation):
            raise TypeError(
                "optimizer_disc should be optax GradientTransformation dict to instantiate"
            )
        # if optimizer is a dict, instantiate it
        if self.temp_scheduler is not None:
            self.temp_scheduler: Callable = instantiate(self.temp_scheduler)
            if not hasattr(self.temp_scheduler, "__call__"):
                raise TypeError("temp_scheduler should be a callable or None")

TrainConfig is main config for setting trainer. The most important parameters are model_name, save_dir, log_dir, dtype, seed, distributed. dtype, seed and distributed are parameters also used in datasets (For now we support only false for distributed). save_dir and log_dir are paths for model checkpointing and tensorboard saving. model_name is the name of model which is referenced in saving and logging to tensorboard so you need to keep an eye on this parameter. optimize and temp_scheduler are parameters which are instantiate by hydra and for this we use objects from optax (please refer to samples). model_hparams contains all the parameters for VQGAN module architecture and disc_hparams contains parameters for Discriminator.

DiscConfig

Bases: PretrainedConfig

Configuration class to store the configuration of a Discriminator model. Dataclass for storing is based on PretrainedConfig from transformers package.

Parameters:

Name Type Description Default
input_last_dim int

last dimension of the input sample in Discriminator.

3
output_last_dim int

last dimension of the output sample in Discriminator.

1
resolution int

resolution of the input image (256x256).

256
ndf int

number of filters in the first layer of Discriminator.

64
n_layers int

number of layers in Discriminator.

3
Source code in modules/config.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
class DiscConfig(PretrainedConfig):
    """Configuration class to store the configuration of a Discriminator model.
    Dataclass for storing is based on `PretrainedConfig` from `transformers` package.

    Args:
        input_last_dim (int): last dimension of the input sample in Discriminator.
        output_last_dim (int): last dimension of the output sample in Discriminator.
        resolution (int): resolution of the input image (256x256).
        ndf (int): number of filters in the first layer of Discriminator.
        n_layers (int): number of layers in Discriminator.
    """

    def __init__(
        self,
        input_last_dim: int = 3,
        output_last_dim: int = 1,
        resolution: int = 256,
        ndf: int = 64,
        n_layers: int = 3,
        **kwargs: Any,
    ):
        super().__init__(**kwargs)
        self.input_last_dim = input_last_dim
        self.output_last_dim = output_last_dim
        self.resolution = resolution
        self.ndf = ndf
        self.n_layers = n_layers

VQGANConfig

Bases: PretrainedConfig

Configuration class to store the configuration of a VQGAN model. Dataclass for storing is based on PretrainedConfig from transformers package.

Parameters:

Name Type Description Default
ch int

number of channels.

128
out_ch int

number of output channels (RGB).

3
in_channels int

number of input channels (RGB).

3
num_res_blocks int

number of residual blocks.

2
resolution int

resolution of the input image (256x256).

256
z_channels int

number of channels in the latent space.

256
ch_mult Tuple[int]

channel multiplier for each layer.

tuple([1, 1, 2, 2, 4])
attn_resolutions Tuple[int]

resolutions at which to apply attention.

(16)
n_embed int

number of embeddings, unique codes in the latent space.

1024
embed_dim int

dimension of embedding from Encoder.

256
dropout float

dropout rate.

0.0
double_z bool

whether to double the latent space for.

False
resamp_with_conv bool

whether to use convolutions for upsampling.

True
use_gumbel bool

whether to use gumbel softmax for quantization.

False
gumb_temp float

temperature for gumbel softmax.

1.0
act_name str

activation function name to use.

'swish'
give_pre_end bool

whether to give the pre-end layer for the decoder.

False
kwargs Any

keyword arguments passed along to the super class.

{}
Source code in modules/config.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
class VQGANConfig(PretrainedConfig):
    """Configuration class to store the configuration of a VQGAN model.
    Dataclass for storing is based on `PretrainedConfig` from `transformers` package.

    Args:
        ch (int): number of channels.
        out_ch (int): number of output channels (RGB).
        in_channels (int): number of input channels (RGB).
        num_res_blocks (int): number of residual blocks.
        resolution (int): resolution of the input image (256x256).
        z_channels (int): number of channels in the latent space.
        ch_mult (Tuple[int]): channel multiplier for each layer.
        attn_resolutions (Tuple[int]): resolutions at which to apply attention.
        n_embed (int): number of embeddings, unique codes in the latent space.
        embed_dim (int): dimension of embedding from Encoder.
        dropout (float): dropout rate.
        double_z (bool): whether to double the latent space for.
        resamp_with_conv (bool): whether to use convolutions for upsampling.
        use_gumbel (bool): whether to use gumbel softmax for quantization.
        gumb_temp (float): temperature for gumbel softmax.
        act_name (str): activation function name to use.
        give_pre_end (bool): whether to give the pre-end layer for the decoder.
        kwargs: keyword arguments passed along to the super class.
    """

    def __init__(
        self,
        ch: int = 128,
        out_ch: int = 3,
        in_channels: int = 3,
        num_res_blocks: int = 2,
        resolution: int = 256,
        z_channels: int = 256,
        ch_mult: Tuple[int, ...] = tuple([1, 1, 2, 2, 4]),
        attn_resolutions: Tuple[int] = (16,),
        n_embed: int = 1024,
        embed_dim: int = 256,
        dropout: float = 0.0,
        double_z: bool = False,
        resamp_with_conv: bool = True,
        use_gumbel: bool = False,
        gumb_temp: float = 1.0,
        beta: float = 0.25,
        kl_weight: float = 5e-4,
        act_name: str = "swish",
        give_pre_end: bool = False,
        **kwargs: Any,
    ):
        super().__init__(**kwargs)
        self.ch = ch
        self.out_ch = out_ch
        self.in_channels = in_channels
        self.num_res_blocks = num_res_blocks
        self.resolution = resolution
        self.z_channels = z_channels
        self.ch_mult = list(ch_mult)
        self.attn_resolutions = list(attn_resolutions)
        self.n_embed = n_embed
        self.embed_dim = embed_dim
        self.dropout = dropout
        self.double_z = double_z
        self.resamp_with_conv = resamp_with_conv
        self.use_gumbel = use_gumbel
        self.gumb_temp = gumb_temp
        self.beta = beta
        self.kl_weight = kl_weight
        self.act_name = act_name
        self.give_pre_end = give_pre_end
        self.num_resolutions = len(ch_mult)

VQGANConfig major parameter here is use_gumbel. VQGAN can be trained with Gumbel-max Trick (Original paper) which gives our bottleneck distribution on which we choose argmax for the code to assign.

Samples

We provide three samples of data and train configs: - config.yaml my training config on imagenette dataset. - gumbel.yaml official training config on imagenet dataset. - imagenet.yaml official training config with Gumble tick on imagenet dataset.

Major Modules

Major Modules used in the pipeline are:

TrainerVQGan

TrainerVQGan in modules.training, this modules responds for training VQGAN

Bases: TrainerModule

Helper functions for training VQGAN.

Source code in modules/training.py
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
class TrainerVQGan(TrainerModule):
    """Helper functions for training VQGAN."""

    def __init__(self, module_config: config.TrainConfig):
        # Initialize parent class
        self.module_config = module_config
        self.model_name = self.module_config.model_name
        if self.module_config.recon_loss == "l2":
            self.recon_loss_fn = losses.l2_loss
        elif self.module_config.recon_loss == "l1":
            self.recon_loss_fn = losses.l1_loss
        elif self.module_config.recon_loss == "combo":
            self.recon_loss_fn = losses.combo_loss
        elif self.module_config.recon_loss == "mape":
            self.recon_loss_fn = losses.mape_loss
        else:
            logger.warning(
                f"""Reconstruction loss function {self.module_config.recon_loss} not supported.
                Will be used default l1 loss instead."""
            )
            self.recon_loss_fn = losses.l1_loss
        # Train state for discriminator
        self.model_disc: FlaxPreTrainedModel = vqgan.VQGanDiscriminator(
            self.module_config.disc_hparams
        )
        if self.module_config.disc_loss == "vanilla":
            self.disc_loss_fn = losses.disc_loss_vanilla
        elif self.module_config.disc_loss == "hinge":
            self.disc_loss_fn = losses.disc_loss_hinge
        else:
            logger.warning(
                f"""Discriminator loss function {self.module_config.disc_loss} not supported.
                Will be used default hinge loss instead."""
            )
            self.disc_loss_fn = losses.disc_loss_hinge
        self.state_disc = TrainStateDisc(
            step=0,
            apply_fn=self.model_disc.__call__,
            params=self.model_disc.params["params"],
            batch_stats=self.model_disc.params["batch_stats"],
            tx=None,
            opt_state=None,
        )
        # Log dir discriminator loss
        self.save_dir_disc: str = os.path.join(
            self.module_config.save_dir, f"{self.model_name}_disc/"
        )
        self.temp_scheduler: Optional[Callable] = self.module_config.temp_scheduler

        super().__init__(module_config=module_config, model_class=vqgan.VQModel)

    def temperature_scheduling(self, epoch: int) -> float:
        """Temperature scheduling.
        Args:
            epoch (int): Current epoch.
        Returns:
            Temperature.
        """
        if self.temp_scheduler is not None:
            return self.model.update_temperature(self.temp_scheduler(epoch))
        else:
            return self.module_config.model_hparams.gumb_temp

    def init_optimizer(self):
        """Initialize optimizer and scheduler also for discriminator.
        By default, we decrease the learning rate with cosine annealing.
        """
        optimizer: optax.GradientTransformation = self.module_config.optimizer
        optimizer_disc: optax.GradientTransformation = self.module_config.optimizer_disc
        self.create_train_stat_full(optimizer, optimizer_disc)

    def create_train_stat_full(
        self,
        optimizer: optax.GradientTransformation,
        optimizer_disc: optax.GradientTransformation,
    ):
        """Initialize training state.
        Args:
            optimizer (optax.GradientTransformation): Optimizer for generator.
            optimizer_disc (optax.GradientTransformation): Optimizer for discriminator.
        """
        self.state = train_state.TrainState.create(
            apply_fn=self.state.apply_fn, params=self.state.params, tx=optimizer
        )
        self.state_disc = TrainStateDisc.create(
            apply_fn=self.state_disc.apply_fn,
            params=self.state_disc.params,
            batch_stats=self.state_disc.batch_stats,
            tx=optimizer_disc,
        )

    def save_model(self, step: Optional[int] = None):
        """Save current model.
        Args:
            step (int, optional): Current step. Defaults to None.
        """
        step = step if step is not None else self.module_config.num_epochs
        super().save_model(step=step)
        checkpoints.save_checkpoint(
            ckpt_dir=self.save_dir_disc, target=self.state_disc, step=step, overwrite=True
        )

    def load_model(self):
        """Load model."""
        super().load_model()
        state_dict = checkpoints.restore_checkpoint(ckpt_dir=self.save_dir_disc, target=None)
        self.state_disc = TrainStateDisc(
            apply_fn=self.state_disc.apply_fn,
            params=state_dict["params"],
            batch_stats=state_dict["batch_stats"],
            step=state_dict["step"],
            tx=self.state_disc.tx if self.state_disc.tx else self.module_config.optimizer_disc,
            opt_state=state_dict["opt_state"],
        )
        self.model_disc.params["params"] = self.state_disc.params
        self.model_disc.params["batch_stats"] = self.state_disc.batch_stats

    def checkpoint_exists(self) -> bool:
        """Check whether a pretrained model exist.
        Returns:
            True if model and discriminator exists, False otherwise.
        """
        main_model: bool = os.path.exists(self.save_dir) and len(os.listdir(self.log_dir)) > 0
        disc_model: bool = os.path.exists(self.save_dir_disc) and len(os.listdir(self.log_dir)) > 0
        return main_model and disc_model

    def train_epoch(self, data_loader: utils.DataLoader, epoch: int) -> Dict[str, float]:
        """Train model for one epoch, and log avg metrics.
        Args:
            data_loader (utils.DataLoader): Data loader to train on.
        Returns:
            Dictionary with all metrics.
        """
        metrics: Dict[str, float] = defaultdict(float)
        metrics_disc = defaultdict(float)
        metrics_disc["step"] = 0.0
        new_temp: float = self.temperature_scheduling(epoch - 1)
        for batch in tqdm(data_loader(), desc="Training", leave=False):
            batch_metrics: Dict[str, float]
            train_outs = self.train_step(  # type: ignore
                state=self.state,
                disc_state=self.state_disc,
                batch=batch,
                rng=self.main_rng,
                optimizer_idx=0,
                disc_use=self.module_config.disc_start > epoch,
                distributed=self.module_config.distributed,
            )
            self.state, self.state_disc, self.main_rng, batch_metrics = train_outs
            # Update metrics
            for key, value in batch_metrics.items():
                metrics[key] += value

            if self.module_config.disc_start > epoch:
                batch_metrics_disc: Dict[str, float]
                train_outs = self.train_step(  # type: ignore
                    state=self.state,
                    disc_state=self.state_disc,
                    batch=batch,
                    rng=self.main_rng,
                    optimizer_idx=1,
                    disc_use=True,
                    distributed=self.module_config.distributed,
                )
                self.state, self.state_disc, self.main_rng, batch_metrics_disc = train_outs
                # Update metrics discriminator
                for key, value in batch_metrics_disc.items():
                    metrics_disc[key] += value
                metrics_disc["step"] += 1.0

            # ensure that model have actual parameters
            self.model.params = self.state.params
            self.model_disc.params["params"] = self.state_disc.params
            self.model_disc.params["batch_stats"] = self.state_disc.batch_stats

        count = len(data_loader)
        metrics = {key: metrics[key] / count for key in metrics}
        metrics["temp"] = new_temp
        count_disc = metrics_disc["step"]
        del metrics_disc["step"]
        metrics_disc_resized: Dict[str, float] = {
            key: metrics_disc[key] / count_disc for key in metrics_disc
        }
        # merge metrics
        for key, value in metrics_disc_resized.items():
            metrics[key] = value

        return metrics

    def create_functions(self):
        """Create training and eval functions."""
        recon_loss_fn: Callable = self.recon_loss_fn
        disc_loss_fn: Callable = self.disc_loss_fn

        def calculate_loss_autoencoder(
            params: FrozenDict[str, Any],
            batch: jnp.ndarray,
            train: bool,
            rng: Union[Any, jnp.ndarray],
            disc_use: bool,
            disc_variables: TrainStateDisc,
        ) -> Tuple[Any, Tuple[Dict[str, Any], Union[Any, jnp.ndarray], Any]]:
            """Function to calculate the loss autoencoder for a batch of images."""
            new_rng, gumble_apply_rng, dropout_apply_rng = jax.random.split(rng, num=3)
            outs = self.model(
                batch,
                params=params,
                dropout_rng=dropout_apply_rng,
                gumble_rng=gumble_apply_rng,
                train=train,
            )
            x_recon, z_q, codebook_loss, indices = outs
            # for now we will use l1 loss than it will be combined with perceptual loss
            rec_loss = recon_loss_fn(x_recon, batch)
            nll_loss = jnp.mean(rec_loss)

            # Generator loss (autoencode)
            outs = self.model_disc(
                x_recon,
                params=disc_variables.params,
                batch_stats=disc_variables.batch_stats,
                train=train,
            )
            logits_fake, new_model_state = outs if train else (outs, None)
            # Original loss is
            # g_loss = -jnp.mean(logits_fake)
            # But we think that based on disc for generator should work we will use minimax loss
            # This loss is none negative and tries to maximize the probability of the fake_logits.
            g_loss = jnp.mean(jnp.maximum(1.0 - logits_fake, 0.0))
            disc_factor = 0.0
            if disc_use:
                disc_factor = self.module_config.disc_weight
            disc_factor = jax.lax.cond(
                disc_use, lambda _: self.module_config.disc_weight, lambda _: 0.0, None
            )
            loss = (
                nll_loss
                + disc_factor * g_loss
                + self.module_config.codebook_weight * codebook_loss.mean()
            )

            metrics = {
                "total_loss": loss,
                "quant_loss": jnp.mean(codebook_loss),
                "nll_loss": nll_loss,
                "rec_loss": jnp.mean(rec_loss),
                "g_loss": g_loss,
            }

            return loss, (metrics, new_rng, new_model_state)

        def calculate_loss_disc(
            params: FrozenDict[str, Any],
            batch: jnp.ndarray,
            train: bool,
            rng: Union[Any, jnp.ndarray],
            disc_use: bool,
            batch_stats: FrozenDict[str, Any],
            model_params: Optional[FrozenDict[str, Any]],
        ) -> Tuple[Any, Tuple[Dict[str, Any], Union[Any, jnp.ndarray], Any]]:
            """Function to calculate the loss discriminator for a batch of images."""
            new_rng, gumble_apply_rng, dropout_apply_rng = jax.random.split(rng, num=3)
            outs = self.model(
                batch,
                params=model_params,
                dropout_rng=dropout_apply_rng,
                gumble_rng=gumble_apply_rng,
                train=train,
            )
            x_recon, z_q, codebook_loss, indices = outs

            # Discriminator loss
            outs = self.model_disc(batch, params=params, batch_stats=batch_stats, train=train)
            logits_real, new_model_state = outs if train else (outs, None)
            outs = self.model_disc(x_recon, params=params, batch_stats=batch_stats, train=train)
            logits_fake, new_model_state = outs if train else (outs, None)
            disc_factor = jax.lax.cond(
                disc_use, lambda _: self.module_config.disc_weight, lambda _: 0.0, None
            )
            loss = disc_factor * disc_loss_fn(logits_real, logits_fake)
            metrics = {
                "disc_loss": loss,
                "logits_real": logits_real.mean(),
                "logits_fake": logits_fake.mean(),
            }

            return loss, (metrics, new_rng, new_model_state)

        def train_step_autoencoder(
            state: train_state.TrainState,
            disc_state: TrainStateDisc,
            batch: jnp.ndarray,
            rng: Union[Any, jnp.ndarray],
            disc_use: bool,
            distributed: bool = False,
        ) -> Tuple[
            train_state.TrainState, TrainStateDisc, Union[Any, jnp.ndarray], Dict[str, float]
        ]:
            """Train step for autoencoder."""
            loss_fn = partial(
                calculate_loss_autoencoder,
                batch=batch,
                train=True,
                rng=rng,
                disc_use=disc_use,
                disc_variables=disc_state,
            )
            (_, (metrics, new_rng, new_model_state)), grads = jax.value_and_grad(
                loss_fn, has_aux=True
            )(state.params)
            # if distributed training, average grads
            if distributed:
                grads = jax.lax.pmean(grads, axis_name="batch")
            # Update parameters
            state = state.apply_gradients(grads=grads)
            return state, disc_state, new_rng, metrics

        def train_step_disc(
            state: train_state.TrainState,
            disc_state: TrainStateDisc,
            batch: jnp.ndarray,
            rng: Union[Any, jnp.ndarray],
            disc_use: bool,
            distributed: bool = False,
        ) -> Tuple[
            train_state.TrainState, TrainStateDisc, Union[Any, jnp.ndarray], Dict[str, float]
        ]:
            """Train step for discriminator."""
            loss_fn = partial(
                calculate_loss_disc,
                batch=batch,
                train=True,
                rng=rng,
                disc_use=disc_use,
                batch_stats=disc_state.batch_stats,
                model_params=state.params,
            )
            (_, (metrics, new_rng, new_model_state)), grads = jax.value_and_grad(
                loss_fn, has_aux=True
            )(disc_state.params)
            # if distributed training, average grads
            if distributed:
                grads = jax.lax.pmean(grads, axis_name="batch")
            # Update parameters, batch statistics
            disc_state = disc_state.apply_gradients(
                grads=grads, batch_stats=new_model_state["batch_stats"]
            )
            return state, disc_state, new_rng, metrics

        def train_step(
            state: train_state.TrainState,
            disc_state: TrainStateDisc,
            batch: jnp.ndarray,
            rng: Union[Any, jnp.ndarray],
            optimizer_idx: int,
            disc_use: bool,
            distributed: bool,
        ) -> Tuple[
            train_state.TrainState, TrainStateDisc, Union[Any, jnp.ndarray], Dict[str, float]
        ]:
            """Train model on a single batch."""
            # calculate loss
            if optimizer_idx == 0:
                outs = train_step_autoencoder(state, disc_state, batch, rng, disc_use, distributed)
            else:
                outs = train_step_disc(state, disc_state, batch, rng, disc_use, distributed)
            # outs = jax.lax.cond(
            #  optimizer_idx == 0,
            #    lambda _: train_step_autoencoder(state,
            #                                     disc_state,
            #                                     batch,
            #                                     rng,
            #                                     disc_use,
            #                                     distributed),
            #    lambda _: train_step_disc(state,
            #                              disc_state,
            #                              batch,
            #                              rng,
            #                              disc_use,
            #                              distributed),
            #    None)
            state, disc_state, new_rng, metrics = outs
            return state, disc_state, new_rng, metrics

        def eval_step(
            state: train_state.TrainState,
            batch: jnp.ndarray,
            disc_state: TrainStateDisc,
            rng: Union[Any, jnp.ndarray],
        ) -> Tuple[Union[Any, jnp.ndarray], Dict[str, float]]:
            """Evaluate model on a single batch."""
            _, (metrics, new_rng, _) = calculate_loss_autoencoder(
                state.params,
                batch=batch,
                train=False,
                rng=rng,
                disc_use=False,
                disc_variables=disc_state,
            )
            return new_rng, metrics

        # pmap or jit for efficiency
        if self.module_config.distributed:
            self.train_step = jax.pmap(  # type: ignore
                train_step, axis_name="batch", static_broadcasted_argnums=(4, 5, 6)
            )
            self.eval_step = jax.jit(partial(eval_step, disc_state=self.state_disc))  # type: ignore
        else:
            self.train_step = jax.jit(train_step, static_argnums=(4, 5, 6))  # type: ignore
            self.eval_step = jax.jit(partial(eval_step, disc_state=self.state_disc))  # type: ignore

checkpoint_exists()

Check whether a pretrained model exist.

Returns:

Type Description
bool

True if model and discriminator exists, False otherwise.

Source code in modules/training.py
454
455
456
457
458
459
460
461
def checkpoint_exists(self) -> bool:
    """Check whether a pretrained model exist.
    Returns:
        True if model and discriminator exists, False otherwise.
    """
    main_model: bool = os.path.exists(self.save_dir) and len(os.listdir(self.log_dir)) > 0
    disc_model: bool = os.path.exists(self.save_dir_disc) and len(os.listdir(self.log_dir)) > 0
    return main_model and disc_model

create_functions()

Create training and eval functions.

Source code in modules/training.py
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
def create_functions(self):
    """Create training and eval functions."""
    recon_loss_fn: Callable = self.recon_loss_fn
    disc_loss_fn: Callable = self.disc_loss_fn

    def calculate_loss_autoencoder(
        params: FrozenDict[str, Any],
        batch: jnp.ndarray,
        train: bool,
        rng: Union[Any, jnp.ndarray],
        disc_use: bool,
        disc_variables: TrainStateDisc,
    ) -> Tuple[Any, Tuple[Dict[str, Any], Union[Any, jnp.ndarray], Any]]:
        """Function to calculate the loss autoencoder for a batch of images."""
        new_rng, gumble_apply_rng, dropout_apply_rng = jax.random.split(rng, num=3)
        outs = self.model(
            batch,
            params=params,
            dropout_rng=dropout_apply_rng,
            gumble_rng=gumble_apply_rng,
            train=train,
        )
        x_recon, z_q, codebook_loss, indices = outs
        # for now we will use l1 loss than it will be combined with perceptual loss
        rec_loss = recon_loss_fn(x_recon, batch)
        nll_loss = jnp.mean(rec_loss)

        # Generator loss (autoencode)
        outs = self.model_disc(
            x_recon,
            params=disc_variables.params,
            batch_stats=disc_variables.batch_stats,
            train=train,
        )
        logits_fake, new_model_state = outs if train else (outs, None)
        # Original loss is
        # g_loss = -jnp.mean(logits_fake)
        # But we think that based on disc for generator should work we will use minimax loss
        # This loss is none negative and tries to maximize the probability of the fake_logits.
        g_loss = jnp.mean(jnp.maximum(1.0 - logits_fake, 0.0))
        disc_factor = 0.0
        if disc_use:
            disc_factor = self.module_config.disc_weight
        disc_factor = jax.lax.cond(
            disc_use, lambda _: self.module_config.disc_weight, lambda _: 0.0, None
        )
        loss = (
            nll_loss
            + disc_factor * g_loss
            + self.module_config.codebook_weight * codebook_loss.mean()
        )

        metrics = {
            "total_loss": loss,
            "quant_loss": jnp.mean(codebook_loss),
            "nll_loss": nll_loss,
            "rec_loss": jnp.mean(rec_loss),
            "g_loss": g_loss,
        }

        return loss, (metrics, new_rng, new_model_state)

    def calculate_loss_disc(
        params: FrozenDict[str, Any],
        batch: jnp.ndarray,
        train: bool,
        rng: Union[Any, jnp.ndarray],
        disc_use: bool,
        batch_stats: FrozenDict[str, Any],
        model_params: Optional[FrozenDict[str, Any]],
    ) -> Tuple[Any, Tuple[Dict[str, Any], Union[Any, jnp.ndarray], Any]]:
        """Function to calculate the loss discriminator for a batch of images."""
        new_rng, gumble_apply_rng, dropout_apply_rng = jax.random.split(rng, num=3)
        outs = self.model(
            batch,
            params=model_params,
            dropout_rng=dropout_apply_rng,
            gumble_rng=gumble_apply_rng,
            train=train,
        )
        x_recon, z_q, codebook_loss, indices = outs

        # Discriminator loss
        outs = self.model_disc(batch, params=params, batch_stats=batch_stats, train=train)
        logits_real, new_model_state = outs if train else (outs, None)
        outs = self.model_disc(x_recon, params=params, batch_stats=batch_stats, train=train)
        logits_fake, new_model_state = outs if train else (outs, None)
        disc_factor = jax.lax.cond(
            disc_use, lambda _: self.module_config.disc_weight, lambda _: 0.0, None
        )
        loss = disc_factor * disc_loss_fn(logits_real, logits_fake)
        metrics = {
            "disc_loss": loss,
            "logits_real": logits_real.mean(),
            "logits_fake": logits_fake.mean(),
        }

        return loss, (metrics, new_rng, new_model_state)

    def train_step_autoencoder(
        state: train_state.TrainState,
        disc_state: TrainStateDisc,
        batch: jnp.ndarray,
        rng: Union[Any, jnp.ndarray],
        disc_use: bool,
        distributed: bool = False,
    ) -> Tuple[
        train_state.TrainState, TrainStateDisc, Union[Any, jnp.ndarray], Dict[str, float]
    ]:
        """Train step for autoencoder."""
        loss_fn = partial(
            calculate_loss_autoencoder,
            batch=batch,
            train=True,
            rng=rng,
            disc_use=disc_use,
            disc_variables=disc_state,
        )
        (_, (metrics, new_rng, new_model_state)), grads = jax.value_and_grad(
            loss_fn, has_aux=True
        )(state.params)
        # if distributed training, average grads
        if distributed:
            grads = jax.lax.pmean(grads, axis_name="batch")
        # Update parameters
        state = state.apply_gradients(grads=grads)
        return state, disc_state, new_rng, metrics

    def train_step_disc(
        state: train_state.TrainState,
        disc_state: TrainStateDisc,
        batch: jnp.ndarray,
        rng: Union[Any, jnp.ndarray],
        disc_use: bool,
        distributed: bool = False,
    ) -> Tuple[
        train_state.TrainState, TrainStateDisc, Union[Any, jnp.ndarray], Dict[str, float]
    ]:
        """Train step for discriminator."""
        loss_fn = partial(
            calculate_loss_disc,
            batch=batch,
            train=True,
            rng=rng,
            disc_use=disc_use,
            batch_stats=disc_state.batch_stats,
            model_params=state.params,
        )
        (_, (metrics, new_rng, new_model_state)), grads = jax.value_and_grad(
            loss_fn, has_aux=True
        )(disc_state.params)
        # if distributed training, average grads
        if distributed:
            grads = jax.lax.pmean(grads, axis_name="batch")
        # Update parameters, batch statistics
        disc_state = disc_state.apply_gradients(
            grads=grads, batch_stats=new_model_state["batch_stats"]
        )
        return state, disc_state, new_rng, metrics

    def train_step(
        state: train_state.TrainState,
        disc_state: TrainStateDisc,
        batch: jnp.ndarray,
        rng: Union[Any, jnp.ndarray],
        optimizer_idx: int,
        disc_use: bool,
        distributed: bool,
    ) -> Tuple[
        train_state.TrainState, TrainStateDisc, Union[Any, jnp.ndarray], Dict[str, float]
    ]:
        """Train model on a single batch."""
        # calculate loss
        if optimizer_idx == 0:
            outs = train_step_autoencoder(state, disc_state, batch, rng, disc_use, distributed)
        else:
            outs = train_step_disc(state, disc_state, batch, rng, disc_use, distributed)
        # outs = jax.lax.cond(
        #  optimizer_idx == 0,
        #    lambda _: train_step_autoencoder(state,
        #                                     disc_state,
        #                                     batch,
        #                                     rng,
        #                                     disc_use,
        #                                     distributed),
        #    lambda _: train_step_disc(state,
        #                              disc_state,
        #                              batch,
        #                              rng,
        #                              disc_use,
        #                              distributed),
        #    None)
        state, disc_state, new_rng, metrics = outs
        return state, disc_state, new_rng, metrics

    def eval_step(
        state: train_state.TrainState,
        batch: jnp.ndarray,
        disc_state: TrainStateDisc,
        rng: Union[Any, jnp.ndarray],
    ) -> Tuple[Union[Any, jnp.ndarray], Dict[str, float]]:
        """Evaluate model on a single batch."""
        _, (metrics, new_rng, _) = calculate_loss_autoencoder(
            state.params,
            batch=batch,
            train=False,
            rng=rng,
            disc_use=False,
            disc_variables=disc_state,
        )
        return new_rng, metrics

    # pmap or jit for efficiency
    if self.module_config.distributed:
        self.train_step = jax.pmap(  # type: ignore
            train_step, axis_name="batch", static_broadcasted_argnums=(4, 5, 6)
        )
        self.eval_step = jax.jit(partial(eval_step, disc_state=self.state_disc))  # type: ignore
    else:
        self.train_step = jax.jit(train_step, static_argnums=(4, 5, 6))  # type: ignore
        self.eval_step = jax.jit(partial(eval_step, disc_state=self.state_disc))  # type: ignore

create_train_stat_full(optimizer, optimizer_disc)

Initialize training state.

Parameters:

Name Type Description Default
optimizer optax.GradientTransformation

Optimizer for generator.

required
optimizer_disc optax.GradientTransformation

Optimizer for discriminator.

required
Source code in modules/training.py
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
def create_train_stat_full(
    self,
    optimizer: optax.GradientTransformation,
    optimizer_disc: optax.GradientTransformation,
):
    """Initialize training state.
    Args:
        optimizer (optax.GradientTransformation): Optimizer for generator.
        optimizer_disc (optax.GradientTransformation): Optimizer for discriminator.
    """
    self.state = train_state.TrainState.create(
        apply_fn=self.state.apply_fn, params=self.state.params, tx=optimizer
    )
    self.state_disc = TrainStateDisc.create(
        apply_fn=self.state_disc.apply_fn,
        params=self.state_disc.params,
        batch_stats=self.state_disc.batch_stats,
        tx=optimizer_disc,
    )

init_optimizer()

Initialize optimizer and scheduler also for discriminator. By default, we decrease the learning rate with cosine annealing.

Source code in modules/training.py
400
401
402
403
404
405
406
def init_optimizer(self):
    """Initialize optimizer and scheduler also for discriminator.
    By default, we decrease the learning rate with cosine annealing.
    """
    optimizer: optax.GradientTransformation = self.module_config.optimizer
    optimizer_disc: optax.GradientTransformation = self.module_config.optimizer_disc
    self.create_train_stat_full(optimizer, optimizer_disc)

load_model()

Load model.

Source code in modules/training.py
439
440
441
442
443
444
445
446
447
448
449
450
451
452
def load_model(self):
    """Load model."""
    super().load_model()
    state_dict = checkpoints.restore_checkpoint(ckpt_dir=self.save_dir_disc, target=None)
    self.state_disc = TrainStateDisc(
        apply_fn=self.state_disc.apply_fn,
        params=state_dict["params"],
        batch_stats=state_dict["batch_stats"],
        step=state_dict["step"],
        tx=self.state_disc.tx if self.state_disc.tx else self.module_config.optimizer_disc,
        opt_state=state_dict["opt_state"],
    )
    self.model_disc.params["params"] = self.state_disc.params
    self.model_disc.params["batch_stats"] = self.state_disc.batch_stats

save_model(step=None)

Save current model.

Parameters:

Name Type Description Default
step int

Current step. Defaults to None.

None
Source code in modules/training.py
428
429
430
431
432
433
434
435
436
437
def save_model(self, step: Optional[int] = None):
    """Save current model.
    Args:
        step (int, optional): Current step. Defaults to None.
    """
    step = step if step is not None else self.module_config.num_epochs
    super().save_model(step=step)
    checkpoints.save_checkpoint(
        ckpt_dir=self.save_dir_disc, target=self.state_disc, step=step, overwrite=True
    )

temperature_scheduling(epoch)

Temperature scheduling.

Parameters:

Name Type Description Default
epoch int

Current epoch.

required

Returns:

Type Description
float

Temperature.

Source code in modules/training.py
388
389
390
391
392
393
394
395
396
397
398
def temperature_scheduling(self, epoch: int) -> float:
    """Temperature scheduling.
    Args:
        epoch (int): Current epoch.
    Returns:
        Temperature.
    """
    if self.temp_scheduler is not None:
        return self.model.update_temperature(self.temp_scheduler(epoch))
    else:
        return self.module_config.model_hparams.gumb_temp

train_epoch(data_loader, epoch)

Train model for one epoch, and log avg metrics.

Parameters:

Name Type Description Default
data_loader utils.DataLoader

Data loader to train on.

required

Returns:

Type Description
Dict[str, float]

Dictionary with all metrics.

Source code in modules/training.py
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
def train_epoch(self, data_loader: utils.DataLoader, epoch: int) -> Dict[str, float]:
    """Train model for one epoch, and log avg metrics.
    Args:
        data_loader (utils.DataLoader): Data loader to train on.
    Returns:
        Dictionary with all metrics.
    """
    metrics: Dict[str, float] = defaultdict(float)
    metrics_disc = defaultdict(float)
    metrics_disc["step"] = 0.0
    new_temp: float = self.temperature_scheduling(epoch - 1)
    for batch in tqdm(data_loader(), desc="Training", leave=False):
        batch_metrics: Dict[str, float]
        train_outs = self.train_step(  # type: ignore
            state=self.state,
            disc_state=self.state_disc,
            batch=batch,
            rng=self.main_rng,
            optimizer_idx=0,
            disc_use=self.module_config.disc_start > epoch,
            distributed=self.module_config.distributed,
        )
        self.state, self.state_disc, self.main_rng, batch_metrics = train_outs
        # Update metrics
        for key, value in batch_metrics.items():
            metrics[key] += value

        if self.module_config.disc_start > epoch:
            batch_metrics_disc: Dict[str, float]
            train_outs = self.train_step(  # type: ignore
                state=self.state,
                disc_state=self.state_disc,
                batch=batch,
                rng=self.main_rng,
                optimizer_idx=1,
                disc_use=True,
                distributed=self.module_config.distributed,
            )
            self.state, self.state_disc, self.main_rng, batch_metrics_disc = train_outs
            # Update metrics discriminator
            for key, value in batch_metrics_disc.items():
                metrics_disc[key] += value
            metrics_disc["step"] += 1.0

        # ensure that model have actual parameters
        self.model.params = self.state.params
        self.model_disc.params["params"] = self.state_disc.params
        self.model_disc.params["batch_stats"] = self.state_disc.batch_stats

    count = len(data_loader)
    metrics = {key: metrics[key] / count for key in metrics}
    metrics["temp"] = new_temp
    count_disc = metrics_disc["step"]
    del metrics_disc["step"]
    metrics_disc_resized: Dict[str, float] = {
        key: metrics_disc[key] / count_disc for key in metrics_disc
    }
    # merge metrics
    for key, value in metrics_disc_resized.items():
        metrics[key] = value

    return metrics

VQGANPreTrainedModel

VQGANPreTrainedModel in modules.vqgan, response for VQ autoencoder architecture. This class is based on FlaxPreTrainedModel which gives ous abilities to push the architecture to Hugging Face Hub.

Bases: FlaxPreTrainedModel

An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.

Attributes:

Name Type Description
module_class nn.Module

a class derived from nn.Module that defines the model's core computation.

config_class PretrainedConfig

a class derived from PretrainedConfig

Source code in modules/vqgan.py
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
class VQGANPreTrainedModel(FlaxPreTrainedModel):
    """An abstract class to handle weights initialization and a simple interface
    for downloading and loading pretrained models.

    Attributes:
        module_class (nn.Module): a class derived from nn.Module
            that defines the model's core computation.
        config_class (PretrainedConfig): a class derived from PretrainedConfig

    """

    module_class: nn.Module
    config_class: PretrainedConfig

    def __init__(
        self,
        config: PretrainedConfig = VQGANConfig(),
        input_shape: Tuple = (1, 256, 256, 3),
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
        _do_init: bool = True,
        **kwargs,
    ):
        """Initialize the model.

        Args:
            config (PretrainedConfig, optional): the config of the model. Defaults to VQGANConfig.
            input_shape (Tuple, optional): the input shape of the model.
                Defaults to (1, 256, 256, 3).
            seed (int, optional): the seed of the model. Defaults to 0.
            dtype (jnp.dtype, optional): the dtype of the computation. Defaults to jnp.float32.
            _do_init (bool, optional): whether to initialize the model. Defaults to True.
        """
        self._missing_keys: Set[str] = set()
        if not isinstance(config, self.config_class):
            raise ValueError(f"config: {config} has to be an instance of {self.config_class}")
        if self.module_class is None:
            raise NotImplementedError("module_class should be defined in derived classes")
        module = self.module_class(config=config, dtype=dtype, **kwargs)
        self.seed = seed
        self.dtype = dtype
        super().__init__(
            config,
            module,
            input_shape=input_shape,
            seed=seed,
            dtype=dtype,
            _do_init=_do_init,
        )

    def init_weights(
        self,
        rng: Union[Any, jnp.ndarray],
        input_shape: Tuple,
        params: Optional[FrozenDict[str, Any]] = None,
    ) -> FrozenDict[str, Any]:
        """Initialize the weights of the model. Get the params

        Args:
            rng (Union[Any,jnp.ndarray]): the random number generator.
            input_shape (Tuple): the input shape of the model.
            params (FrozenDict, optional): the params of the model. Defaults to None.

        Returns:
            initialized params of the model.
        """
        # initialize model
        input_x = jnp.zeros(input_shape, dtype=self.dtype)
        params_rng, dropout_rng, gumble_rng = jax.random.split(rng, num=3)
        rngs: Dict[str, Union[Any, jnp.ndarray]] = {
            "params": params_rng,
            "dropout": dropout_rng,
            "gumbel": gumble_rng,
        }

        random_params = self.module.init(rngs, input_x, True)["params"]

        # If params provided find unitialized params and replace with provided params
        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]  # type: ignore
            self._missing_keys = set()
            return freeze(unflatten_dict(params))
        else:
            return random_params

    def encode(
        self,
        pixel_values: jnp.ndarray,
        params: Optional[FrozenDict] = None,
        dropout_rng: Optional[Union[Any, jnp.ndarray]] = None,
        gumble_rng: Optional[Union[Any, jnp.ndarray]] = None,
        train: bool = False,
    ) -> Tuple[jnp.ndarray, float, jnp.ndarray]:
        """Encode the input.
        Args:
            pixel_values (jnp.ndarray): the input to the encoder.
            params (Optional[FrozenDict], optional): the params of the model. Defaults to None.
            dropout_rng (Union[Any,jnp.ndarray], optional): the dropout rng. Defaults to None.
            gumble_rng (Union[Any,jnp.ndarray], optional): the gumbel rng. Defaults to None.
            train (bool, optional): Training or inference mode. Defaults to False.
        """
        # Handle any PRNG if needed
        rngs: Dict[str, Union[Any, jnp.ndarray]] = (
            {"dropout": dropout_rng} if dropout_rng is not None else {}
        )
        rngs["gumbel"] = gumble_rng if gumble_rng is not None else {}
        return self.module.apply(
            {"params": params or self.params},
            pixel_values,
            not train,
            rngs=rngs,
            method=self.module.encode,
        )

    def decode(
        self,
        z: jnp.ndarray,
        params: Optional[FrozenDict] = None,
        dropout_rng: Optional[Union[Any, jnp.ndarray]] = None,
        gumble_rng: Optional[Union[Any, jnp.ndarray]] = None,
        train: bool = False,
    ) -> jnp.ndarray:
        """Decode the latent vector.

        Args:
            z (jnp.ndarray): the latent vector.
            params (Optional[FrozenDict], optional): the params of the model. Defaults to None.
            dropout_rng (Union[Any,jnp.ndarray], optional): the dropout rng. Defaults to None.
            gumble_rng (Union[Any,jnp.ndarray], optional): the gumbel rng. Defaults to None.
            train (bool, optional): Training or inference mode. Defaults to False.

        Returns:
            the decoded image.
        """
        # Handle any PRNG if needed
        rngs: Dict[str, Union[Any, jnp.ndarray]] = (
            {"dropout": dropout_rng} if dropout_rng is not None else {}
        )
        rngs["gumbel"] = gumble_rng if gumble_rng is not None else {}
        return self.module.apply(
            {"params": params or self.params},
            z,
            not train,
            rngs=rngs,
            method=self.module.decode,
        )

    def decode_code(
        self,
        indices: jnp.ndarray,
        z_shape: Tuple[int, ...],
        params: Optional[FrozenDict] = None,
    ) -> jnp.ndarray:
        """Decode the indices.

        Args:
            indices (jnp.ndarray): the indices.
            z_shape (Tuple[int, ...]): the shape of the latent vector.
            params (Optional[FrozenDict], optional): the params of the model. Defaults to None.

        Returns:
            the decoded image from indices.
        """
        return self.module.apply(
            {"params": params or self.params},
            indices,
            z_shape,
            method=self.module.decode_code,
        )

    def update_temperature(self, temperature: float, params: Optional[FrozenDict] = None) -> float:
        """Update the temperature of the model.
        Args:
            temperature (float): the temperature to update to.
            params (Optional[FrozenDict], optional): the params of the model. Defaults to None.
        Returns:
            the updated temperature.
        """
        new_temperature = self.module.apply(
            {"params": params or self.params},
            temperature,
            method=self.module.update_temperature,
        )
        return new_temperature

    def __call__(
        self,
        pixel_values: jnp.ndarray,
        params: Optional[FrozenDict] = None,
        dropout_rng: Optional[jnp.ndarray] = None,
        gumble_rng: Optional[jnp.ndarray] = None,
        train: bool = False,
    ) -> Tuple[jnp.ndarray, jnp.ndarray, float, jnp.ndarray]:
        """Encode and decode the input.

        Args:
            pixel_values (jnp.ndarray): the input to the encoder.
            params (Optional[FrozenDict], optional): the params of the model. Defaults to None.
            dropout_rng (Optional[jnp.ndarray], optional): the dropout rng. Defaults to None.
            gumble_rng (Optional[jnp.ndarray], optional): the gumbel rng. Defaults to None.
                If gumble_rng is None then the defult rng is used and produce deterministic results.
            train (bool, optional): Training or inference mode. Defaults to False.

        Returns:
                the encoded latent vector,
                the decoded image,
                the log prob of the latent vector,
                the indices of the latent vector.
        """
        # Check dtype
        pixel_values = (
            pixel_values.astype(self.dtype) if pixel_values.dtype != self.dtype else pixel_values
        )
        # Handle any PRNG if needed
        rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
        rngs["gumbel"] = gumble_rng if gumble_rng is not None else self.key
        return self.module.apply(
            {"params": params or self.params}, pixel_values, not train, rngs=rngs
        )

__call__(pixel_values, params=None, dropout_rng=None, gumble_rng=None, train=False)

Encode and decode the input.

Parameters:

Name Type Description Default
pixel_values jnp.ndarray

the input to the encoder.

required
params Optional[FrozenDict]

the params of the model. Defaults to None.

None
dropout_rng Optional[jnp.ndarray]

the dropout rng. Defaults to None.

None
gumble_rng Optional[jnp.ndarray]

the gumbel rng. Defaults to None. If gumble_rng is None then the defult rng is used and produce deterministic results.

None
train bool

Training or inference mode. Defaults to False.

False

Returns:

Type Description
jnp.ndarray

the encoded latent vector,

jnp.ndarray

the decoded image,

float

the log prob of the latent vector,

jnp.ndarray

the indices of the latent vector.

Source code in modules/vqgan.py
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
def __call__(
    self,
    pixel_values: jnp.ndarray,
    params: Optional[FrozenDict] = None,
    dropout_rng: Optional[jnp.ndarray] = None,
    gumble_rng: Optional[jnp.ndarray] = None,
    train: bool = False,
) -> Tuple[jnp.ndarray, jnp.ndarray, float, jnp.ndarray]:
    """Encode and decode the input.

    Args:
        pixel_values (jnp.ndarray): the input to the encoder.
        params (Optional[FrozenDict], optional): the params of the model. Defaults to None.
        dropout_rng (Optional[jnp.ndarray], optional): the dropout rng. Defaults to None.
        gumble_rng (Optional[jnp.ndarray], optional): the gumbel rng. Defaults to None.
            If gumble_rng is None then the defult rng is used and produce deterministic results.
        train (bool, optional): Training or inference mode. Defaults to False.

    Returns:
            the encoded latent vector,
            the decoded image,
            the log prob of the latent vector,
            the indices of the latent vector.
    """
    # Check dtype
    pixel_values = (
        pixel_values.astype(self.dtype) if pixel_values.dtype != self.dtype else pixel_values
    )
    # Handle any PRNG if needed
    rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
    rngs["gumbel"] = gumble_rng if gumble_rng is not None else self.key
    return self.module.apply(
        {"params": params or self.params}, pixel_values, not train, rngs=rngs
    )

__init__(config=VQGANConfig(), input_shape=(1, 256, 256, 3), seed=0, dtype=jnp.float32, _do_init=True, **kwargs)

Initialize the model.

Parameters:

Name Type Description Default
config PretrainedConfig

the config of the model. Defaults to VQGANConfig.

VQGANConfig()
input_shape Tuple

the input shape of the model. Defaults to (1, 256, 256, 3).

(1, 256, 256, 3)
seed int

the seed of the model. Defaults to 0.

0
dtype jnp.dtype

the dtype of the computation. Defaults to jnp.float32.

jnp.float32
_do_init bool

whether to initialize the model. Defaults to True.

True
Source code in modules/vqgan.py
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
def __init__(
    self,
    config: PretrainedConfig = VQGANConfig(),
    input_shape: Tuple = (1, 256, 256, 3),
    seed: int = 0,
    dtype: jnp.dtype = jnp.float32,
    _do_init: bool = True,
    **kwargs,
):
    """Initialize the model.

    Args:
        config (PretrainedConfig, optional): the config of the model. Defaults to VQGANConfig.
        input_shape (Tuple, optional): the input shape of the model.
            Defaults to (1, 256, 256, 3).
        seed (int, optional): the seed of the model. Defaults to 0.
        dtype (jnp.dtype, optional): the dtype of the computation. Defaults to jnp.float32.
        _do_init (bool, optional): whether to initialize the model. Defaults to True.
    """
    self._missing_keys: Set[str] = set()
    if not isinstance(config, self.config_class):
        raise ValueError(f"config: {config} has to be an instance of {self.config_class}")
    if self.module_class is None:
        raise NotImplementedError("module_class should be defined in derived classes")
    module = self.module_class(config=config, dtype=dtype, **kwargs)
    self.seed = seed
    self.dtype = dtype
    super().__init__(
        config,
        module,
        input_shape=input_shape,
        seed=seed,
        dtype=dtype,
        _do_init=_do_init,
    )

decode(z, params=None, dropout_rng=None, gumble_rng=None, train=False)

Decode the latent vector.

Parameters:

Name Type Description Default
z jnp.ndarray

the latent vector.

required
params Optional[FrozenDict]

the params of the model. Defaults to None.

None
dropout_rng Union[Any, jnp.ndarray]

the dropout rng. Defaults to None.

None
gumble_rng Union[Any, jnp.ndarray]

the gumbel rng. Defaults to None.

None
train bool

Training or inference mode. Defaults to False.

False

Returns:

Type Description
jnp.ndarray

the decoded image.

Source code in modules/vqgan.py
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
def decode(
    self,
    z: jnp.ndarray,
    params: Optional[FrozenDict] = None,
    dropout_rng: Optional[Union[Any, jnp.ndarray]] = None,
    gumble_rng: Optional[Union[Any, jnp.ndarray]] = None,
    train: bool = False,
) -> jnp.ndarray:
    """Decode the latent vector.

    Args:
        z (jnp.ndarray): the latent vector.
        params (Optional[FrozenDict], optional): the params of the model. Defaults to None.
        dropout_rng (Union[Any,jnp.ndarray], optional): the dropout rng. Defaults to None.
        gumble_rng (Union[Any,jnp.ndarray], optional): the gumbel rng. Defaults to None.
        train (bool, optional): Training or inference mode. Defaults to False.

    Returns:
        the decoded image.
    """
    # Handle any PRNG if needed
    rngs: Dict[str, Union[Any, jnp.ndarray]] = (
        {"dropout": dropout_rng} if dropout_rng is not None else {}
    )
    rngs["gumbel"] = gumble_rng if gumble_rng is not None else {}
    return self.module.apply(
        {"params": params or self.params},
        z,
        not train,
        rngs=rngs,
        method=self.module.decode,
    )

decode_code(indices, z_shape, params=None)

Decode the indices.

Parameters:

Name Type Description Default
indices jnp.ndarray

the indices.

required
z_shape Tuple[int, ...]

the shape of the latent vector.

required
params Optional[FrozenDict]

the params of the model. Defaults to None.

None

Returns:

Type Description
jnp.ndarray

the decoded image from indices.

Source code in modules/vqgan.py
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
def decode_code(
    self,
    indices: jnp.ndarray,
    z_shape: Tuple[int, ...],
    params: Optional[FrozenDict] = None,
) -> jnp.ndarray:
    """Decode the indices.

    Args:
        indices (jnp.ndarray): the indices.
        z_shape (Tuple[int, ...]): the shape of the latent vector.
        params (Optional[FrozenDict], optional): the params of the model. Defaults to None.

    Returns:
        the decoded image from indices.
    """
    return self.module.apply(
        {"params": params or self.params},
        indices,
        z_shape,
        method=self.module.decode_code,
    )

encode(pixel_values, params=None, dropout_rng=None, gumble_rng=None, train=False)

Encode the input.

Parameters:

Name Type Description Default
pixel_values jnp.ndarray

the input to the encoder.

required
params Optional[FrozenDict]

the params of the model. Defaults to None.

None
dropout_rng Union[Any, jnp.ndarray]

the dropout rng. Defaults to None.

None
gumble_rng Union[Any, jnp.ndarray]

the gumbel rng. Defaults to None.

None
train bool

Training or inference mode. Defaults to False.

False
Source code in modules/vqgan.py
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
def encode(
    self,
    pixel_values: jnp.ndarray,
    params: Optional[FrozenDict] = None,
    dropout_rng: Optional[Union[Any, jnp.ndarray]] = None,
    gumble_rng: Optional[Union[Any, jnp.ndarray]] = None,
    train: bool = False,
) -> Tuple[jnp.ndarray, float, jnp.ndarray]:
    """Encode the input.
    Args:
        pixel_values (jnp.ndarray): the input to the encoder.
        params (Optional[FrozenDict], optional): the params of the model. Defaults to None.
        dropout_rng (Union[Any,jnp.ndarray], optional): the dropout rng. Defaults to None.
        gumble_rng (Union[Any,jnp.ndarray], optional): the gumbel rng. Defaults to None.
        train (bool, optional): Training or inference mode. Defaults to False.
    """
    # Handle any PRNG if needed
    rngs: Dict[str, Union[Any, jnp.ndarray]] = (
        {"dropout": dropout_rng} if dropout_rng is not None else {}
    )
    rngs["gumbel"] = gumble_rng if gumble_rng is not None else {}
    return self.module.apply(
        {"params": params or self.params},
        pixel_values,
        not train,
        rngs=rngs,
        method=self.module.encode,
    )

init_weights(rng, input_shape, params=None)

Initialize the weights of the model. Get the params

Parameters:

Name Type Description Default
rng Union[Any, jnp.ndarray]

the random number generator.

required
input_shape Tuple

the input shape of the model.

required
params FrozenDict

the params of the model. Defaults to None.

None

Returns:

Type Description
FrozenDict[str, Any]

initialized params of the model.

Source code in modules/vqgan.py
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
def init_weights(
    self,
    rng: Union[Any, jnp.ndarray],
    input_shape: Tuple,
    params: Optional[FrozenDict[str, Any]] = None,
) -> FrozenDict[str, Any]:
    """Initialize the weights of the model. Get the params

    Args:
        rng (Union[Any,jnp.ndarray]): the random number generator.
        input_shape (Tuple): the input shape of the model.
        params (FrozenDict, optional): the params of the model. Defaults to None.

    Returns:
        initialized params of the model.
    """
    # initialize model
    input_x = jnp.zeros(input_shape, dtype=self.dtype)
    params_rng, dropout_rng, gumble_rng = jax.random.split(rng, num=3)
    rngs: Dict[str, Union[Any, jnp.ndarray]] = {
        "params": params_rng,
        "dropout": dropout_rng,
        "gumbel": gumble_rng,
    }

    random_params = self.module.init(rngs, input_x, True)["params"]

    # If params provided find unitialized params and replace with provided params
    if params is not None:
        random_params = flatten_dict(unfreeze(random_params))
        params = flatten_dict(unfreeze(params))
        for missing_key in self._missing_keys:
            params[missing_key] = random_params[missing_key]  # type: ignore
        self._missing_keys = set()
        return freeze(unflatten_dict(params))
    else:
        return random_params

update_temperature(temperature, params=None)

Update the temperature of the model.

Parameters:

Name Type Description Default
temperature float

the temperature to update to.

required
params Optional[FrozenDict]

the params of the model. Defaults to None.

None

Returns:

Type Description
float

the updated temperature.

Source code in modules/vqgan.py
494
495
496
497
498
499
500
501
502
503
504
505
506
507
def update_temperature(self, temperature: float, params: Optional[FrozenDict] = None) -> float:
    """Update the temperature of the model.
    Args:
        temperature (float): the temperature to update to.
        params (Optional[FrozenDict], optional): the params of the model. Defaults to None.
    Returns:
        the updated temperature.
    """
    new_temperature = self.module.apply(
        {"params": params or self.params},
        temperature,
        method=self.module.update_temperature,
    )
    return new_temperature

VQGanDiscriminator

VQGanDiscriminator in modules.vqgan, response for Discriminator architecture. This class is based on FlaxPreTrainedModel which gives ous abilities to push the architecture to Hugging Face Hub.

Bases: FlaxPreTrainedModel

VQGAN discriminator model.

Attributes:

Name Type Description
module_class nn.Module

the discriminator module class (NLayerDiscriminator).

Source code in modules/vqgan.py
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
class VQGanDiscriminator(FlaxPreTrainedModel):
    """VQGAN discriminator model.

    Attributes:
        module_class (nn.Module): the discriminator module class (NLayerDiscriminator).
    """

    module_class: nn.Module = NLayerDiscriminator  # type: ignore

    def __init__(
        self,
        config: DiscConfig = DiscConfig(),
        input_shape: Tuple = (1, 256, 256, 3),
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
        _do_init: bool = True,
        **kwargs,
    ):
        self._missing_keys: Set[str] = set()
        module = self.module_class(
            ndf=config.ndf,
            n_layers=config.n_layers,
            output_dim=config.output_last_dim,
            dtype=dtype,
            **kwargs,
        )
        super().__init__(
            config,
            module,
            input_shape=input_shape,
            seed=seed,
            dtype=dtype,
            _do_init=_do_init,
        )

    def init_weights(
        self,
        params_rng: jnp.ndarray,
        input_shape: Tuple,
        params: Optional[FrozenDict[str, Any]] = None,
    ) -> FrozenDict[str, Any]:
        # initialize model
        input_x = jnp.zeros(input_shape, dtype=self.dtype)
        random_params = self.module.init(params_rng, input_x, True)

        # If params provided find unitialized params and replace with provided params
        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]  # type: ignore
            self._missing_keys = set()
            return freeze(unflatten_dict(params))
        else:
            return random_params

    def __call__(
        self,
        input: jnp.ndarray,
        params: Optional[FrozenDict] = None,
        batch_stats: Optional[FrozenDict] = None,
        train: bool = False,
    ) -> Tuple[jnp.ndarray, jnp.ndarray, float, jnp.ndarray]:
        # Handle any PRNG if needed
        if batch_stats is not None and params is not None:
            dict_params = {"params": params, "batch_stats": batch_stats}
        elif batch_stats is not None:
            dict_params = {"params": self.params["params"], "batch_stats": batch_stats}
        elif params is not None:
            dict_params = {"params": params, "batch_stats": self.params["batch_stats"]}
        else:
            dict_params = {
                "params": self.params["params"],
                "batch_stats": self.params["batch_stats"],
            }
        return self.module.apply(
            dict_params, input, train=train, mutable=["batch_stats"] if train else False
        )

TensorflowDataset

TensorflowDataset in modules.utils, response for loading Tensorflow datasets and prepering them. This class is based on BaseDataset

Bases: BaseDataset

Tensorflow dataset.

Source code in modules/utils.py
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
class TensorflowDataset(BaseDataset):
    """Tensorflow dataset."""

    def load_dataset(self, train: bool) -> tf.data.Dataset:
        """Load the dataset.
        Args:
            train: If the dataset is for training.
        Returns:
            The dataset."""

        # if you get error 'Too many open files' one can resolve it doing what this issue proposed
        # https://github.com/tensorflow/datasets/issues/1441#issuecomment-581660890
        # Below is the code to resolve it
        # import resource
        # low, high = resource.getrlimit(resource.RLIMIT_NOFILE)
        # resource.setrlimit(resource.RLIMIT_NOFILE, (high, high))
        split = "train" if train else "validation"
        try:
            ds = tfds.load(
                name=self.dataset_name, split=split, as_supervised=True, data_dir=self.root
            ).map(tf.autograph.experimental.do_not_convert(lambda x, y: x))
        except Exception as e:
            if split == "validation":
                ds = tfds.load(
                    name=self.dataset_name, split="test", as_supervised=True, data_dir=self.root
                ).map(tf.autograph.experimental.do_not_convert(lambda x, y: x))
            else:
                raise e

        return ds.cache()

load_dataset(train)

Load the dataset.

Parameters:

Name Type Description Default
train bool

If the dataset is for training.

required

Returns:

Type Description
tf.data.Dataset

The dataset.

Source code in modules/utils.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
def load_dataset(self, train: bool) -> tf.data.Dataset:
    """Load the dataset.
    Args:
        train: If the dataset is for training.
    Returns:
        The dataset."""

    # if you get error 'Too many open files' one can resolve it doing what this issue proposed
    # https://github.com/tensorflow/datasets/issues/1441#issuecomment-581660890
    # Below is the code to resolve it
    # import resource
    # low, high = resource.getrlimit(resource.RLIMIT_NOFILE)
    # resource.setrlimit(resource.RLIMIT_NOFILE, (high, high))
    split = "train" if train else "validation"
    try:
        ds = tfds.load(
            name=self.dataset_name, split=split, as_supervised=True, data_dir=self.root
        ).map(tf.autograph.experimental.do_not_convert(lambda x, y: x))
    except Exception as e:
        if split == "validation":
            ds = tfds.load(
                name=self.dataset_name, split="test", as_supervised=True, data_dir=self.root
            ).map(tf.autograph.experimental.do_not_convert(lambda x, y: x))
        else:
            raise e

    return ds.cache()

Bases: ABC

Load the dataset. Abstract method.

Source code in modules/utils.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
class BaseDataset(ABC):
    """Load the dataset. Abstract method."""

    def __init__(self, train: bool, dtype: jnp.dtype, config: DataConfig) -> None:
        """Set the dataset.
        Args:
            train: If the dataset is for training.
            config: The config for the dataset.
        """
        self.root: str = config.dataset_root
        self.use_transforms: bool = True if train else False
        if self.use_transforms and config.transform is None:
            raise ValueError("Transforms must be provided for training.")
        self.transforms = A.from_dict(config.transform)
        self.image_size: int = config.size
        self.dataset_name = config.dataset_name
        self.dtype = dtype
        self.params = config.train_params if train else config.test_params
        self.dataset: tf.data.Dataset = self.load_dataset(train)
        assert len(self.dataset) > 0, "Dataset is empty."

    @abstractmethod
    def load_dataset(self, train: bool) -> tf.data.Dataset:
        """Load the dataset.
        Args:
            train: If the dataset is for training.
        Returns:
            The dataset.
        """
        raise NotImplementedError

    def __len__(self) -> int:
        """Return the length of the dataset."""
        return len(self.dataset)

    def _preprocess(self, image: tf.Tensor) -> tf.Tensor:
        """Preprocess the image.
        Args:
            image: The image to preprocess.
        Returns:
            The preprocessed image.
        """

        def aug_fn(image: tf.Tensor) -> tf.Tensor:
            data = {"image": image}
            aug_data = self.transforms(**data)
            aug_img = aug_data["image"]
            aug_img = tf.cast(aug_img, self.dtype) / 255.0
            aug_img = (aug_img - IMAGENET_STANDARD_MEAN) / IMAGENET_STANDARD_STD
            aug_img = tf.image.resize(aug_img, (self.image_size, self.image_size))
            return aug_img

        if self.use_transforms:
            image = tf.numpy_function(func=aug_fn, inp=[image], Tout=self.dtype)
        else:
            image = tf.cast(image, self.dtype) / 255.0
            image = (image - IMAGENET_STANDARD_MEAN) / IMAGENET_STANDARD_STD
            image = tf.image.resize(image, (self.image_size, self.image_size))
        return image

    def get_dataset(self) -> tf.data.Dataset:
        """Return the dataset.
        Returns:
            The dataset.
        """
        dataset = self.dataset.map(self._preprocess)
        dataset = dataset.shuffle(self.params.batch_size * 16) if self.params.shuffle else dataset
        return dataset

__init__(train, dtype, config)

Set the dataset.

Parameters:

Name Type Description Default
train bool

If the dataset is for training.

required
config DataConfig

The config for the dataset.

required
Source code in modules/utils.py
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def __init__(self, train: bool, dtype: jnp.dtype, config: DataConfig) -> None:
    """Set the dataset.
    Args:
        train: If the dataset is for training.
        config: The config for the dataset.
    """
    self.root: str = config.dataset_root
    self.use_transforms: bool = True if train else False
    if self.use_transforms and config.transform is None:
        raise ValueError("Transforms must be provided for training.")
    self.transforms = A.from_dict(config.transform)
    self.image_size: int = config.size
    self.dataset_name = config.dataset_name
    self.dtype = dtype
    self.params = config.train_params if train else config.test_params
    self.dataset: tf.data.Dataset = self.load_dataset(train)
    assert len(self.dataset) > 0, "Dataset is empty."

__len__()

Return the length of the dataset.

Source code in modules/utils.py
103
104
105
def __len__(self) -> int:
    """Return the length of the dataset."""
    return len(self.dataset)

get_dataset()

Return the dataset.

Returns:

Type Description
tf.data.Dataset

The dataset.

Source code in modules/utils.py
132
133
134
135
136
137
138
139
def get_dataset(self) -> tf.data.Dataset:
    """Return the dataset.
    Returns:
        The dataset.
    """
    dataset = self.dataset.map(self._preprocess)
    dataset = dataset.shuffle(self.params.batch_size * 16) if self.params.shuffle else dataset
    return dataset

load_dataset(train) abstractmethod

Load the dataset.

Parameters:

Name Type Description Default
train bool

If the dataset is for training.

required

Returns:

Type Description
tf.data.Dataset

The dataset.

Source code in modules/utils.py
 93
 94
 95
 96
 97
 98
 99
100
101
@abstractmethod
def load_dataset(self, train: bool) -> tf.data.Dataset:
    """Load the dataset.
    Args:
        train: If the dataset is for training.
    Returns:
        The dataset.
    """
    raise NotImplementedError

DataLoader

DataLoader in modules.utils, responses for wraping datasets and creating batches. Similar to Pytorch Dataloader

Dataloader similar as in pytorch.

Source code in modules/utils.py
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
class DataLoader:
    """Dataloader similar as in pytorch."""

    def __init__(self, dataset: BaseDataset, distributed: bool) -> None:
        """Create a data loader.
        Args:
            dataset (BaseDataset): The dataset to load.
            distributed (bool): If the data is distributed.
        """
        self.dataset_placeholder = dataset
        self.dist = distributed

    def __len__(self) -> int:
        """Return the length of the dataset."""
        return len(self.dataset_placeholder) // self.dataset_placeholder.params.batch_size

    def __call__(self, *args: Any, **kwds: Any) -> Iterable:
        """Return the dataset in dataloader style."""
        ds = self.dataset_placeholder.get_dataset()
        batch_size = self.dataset_placeholder.params.batch_size
        if self.dist:
            per_core_bs, remainder = divmod(batch_size, len(jax.devices()))
            assert remainder == 0
            ds = ds.batch(per_core_bs, drop_remainder=True).batch(
                len(jax.devices()), drop_remainder=True
            )
        else:
            ds = ds.batch(batch_size, drop_remainder=True)

        # ds = map(lambda x: x._numpy(), ds.prefetch(tf.data.AUTOTUNE))
        # data = flax.jax_utils.prefetch_to_device(ds, 3) if self.dist else ds
        ds = tfds.as_numpy(ds.prefetch(tf.data.AUTOTUNE))
        ds = flax.jax_utils.prefetch_to_device(ds, 3) if self.dist else ds
        return ds

__call__(*args, **kwds)

Return the dataset in dataloader style.

Source code in modules/utils.py
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
def __call__(self, *args: Any, **kwds: Any) -> Iterable:
    """Return the dataset in dataloader style."""
    ds = self.dataset_placeholder.get_dataset()
    batch_size = self.dataset_placeholder.params.batch_size
    if self.dist:
        per_core_bs, remainder = divmod(batch_size, len(jax.devices()))
        assert remainder == 0
        ds = ds.batch(per_core_bs, drop_remainder=True).batch(
            len(jax.devices()), drop_remainder=True
        )
    else:
        ds = ds.batch(batch_size, drop_remainder=True)

    # ds = map(lambda x: x._numpy(), ds.prefetch(tf.data.AUTOTUNE))
    # data = flax.jax_utils.prefetch_to_device(ds, 3) if self.dist else ds
    ds = tfds.as_numpy(ds.prefetch(tf.data.AUTOTUNE))
    ds = flax.jax_utils.prefetch_to_device(ds, 3) if self.dist else ds
    return ds

__init__(dataset, distributed)

Create a data loader.

Parameters:

Name Type Description Default
dataset BaseDataset

The dataset to load.

required
distributed bool

If the data is distributed.

required
Source code in modules/utils.py
199
200
201
202
203
204
205
206
def __init__(self, dataset: BaseDataset, distributed: bool) -> None:
    """Create a data loader.
    Args:
        dataset (BaseDataset): The dataset to load.
        distributed (bool): If the data is distributed.
    """
    self.dataset_placeholder = dataset
    self.dist = distributed

__len__()

Return the length of the dataset.

Source code in modules/utils.py
208
209
210
def __len__(self) -> int:
    """Return the length of the dataset."""
    return len(self.dataset_placeholder) // self.dataset_placeholder.params.batch_size