Skip to content

API Reference

Structure:

Config

DataConfig dataclass

Data configuration class.

Parameters:

Name Type Description Default
train_params DataParams

training data parameters. Check DataParams for more details.

required
test_params DataParams

testing data parameters. Check DataParams for more details.

required
dataset_name str

name of the dataset to use. Currently only supports "voc".

''
dataset_root str

root directory of the dataset.

''
transform optional, dict

transform to apply to the dataset. Default None for no transf. Transform dict comes from albumentations library. Check albumentations for more details. Loading transform should proceed with albumentations.from_dict method.

None
size int

size of image width and height.

224
Source code in modules/config.py
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
@dataclass
class DataConfig:
    """Data configuration class.
    Arguments:
        train_params (DataParams): training data parameters. Check DataParams for more details.
        test_params (DataParams): testing data parameters. Check DataParams for more details.
        dataset_name (str): name of the dataset to use. Currently only supports "voc".
        dataset_root (str): root directory of the dataset.
        transform (optional, dict): transform to apply to the dataset. Default None for no transf.
            Transform dict comes from albumentations library. Check albumentations for more details.
            Loading transform should proceed with albumentations.from_dict method.
        size (int): size of image width and height.

    """

    train_params: DataParams
    test_params: DataParams
    dataset_name: str = ""
    dataset_root: str = ""
    transform: Optional[Dict[str, Any]] = None
    size: int = 224

    def __post_init__(self):
        # set train_params and test_params
        self.train_params = DataParams(**self.train_params)  # type: ignore
        self.test_params = DataParams(**self.test_params)  # type: ignore

DataParams dataclass

Train and test data parameters.

Parameters:

Name Type Description Default
batch_size int

batch size for training.

required
shuffle bool

whether to shuffle the dataset.

required
Source code in modules/config.py
215
216
217
218
219
220
221
222
223
224
@dataclass
class DataParams:
    """Train and test data parameters.
    Arguments:
        batch_size (int): batch size for training.
        shuffle (bool): whether to shuffle the dataset.
    """

    batch_size: int
    shuffle: bool

DiscConfig

Bases: PretrainedConfig

Configuration class to store the configuration of a Discriminator model. Dataclass for storing is based on PretrainedConfig from transformers package.

Parameters:

Name Type Description Default
input_last_dim int

last dimension of the input sample in Discriminator.

3
output_last_dim int

last dimension of the output sample in Discriminator.

1
resolution int

resolution of the input image (256x256).

256
ndf int

number of filters in the first layer of Discriminator.

64
n_layers int

number of layers in Discriminator.

3
Source code in modules/config.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
class DiscConfig(PretrainedConfig):
    """Configuration class to store the configuration of a Discriminator model.
    Dataclass for storing is based on `PretrainedConfig` from `transformers` package.

    Args:
        input_last_dim (int): last dimension of the input sample in Discriminator.
        output_last_dim (int): last dimension of the output sample in Discriminator.
        resolution (int): resolution of the input image (256x256).
        ndf (int): number of filters in the first layer of Discriminator.
        n_layers (int): number of layers in Discriminator.
    """

    def __init__(
        self,
        input_last_dim: int = 3,
        output_last_dim: int = 1,
        resolution: int = 256,
        ndf: int = 64,
        n_layers: int = 3,
        **kwargs: Any,
    ):
        super().__init__(**kwargs)
        self.input_last_dim = input_last_dim
        self.output_last_dim = output_last_dim
        self.resolution = resolution
        self.ndf = ndf
        self.n_layers = n_layers

LoadConfig dataclass

Load configuration class to store the configuration of a train model and data. Main configuration class to be used for training.

Parameters:

Name Type Description Default
train DataConfig

data configuration.

required
data TrainConfig

training configuration.

required
Source code in modules/config.py
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
@dataclass
class LoadConfig:
    """Load configuration class to store the configuration of a train model and data.
    Main configuration class to be used for training.

    Arguments:
        train (DataConfig): data configuration.
        data (TrainConfig): training configuration.
    """

    train: TrainConfig
    data: DataConfig

    def __post_init__(self):
        self.train = (
            TrainConfig(**self.train) if self.train is not None else TrainConfig()  # type: ignore
        )
        self.data = (
            DataConfig(**self.data) if self.data is not None else DataConfig()  # type: ignore
        )

        # set resolution
        self.train.disc_hparams.resolution = self.data.size
        self.train.model_hparams.resolution = self.data.size

TrainConfig dataclass

Configuration class to store the configuration of a train model.

Parameters:

Name Type Description Default
model_name str

name of the model to train. Used for saving and logging.

required
model_hparams VQGANConfig

model hyperparameters. Check VQGANConfig for more details.

required
disc_hparams DiscConfig

discriminator hyperparameters. Check DiscConfig for more details.

required
save_dir str

directory to save the model.

required
log_dir str

directory to save the logs for tensorboard.

required
check_val_every_n_epoch int

number of epochs to run validation.

required
log_img_every_n_epoch int

number of epochs to log images.

required
input_shape Tuple[int, int, int]

shape of the input image (H, W, C).

required
codebook_weight float

weight for the codebook loss (Quantizer part).

required
monitor str

metric to monitor for saving best model.

required
recon_loss str

reconstruction loss to use. Can be one of l1, l2, comb, mape.

required
disc_loss str

discriminator loss to use. Can be vanilla or hinge.

required
disc_weight float

weight for the discriminator loss.

required
num_epochs int

number of epochs to train.

required
dtype str

dtype to use for training. Supported: float32, float16, float16, bfloat16.

required
distributed bool

whether to use distributed training.

required
seed int

seed for random number generation.

required
optimizer str

optimizer to use for training. Structure needs to be optax Optimizer (Check optax for more details) with 'target' parameter, for specifing optax optimizer, and 'kwargs' parameter for passing to optimizer. check config_test.yaml for example.

required
optimizer_disc str

optimizer to use for discriminator training. Similar to optimizer.

required
disc_start int

number of epochs to past to start using the discriminator.

required
temp_scheduler optional

temperature scheduler to use for training. Similar to optimizer but uses optax scheduler with 'target' parameter, for specifing optax scheduler. if None, then no scheduler is used. Check config_test.yaml for example.

required
Source code in modules/config.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
@dataclass
class TrainConfig:
    """Configuration class to store the configuration of a train model.

    Arguments:
        model_name (str): name of the model to train. Used for saving and logging.
        model_hparams (VQGANConfig): model hyperparameters. Check VQGANConfig for more details.
        disc_hparams (DiscConfig): discriminator hyperparameters.
            Check DiscConfig for more details.
        save_dir (str): directory to save the model.
        log_dir (str): directory to save the logs for tensorboard.
        check_val_every_n_epoch (int): number of epochs to run validation.
        log_img_every_n_epoch (int): number of epochs to log images.
        input_shape (Tuple[int, int, int]): shape of the input image (H, W, C).
        codebook_weight (float): weight for the codebook loss (Quantizer part).
        monitor (str): metric to monitor for saving best model.
        recon_loss (str): reconstruction loss to use. Can be one of `l1`, `l2`, `comb`, `mape`.
        disc_loss (str): discriminator loss to use. Can be `vanilla` or `hinge`.
        disc_weight (float): weight for the discriminator loss.
        num_epochs (int): number of epochs to train.
        dtype (str): dtype to use for training.
            Supported: `float32`, `float16`, `float16`, `bfloat16`.
        distributed (bool): whether to use distributed training.
        seed (int): seed for random number generation.
        optimizer (str): optimizer to use for training. Structure needs to be
            optax Optimizer (Check optax for more details) with '__target__' parameter,
            for specifing optax optimizer, and 'kwargs' parameter for passing to optimizer.
            check config_test.yaml for example.
        optimizer_disc (str): optimizer to use for discriminator training. Similar to optimizer.
        disc_start (int): number of epochs to past to start using the discriminator.
        temp_scheduler (optional): temperature scheduler to use for training. Similar to optimizer
            but uses optax scheduler with '__target__' parameter, for specifing optax scheduler.
            if None, then no scheduler is used. Check config_test.yaml for example.
    """

    model_name: str
    model_hparams: VQGANConfig
    disc_hparams: DiscConfig
    save_dir: str
    log_dir: str
    check_val_every_n_epoch: int
    log_img_every_n_epoch: int
    input_shape: Tuple[int, ...]
    codebook_weight: float
    monitor: str
    recon_loss: str
    disc_loss: str
    disc_weight: float
    num_epochs: int
    dtype: jnp.dtype
    distributed: bool
    seed: int
    optimizer: optax.GradientTransformation
    optimizer_disc: optax.GradientTransformation
    disc_start: int
    temp_scheduler: Optional[Callable]

    def __post_init__(self):
        # load model hparams
        self.model_hparams = (
            VQGANConfig(**self.model_hparams) if self.model_hparams is not None else VQGANConfig()
        )
        if not isinstance(self.model_hparams, VQGANConfig):
            raise TypeError("model_hparams could not create VQGANConfig")
        # load disc hparams
        self.disc_hparams = (
            DiscConfig(**self.disc_hparams) if self.disc_hparams is not None else DiscConfig()
        )
        if not isinstance(self.disc_hparams, DiscConfig):
            raise TypeError("disc_hparams could not create DiscConfig")
        # conver shape list to tuple shape
        self.input_shape = tuple(self.input_shape)
        if len(self.input_shape) != 3:
            raise ValueError(f"input_shape: {self.input_shape} should be of length 3")
        # set dtype
        if self.dtype == "float64":
            self.dtype = jnp.float64
        elif self.dtype == "float32":
            self.dtype = jnp.float32
        elif self.dtype == "float16":
            self.dtype = jnp.float16
        elif self.dtype == "bfloat16":
            self.dtype = jnp.bfloat16
        else:
            raise ValueError(
                f"""Invalid dtype {self.dtype}
                             expected one of float64, float32, float16, bfloat16"""
            )
        # instantiate the optimizer
        self.optimizer = instantiate(self.optimizer)
        if not isinstance(self.optimizer, optax.GradientTransformation):
            raise TypeError("optimizer should be optax GradientTransformation dict to instantiate")
        # instantiate the optimizer for discriminator
        self.optimizer_disc = instantiate(self.optimizer_disc)
        if not isinstance(self.optimizer_disc, optax.GradientTransformation):
            raise TypeError(
                "optimizer_disc should be optax GradientTransformation dict to instantiate"
            )
        # if optimizer is a dict, instantiate it
        if self.temp_scheduler is not None:
            self.temp_scheduler: Callable = instantiate(self.temp_scheduler)
            if not hasattr(self.temp_scheduler, "__call__"):
                raise TypeError("temp_scheduler should be a callable or None")

VQGANConfig

Bases: PretrainedConfig

Configuration class to store the configuration of a VQGAN model. Dataclass for storing is based on PretrainedConfig from transformers package.

Parameters:

Name Type Description Default
ch int

number of channels.

128
out_ch int

number of output channels (RGB).

3
in_channels int

number of input channels (RGB).

3
num_res_blocks int

number of residual blocks.

2
resolution int

resolution of the input image (256x256).

256
z_channels int

number of channels in the latent space.

256
ch_mult Tuple[int]

channel multiplier for each layer.

tuple([1, 1, 2, 2, 4])
attn_resolutions Tuple[int]

resolutions at which to apply attention.

(16)
n_embed int

number of embeddings, unique codes in the latent space.

1024
embed_dim int

dimension of embedding from Encoder.

256
dropout float

dropout rate.

0.0
double_z bool

whether to double the latent space for.

False
resamp_with_conv bool

whether to use convolutions for upsampling.

True
use_gumbel bool

whether to use gumbel softmax for quantization.

False
gumb_temp float

temperature for gumbel softmax.

1.0
act_name str

activation function name to use.

'swish'
give_pre_end bool

whether to give the pre-end layer for the decoder.

False
kwargs Any

keyword arguments passed along to the super class.

{}
Source code in modules/config.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
class VQGANConfig(PretrainedConfig):
    """Configuration class to store the configuration of a VQGAN model.
    Dataclass for storing is based on `PretrainedConfig` from `transformers` package.

    Args:
        ch (int): number of channels.
        out_ch (int): number of output channels (RGB).
        in_channels (int): number of input channels (RGB).
        num_res_blocks (int): number of residual blocks.
        resolution (int): resolution of the input image (256x256).
        z_channels (int): number of channels in the latent space.
        ch_mult (Tuple[int]): channel multiplier for each layer.
        attn_resolutions (Tuple[int]): resolutions at which to apply attention.
        n_embed (int): number of embeddings, unique codes in the latent space.
        embed_dim (int): dimension of embedding from Encoder.
        dropout (float): dropout rate.
        double_z (bool): whether to double the latent space for.
        resamp_with_conv (bool): whether to use convolutions for upsampling.
        use_gumbel (bool): whether to use gumbel softmax for quantization.
        gumb_temp (float): temperature for gumbel softmax.
        act_name (str): activation function name to use.
        give_pre_end (bool): whether to give the pre-end layer for the decoder.
        kwargs: keyword arguments passed along to the super class.
    """

    def __init__(
        self,
        ch: int = 128,
        out_ch: int = 3,
        in_channels: int = 3,
        num_res_blocks: int = 2,
        resolution: int = 256,
        z_channels: int = 256,
        ch_mult: Tuple[int, ...] = tuple([1, 1, 2, 2, 4]),
        attn_resolutions: Tuple[int] = (16,),
        n_embed: int = 1024,
        embed_dim: int = 256,
        dropout: float = 0.0,
        double_z: bool = False,
        resamp_with_conv: bool = True,
        use_gumbel: bool = False,
        gumb_temp: float = 1.0,
        beta: float = 0.25,
        kl_weight: float = 5e-4,
        act_name: str = "swish",
        give_pre_end: bool = False,
        **kwargs: Any,
    ):
        super().__init__(**kwargs)
        self.ch = ch
        self.out_ch = out_ch
        self.in_channels = in_channels
        self.num_res_blocks = num_res_blocks
        self.resolution = resolution
        self.z_channels = z_channels
        self.ch_mult = list(ch_mult)
        self.attn_resolutions = list(attn_resolutions)
        self.n_embed = n_embed
        self.embed_dim = embed_dim
        self.dropout = dropout
        self.double_z = double_z
        self.resamp_with_conv = resamp_with_conv
        self.use_gumbel = use_gumbel
        self.gumb_temp = gumb_temp
        self.beta = beta
        self.kl_weight = kl_weight
        self.act_name = act_name
        self.give_pre_end = give_pre_end
        self.num_resolutions = len(ch_mult)

Losses

combo_loss(predictions, targets)

Compute combined l1 and l2 loss : l1 if l2 < 0.5 else l2.

Parameters:

Name Type Description Default
predictions jnp.ndarray

Predictions from the model.

required
targets jnp.ndarray

Targets for the model.

required

Returns:

Type Description
jnp.ndarray

Reconstruction loss.

Source code in modules/losses.py
27
28
29
30
31
32
33
34
35
36
37
def combo_loss(predictions: jnp.ndarray, targets: jnp.ndarray) -> jnp.ndarray:
    """Compute combined l1 and l2 loss : l1 if l2 < 0.5 else l2.
    Args:
        predictions (jnp.ndarray): Predictions from the model.
        targets (jnp.ndarray): Targets for the model.
    Returns:
        Reconstruction loss.
    """
    l1 = predictions - targets
    l2 = (predictions - targets) ** 2
    return jnp.where(l2 < 0.5, l1, l2)

disc_loss_hinge(real, fake)

Compute discriminator loss for hinge GAN. Real and fake logits influence the loss the same.

Parameters:

Name Type Description Default
real jnp.ndarray

Real images, received from dataset.

required
fake jnp.ndarray

Fake images, produced by generator.

required

Returns:

Type Description
jnp.ndarray

Discriminator loss.

Source code in modules/losses.py
65
66
67
68
69
70
71
72
73
74
75
76
def disc_loss_hinge(real: jnp.ndarray, fake: jnp.ndarray) -> jnp.ndarray:
    """Compute discriminator loss for hinge GAN.
    Real and fake logits influence the loss the same.
    Args:
        real: Real images, received from dataset.
        fake: Fake images, produced by generator.
    Returns:
        Discriminator loss.
    """
    real_loss = jnp.mean(jnp.maximum(1.0 - real, 0.0))
    loss_fake = jnp.mean(jnp.maximum(1.0 + fake, 0.0))
    return 0.5 * (real_loss + loss_fake)

disc_loss_vanilla(real, fake)

Compute discriminator loss for vanilla GAN. Wrong fake logits impact more the loss than the bad real logits.

Parameters:

Name Type Description Default
real jnp.ndarray

Real images, received from dataset.

required
fake jnp.ndarray

Fake images, produced by generator.

required

Returns:

Type Description
jnp.ndarray

Discriminator loss.

Source code in modules/losses.py
51
52
53
54
55
56
57
58
59
60
61
62
def disc_loss_vanilla(real: jnp.ndarray, fake: jnp.ndarray) -> jnp.ndarray:
    """Compute discriminator loss for vanilla GAN.
    Wrong fake logits impact more the loss than the bad real logits.
    Args:
        real: Real images, received from dataset.
        fake: Fake images, produced by generator.
    Returns:
        Discriminator loss.
    """
    real_loss = jnp.mean(jax.nn.softplus(-real))
    generated_loss = jnp.mean(jax.nn.softplus(fake))
    return 0.5 * (real_loss + generated_loss)

l1_loss(predictions, targets)

Compute L1 loss.

Parameters:

Name Type Description Default
predictions jnp.ndarray

Predictions from the model.

required
targets jnp.ndarray

Targets for the model.

required

Returns:

Type Description
jnp.ndarray

Reconstruction loss.

Source code in modules/losses.py
16
17
18
19
20
21
22
23
24
def l1_loss(predictions: jnp.ndarray, targets: jnp.ndarray) -> jnp.ndarray:
    """Compute L1 loss.
    Args:
        predictions (jnp.ndarray): Predictions from the model.
        targets (jnp.ndarray): Targets for the model.
    Returns:
        Reconstruction loss.
    """
    return jnp.abs(predictions - targets)

l2_loss(predictions, targets)

Compute L2 loss.

Parameters:

Name Type Description Default
predictions jnp.ndarray

Predictions from the model.

required
targets jnp.ndarray

Targets for the model.

required

Returns:

Type Description
jnp.ndarray

Reconstruction loss.

Source code in modules/losses.py
 5
 6
 7
 8
 9
10
11
12
13
def l2_loss(predictions: jnp.ndarray, targets: jnp.ndarray) -> jnp.ndarray:
    """Compute L2 loss.
    Args:
        predictions (jnp.ndarray): Predictions from the model.
        targets (jnp.ndarray): Targets for the model.
    Returns:
        Reconstruction loss.
    """
    return (predictions - targets) ** 2

mape_loss(predictions, targets)

Compute mean absolute percentage error loss.

Parameters:

Name Type Description Default
predictions jnp.ndarray

Predictions from the model.

required
targets jnp.ndarray

Targets for the model.

required

Returns:

Type Description
jnp.ndarray

Reconstruction loss.

Source code in modules/losses.py
40
41
42
43
44
45
46
47
48
def mape_loss(predictions: jnp.ndarray, targets: jnp.ndarray) -> jnp.ndarray:
    """Compute mean absolute percentage error loss.
    Args:
        predictions (jnp.ndarray): Predictions from the model.
        targets (jnp.ndarray): Targets for the model.
    Returns:
        Reconstruction loss.
    """
    return jnp.abs((targets - predictions) / targets)

Models

AttnBlock

Bases: nn.Module

Attention block.

Attributes:

Name Type Description
in_channels int

number of input channels.

dtype jnp.dtype

the dtype of the computation (default: float32).

Source code in modules/models.py
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
class AttnBlock(nn.Module):
    """Attention block.

    Attributes:
        in_channels (int): number of input channels.
        dtype (jnp.dtype): the dtype of the computation (default: float32).
    """

    in_channels: int
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """Forward pass of the block.
        Args:
            x (jnp.ndarray): input tensor.
        Returns:
            output tensor with the same shape as the input.
        """
        h_ = x
        h_ = nn.GroupNorm(num_groups=32, epsilon=1e-6)(h_)
        # get query, key, value
        q = nn.Conv(
            self.in_channels,
            kernel_size=(1, 1),
            strides=(1, 1),
            padding="VALID",
            dtype=self.dtype,
        )(h_)
        k = nn.Conv(
            self.in_channels,
            kernel_size=(1, 1),
            strides=(1, 1),
            padding="VALID",
            dtype=self.dtype,
        )(h_)
        v = nn.Conv(
            self.in_channels,
            kernel_size=(1, 1),
            strides=(1, 1),
            padding="VALID",
            dtype=self.dtype,
        )(h_)

        # compute attention
        q = rearrange(q, "B H W C -> B (H W) C")
        k = rearrange(k, "B H W C -> B (H W) C")
        w_ = jnp.einsum("bqc,bkc->bqk", q, k)
        w_ *= int(x.shape[-1]) ** (-0.5)
        w_ = nn.softmax(w_, axis=2)

        # attend to values
        v = rearrange(v, "B H W C -> B (H W) C")
        h_ = jnp.einsum("bkc,bqk->bqc", v, w_)
        h_ = rearrange(h_, "B (H W) C -> B H W C", H=x.shape[1], W=x.shape[2])

        h_ = nn.Conv(
            self.in_channels,
            kernel_size=(1, 1),
            strides=(1, 1),
            padding="VALID",
            dtype=self.dtype,
        )(h_)
        h_ += x
        return h_

__call__(x)

Forward pass of the block.

Parameters:

Name Type Description Default
x jnp.ndarray

input tensor.

required

Returns:

Type Description
jnp.ndarray

output tensor with the same shape as the input.

Source code in modules/models.py
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
    """Forward pass of the block.
    Args:
        x (jnp.ndarray): input tensor.
    Returns:
        output tensor with the same shape as the input.
    """
    h_ = x
    h_ = nn.GroupNorm(num_groups=32, epsilon=1e-6)(h_)
    # get query, key, value
    q = nn.Conv(
        self.in_channels,
        kernel_size=(1, 1),
        strides=(1, 1),
        padding="VALID",
        dtype=self.dtype,
    )(h_)
    k = nn.Conv(
        self.in_channels,
        kernel_size=(1, 1),
        strides=(1, 1),
        padding="VALID",
        dtype=self.dtype,
    )(h_)
    v = nn.Conv(
        self.in_channels,
        kernel_size=(1, 1),
        strides=(1, 1),
        padding="VALID",
        dtype=self.dtype,
    )(h_)

    # compute attention
    q = rearrange(q, "B H W C -> B (H W) C")
    k = rearrange(k, "B H W C -> B (H W) C")
    w_ = jnp.einsum("bqc,bkc->bqk", q, k)
    w_ *= int(x.shape[-1]) ** (-0.5)
    w_ = nn.softmax(w_, axis=2)

    # attend to values
    v = rearrange(v, "B H W C -> B (H W) C")
    h_ = jnp.einsum("bkc,bqk->bqc", v, w_)
    h_ = rearrange(h_, "B (H W) C -> B H W C", H=x.shape[1], W=x.shape[2])

    h_ = nn.Conv(
        self.in_channels,
        kernel_size=(1, 1),
        strides=(1, 1),
        padding="VALID",
        dtype=self.dtype,
    )(h_)
    h_ += x
    return h_

Decoder

Bases: nn.Module

Decoder of VQ-GAN to map input batch of latent space to images. Dimension Transformations originally: 32x32x256 --Conv2d--> 32x32x512 --MidBlock--> 32x32x512

for loop

--UpsamplingBlock--> 64x64x256 --UpsamplingBlock--> 128x128x128 --UpsamplingBlock--> 256x256x64 --UpsamplingBlock--> 256x256x32

--GroupNorm--> --nonlinear--> --Conv2d-> 256x256x3

Attributes:

Name Type Description
config VQGANConfig

the config of the model.

act_fn str

activation function.

dtype jnp.dtype

the dtype of the computation (default: float32).

Source code in modules/models.py
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
class Decoder(nn.Module):
    """
    Decoder of VQ-GAN to map input batch of latent space to images.
    Dimension Transformations originally:
    32x32x256 --Conv2d--> 32x32x512
    --MidBlock--> 32x32x512
    for loop:
        --UpsamplingBlock--> 64x64x256
        --UpsamplingBlock--> 128x128x128
        --UpsamplingBlock--> 256x256x64
        --UpsamplingBlock--> 256x256x32
    --GroupNorm-->
    --nonlinear-->
    --Conv2d-> 256x256x3

    Attributes:
        config (VQGANConfig): the config of the model.
        act_fn (str): activation function.
        dtype (jnp.dtype): the dtype of the computation (default: float32).
    """

    config: VQGANConfig
    act_fn: Callable = nn.gelu
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(self, z: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
        # time embedding
        temb_ch: int = 0
        temb: Optional[jnp.ndarray] = None

        # compute in_ch_mult and block_in at lowest res
        block_in = self.config.ch * self.config.ch_mult[self.config.num_resolutions - 1]

        # z_channel to block_in
        x = nn.Conv(
            block_in,
            kernel_size=(3, 3),
            strides=(1, 1),
            padding=((1, 1), (1, 1)),
            dtype=self.dtype,
        )(z)

        # middle
        x = MidBlock(
            block_in,
            act_fn=self.act_fn,
            temb_channels=temb_ch,
            dropout_prob=self.config.dropout,
            dtype=self.dtype,
        )(x, temb, deterministic=deterministic)

        # upsampling
        # compute curr_res at lowest res
        curr_res = self.config.resolution // 2 ** (self.config.num_resolutions - 1)
        for i in reversed(range(self.config.num_resolutions)):
            x = UpsamplingBlock(self.config, curr_res, block_idx=i, dtype=self.dtype)(
                x, temb, deterministic=deterministic
            )

            # update resolution if not end
            curr_res = curr_res * 2 if i != self.config.num_resolutions - 1 else curr_res

        # end
        if self.config.give_pre_end:
            return x

        # CFN
        x = nn.GroupNorm(num_groups=32, dtype=self.dtype)(x)
        x = self.act_fn(x)
        x = nn.Conv(
            self.config.out_ch,
            kernel_size=(3, 3),
            strides=(1, 1),
            padding=((1, 1), (1, 1)),
            dtype=self.dtype,
        )(x)

        return x

Downsample

Bases: nn.Module

Downsample the input by a factor of 2.

Attributes:

Name Type Description
in_channels int

Number of input channels.

use_conv bool

Whether to use a identity convolution.

dtype jnp.dtype

the dtype of the computation (default: float32).

Source code in modules/models.py
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
class Downsample(nn.Module):
    """Downsample the input by a factor of 2.

    Attributes:
        in_channels (int): Number of input channels.
        use_conv (bool): Whether to use a identity convolution.
        dtype (jnp.dtype): the dtype of the computation (default: float32).
    """

    in_channels: int
    use_conv: bool
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        if self.use_conv:
            pad = ((0, 0), (0, 1), (0, 1), (0, 0))  # pad height and width dim
            x = jnp.pad(x, pad, mode="constant", constant_values=0)
            x = nn.Conv(
                features=self.in_channels,
                kernel_size=(3, 3),
                strides=(2, 2),
                padding="VALID",
                dtype=self.dtype,
            )(x)
        else:
            x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2), padding="VALID")
        return x

DownsamplingBlock

Bases: nn.Module

Downsampling block for Encoder.

Attributes:

Name Type Description
config VQGANConfig

the config of the model.

curr_res int

current resolution.

blck_idx int

current block index.

act_fn Callable

activation function.

dtype jnp.dtype

the dtype of the computation (default: float32).

Source code in modules/models.py
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
class DownsamplingBlock(nn.Module):
    """Downsampling block for Encoder.

    Attributes:
        config (VQGANConfig): the config of the model.
        curr_res (int): current resolution.
        blck_idx (int): current block index.
        act_fn (Callable): activation function.
        dtype (jnp.dtype): the dtype of the computation (default: float32).
    """

    config: VQGANConfig
    curr_res: int
    block_idx: int
    act_fn: Callable = nn.gelu
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        # get input and output numb of channels based on config and current block index.
        in_ch_mult: Tuple[int, ...] = (1,) + tuple(self.config.ch_mult)
        block_in: int = self.config.ch * in_ch_mult[self.block_idx]
        block_out: int = self.config.ch * self.config.ch_mult[self.block_idx]

        # temporary embedding channels
        self.temb_ch: int = 0

        # build blocks
        res_blocks = []
        attn_blocks = []
        for _ in range(self.config.num_res_blocks):
            assert block_in % 32 == 0, "block_in must be divisible by 32 for GroupNorm"
            res_blocks.append(
                ResNetBlock(
                    block_in,
                    block_out,
                    act_fn=self.act_fn,
                    temb_channels=self.temb_ch,
                    dropout_prob=self.config.dropout,
                    dtype=self.dtype,
                )
            )
            # after channel downsample rest are identity blocks.
            block_in = block_out
            # check if we need to add attention block based on configs
            if self.curr_res in self.config.attn_resolutions:
                assert block_in % 32 == 0, "block_in must be divisible by 32 for GroupNorm"
                attn_blocks.append(AttnBlock(block_in, dtype=self.dtype))

        self.blocks = res_blocks
        self.attns = attn_blocks

        # downsample if not last block
        self.downsample = None
        if self.block_idx != self.config.num_resolutions - 1:
            self.downsample = Downsample(block_in, self.config.resamp_with_conv, dtype=self.dtype)

    def __call__(
        self, x: jnp.ndarray, temb: Optional[jnp.ndarray] = None, deterministic: bool = True
    ) -> jnp.ndarray:
        """Forward pass of the block.
        Args:
            x (jnp.ndarray): input tensor.
            temb (Optional[jnp.ndarray], optional): temporal embedding. Defaults to None.
            deterministic (bool, optional): deterministic flag. Defaults to True.
        """
        assert temb is None, "DownsamplingBlock don't use temporal embedding"
        for i, res_block in enumerate(self.blocks):
            x = res_block(x, temb, deterministic=deterministic)

            if self.attns:
                x = self.attns[i](x)

        if self.downsample is not None:
            x = self.downsample(x)

        return x

__call__(x, temb=None, deterministic=True)

Forward pass of the block.

Parameters:

Name Type Description Default
x jnp.ndarray

input tensor.

required
temb Optional[jnp.ndarray]

temporal embedding. Defaults to None.

None
deterministic bool

deterministic flag. Defaults to True.

True
Source code in modules/models.py
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
def __call__(
    self, x: jnp.ndarray, temb: Optional[jnp.ndarray] = None, deterministic: bool = True
) -> jnp.ndarray:
    """Forward pass of the block.
    Args:
        x (jnp.ndarray): input tensor.
        temb (Optional[jnp.ndarray], optional): temporal embedding. Defaults to None.
        deterministic (bool, optional): deterministic flag. Defaults to True.
    """
    assert temb is None, "DownsamplingBlock don't use temporal embedding"
    for i, res_block in enumerate(self.blocks):
        x = res_block(x, temb, deterministic=deterministic)

        if self.attns:
            x = self.attns[i](x)

    if self.downsample is not None:
        x = self.downsample(x)

    return x

Encoder

Bases: nn.Module

Encoder of VQ-GAN to map input batch of images to latent space. Dimension Transformations originally: 256x256x3 --Conv2d--> 256x256x32

for loop

--DownsamplingBlock--> 128x128x64 --DownsamplingBlock--> 64x64x128 --DownsamplingBlock--> 32x32x256 --DownsamplingBlock--> 32x32x512

--MidBlock--> 32x32x512 --GroupNorm--> --nonlinear--> --Conv2d-> 32x32x256

Attributes:

Name Type Description
config VQGANConfig

the config of the model.

act_fn str

activation function.

dtype jnp.dtype

the dtype of the computation (default: float32).

Source code in modules/models.py
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
class Encoder(nn.Module):
    """
    Encoder of VQ-GAN to map input batch of images to latent space.
    Dimension Transformations originally:
    256x256x3 --Conv2d--> 256x256x32
    for loop:
        --DownsamplingBlock--> 128x128x64
        --DownsamplingBlock--> 64x64x128
        --DownsamplingBlock--> 32x32x256
        --DownsamplingBlock--> 32x32x512

    --MidBlock--> 32x32x512
    --GroupNorm-->
    --nonlinear-->
    --Conv2d-> 32x32x256

    Attributes:
        config (VQGANConfig): the config of the model.
        act_fn (str): activation function.
        dtype (jnp.dtype): the dtype of the computation (default: float32).
    """

    config: VQGANConfig
    act_fn: Callable = nn.gelu
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
        # time embedding
        temb_ch: int = 0
        temb: Optional[jnp.ndarray] = None

        # downsampling
        x = nn.Conv(
            self.config.ch,
            kernel_size=(3, 3),
            strides=(1, 1),
            padding=((1, 1), (1, 1)),
            dtype=self.dtype,
        )(x)
        curr_res = self.config.resolution
        for i in range(self.config.num_resolutions):
            x = DownsamplingBlock(self.config, curr_res, block_idx=i, dtype=self.dtype)(
                x, temb, deterministic=deterministic
            )

            # update resolution if not bottleneck
            curr_res = curr_res // 2 if i != self.config.num_resolutions - 1 else curr_res
        # middle
        mid_channels = self.config.ch * self.config.ch_mult[-1]
        x = MidBlock(
            mid_channels,
            act_fn=self.act_fn,
            temb_channels=temb_ch,
            dropout_prob=self.config.dropout,
            dtype=self.dtype,
        )(x, temb, deterministic=deterministic)

        # end CFN
        x = nn.GroupNorm(num_groups=32, dtype=self.dtype)(x)
        x = self.act_fn(x)
        x = nn.Conv(
            2 * self.config.z_channels if self.config.double_z else self.config.z_channels,
            kernel_size=(3, 3),
            strides=(1, 1),
            padding=((1, 1), (1, 1)),
            dtype=self.dtype,
        )(x)

        return x

MidBlock

Bases: nn.Module

Mid block for Encoder and Decoder.

Attributes:

Name Type Description
in_channels int

number of input channels.

act_fn str

activation function.

temb_channels int

number of channels for temporal embedding.

dropout_prob float

dropout probability.

dtype jnp.dtype

the dtype of the computation (default: float32).

Source code in modules/models.py
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
class MidBlock(nn.Module):
    """Mid block for Encoder and Decoder.

    Attributes:
        in_channels (int): number of input channels.
        act_fn (str): activation function.
        temb_channels (int): number of channels for temporal embedding.
        dropout_prob (float): dropout probability.
        dtype (jnp.dtype): the dtype of the computation (default: float32).
    """

    in_channels: int
    act_fn: Callable
    temb_channels: int
    dropout_prob: float
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(
        self, x: jnp.ndarray, temb: Optional[jnp.ndarray] = None, deterministic: bool = True
    ) -> jnp.ndarray:
        """BxWxHxC --ResNet--> BxWxHxC --Attn--> BxWxHxC --ResNet--> BxWxHxC

        Args:
            x (jnp.ndarray): input tensor.
            temb (Optional[jnp.ndarray], optional): temporal embedding. Defaults to None.
            deterministic (bool, optional): deterministic flag. Defaults to True.
        """
        assert self.in_channels % 32 == 0, "block_in must be divisible by 32 for GroupNorm"
        x = ResNetBlock(
            self.in_channels,
            self.in_channels,
            act_fn=self.act_fn,
            temb_channels=self.temb_channels,
            dropout_prob=self.dropout_prob,
            dtype=self.dtype,
        )(x, temb, deterministic=deterministic)
        x = AttnBlock(self.in_channels, dtype=self.dtype)(x)
        x = ResNetBlock(
            self.in_channels,
            self.in_channels,
            act_fn=self.act_fn,
            temb_channels=self.temb_channels,
            dropout_prob=self.dropout_prob,
            dtype=self.dtype,
        )(x, temb, deterministic=deterministic)

        return x

__call__(x, temb=None, deterministic=True)

BxWxHxC --ResNet--> BxWxHxC --Attn--> BxWxHxC --ResNet--> BxWxHxC

Parameters:

Name Type Description Default
x jnp.ndarray

input tensor.

required
temb Optional[jnp.ndarray]

temporal embedding. Defaults to None.

None
deterministic bool

deterministic flag. Defaults to True.

True
Source code in modules/models.py
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
@nn.compact
def __call__(
    self, x: jnp.ndarray, temb: Optional[jnp.ndarray] = None, deterministic: bool = True
) -> jnp.ndarray:
    """BxWxHxC --ResNet--> BxWxHxC --Attn--> BxWxHxC --ResNet--> BxWxHxC

    Args:
        x (jnp.ndarray): input tensor.
        temb (Optional[jnp.ndarray], optional): temporal embedding. Defaults to None.
        deterministic (bool, optional): deterministic flag. Defaults to True.
    """
    assert self.in_channels % 32 == 0, "block_in must be divisible by 32 for GroupNorm"
    x = ResNetBlock(
        self.in_channels,
        self.in_channels,
        act_fn=self.act_fn,
        temb_channels=self.temb_channels,
        dropout_prob=self.dropout_prob,
        dtype=self.dtype,
    )(x, temb, deterministic=deterministic)
    x = AttnBlock(self.in_channels, dtype=self.dtype)(x)
    x = ResNetBlock(
        self.in_channels,
        self.in_channels,
        act_fn=self.act_fn,
        temb_channels=self.temb_channels,
        dropout_prob=self.dropout_prob,
        dtype=self.dtype,
    )(x, temb, deterministic=deterministic)

    return x

ResNetBlock

Bases: nn.Module

ResNet block with optional bottleneck.

Attributes:

Name Type Description
in_channels int

number of input channels.

out_channels Optional[int]

number of output channels. If None, the output channels will be the same as the input channels.

act_fn Callable

activation function.

use_conv_shortcut bool

whether to use a convolutional shortcut.

temb_channels jnp.ndarray

number of channels in the temporal embedding.

dropout_prob float

dropout probability.

dtype jnp.dtype

the dtype of the computation (default: float32).

Source code in modules/models.py
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
class ResNetBlock(nn.Module):
    """ResNet block with optional bottleneck.

    Attributes:
        in_channels (int): number of input channels.
        out_channels (Optional[int]): number of output channels.
            If None, the output channels will be the same as the input channels.
        act_fn (Callable): activation function.
        use_conv_shortcut (bool): whether to use a convolutional shortcut.
        temb_channels (jnp.ndarray): number of channels in the temporal embedding.
        dropout_prob (float): dropout probability.
        dtype (jnp.dtype): the dtype of the computation (default: float32).
    """

    in_channels: int
    out_channels: Optional[int] = None
    act_fn: Callable = nn.gelu
    use_conv_shortcut: bool = False
    temb_channels: int = 512
    dropout_prob: float = 0.0
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        # out_channels not specified, we will residual block will be identity
        self.out_channels_: int = (
            self.in_channels if self.out_channels is None else self.out_channels
        )

        #  First block
        self.block1 = nn.Sequential(
            [
                nn.GroupNorm(num_groups=32, epsilon=1e-6),
                self.act_fn,
                nn.Conv(
                    self.out_channels_,
                    kernel_size=(3, 3),
                    strides=(1, 1),
                    padding=((1, 1), (1, 1)),
                    dtype=self.dtype,
                ),
            ]
        )

        # Project temporary embedding to be add to hidden states
        if self.temb_channels:
            self.temb_proj = nn.Dense(self.out_channels_, dtype=self.dtype)

        # Second block
        self.block2_pre = nn.Sequential([nn.GroupNorm(num_groups=32, epsilon=1e-6), self.act_fn])
        self.block2_drop = nn.Dropout(self.dropout_prob)
        self.bloc2_conv = nn.Conv(
            self.out_channels_,
            kernel_size=(3, 3),
            strides=(1, 1),
            padding=((1, 1), (1, 1)),
            dtype=self.dtype,
        )

        # if output x channels do not match residual channels, we need to project residual
        if self.in_channels != self.out_channels_:
            if self.use_conv_shortcut:
                # Reduce by conv(3,3) only channels with learned spatial features
                self.conv_shortcut = nn.Conv(
                    self.out_channels_,
                    kernel_size=(3, 3),
                    strides=(1, 1),
                    padding=((1, 1), (1, 1)),
                    dtype=self.dtype,
                )
            else:
                # Reduce by conv(1,1) only channels
                self.nin_shortcut = nn.Conv(
                    self.out_channels_,
                    kernel_size=(1, 1),
                    strides=(1, 1),
                    padding="VALID",
                    dtype=self.dtype,
                )

    def __call__(
        self,
        x: jnp.ndarray,
        temb: Optional[jnp.ndarray] = None,
        deterministic: bool = True,
    ) -> jnp.ndarray:
        """Forward pass of the block.
        Args:
            x (jnp.ndarray): input tensor.
            temb (Optional[int], optional): temporal embedding. Defaults to None.
            deterministic (bool, optional): deterministic flag. Defaults to True.
        Returns:
            output tensor with the out_channels dimension as the last dimension (C).
        """
        residual = x
        x = self.block1(x)
        if temb is not None:
            # transform temporal embedding to match hidden states [BxT] -> [BxC] -> [Bx1x1xC]
            x = x + self.temb_proj(self.act_fn(temb))[:, None, None, :]

        x = self.block2_pre(x)
        x = self.block2_drop(x, deterministic=deterministic)
        x = self.bloc2_conv(x)
        if self.in_channels != self.out_channels_:
            if self.use_conv_shortcut:
                residual = self.conv_shortcut(residual)
            else:
                residual = self.nin_shortcut(residual)

        return x + residual

__call__(x, temb=None, deterministic=True)

Forward pass of the block.

Parameters:

Name Type Description Default
x jnp.ndarray

input tensor.

required
temb Optional[int]

temporal embedding. Defaults to None.

None
deterministic bool

deterministic flag. Defaults to True.

True

Returns:

Type Description
jnp.ndarray

output tensor with the out_channels dimension as the last dimension (C).

Source code in modules/models.py
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
def __call__(
    self,
    x: jnp.ndarray,
    temb: Optional[jnp.ndarray] = None,
    deterministic: bool = True,
) -> jnp.ndarray:
    """Forward pass of the block.
    Args:
        x (jnp.ndarray): input tensor.
        temb (Optional[int], optional): temporal embedding. Defaults to None.
        deterministic (bool, optional): deterministic flag. Defaults to True.
    Returns:
        output tensor with the out_channels dimension as the last dimension (C).
    """
    residual = x
    x = self.block1(x)
    if temb is not None:
        # transform temporal embedding to match hidden states [BxT] -> [BxC] -> [Bx1x1xC]
        x = x + self.temb_proj(self.act_fn(temb))[:, None, None, :]

    x = self.block2_pre(x)
    x = self.block2_drop(x, deterministic=deterministic)
    x = self.bloc2_conv(x)
    if self.in_channels != self.out_channels_:
        if self.use_conv_shortcut:
            residual = self.conv_shortcut(residual)
        else:
            residual = self.nin_shortcut(residual)

    return x + residual

Upsample

Bases: nn.Module

Upsample the input by a factor of 2.

Attributes:

Name Type Description
in_channels int

Number of input channels.

use_conv bool

Whether to use a identity convolution.

dtype jnp.dtype

the dtype of the computation (default: float32).

Source code in modules/models.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class Upsample(nn.Module):
    """Upsample the input by a factor of 2.

    Attributes:
        in_channels (int): Number of input channels.
        use_conv (bool): Whether to use a identity convolution.
        dtype (jnp.dtype): the dtype of the computation (default: float32).
    """

    in_channels: int
    use_conv: bool
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        batch, height, width, channels = x.shape
        x = jax.image.resize(
            x,
            shape=(batch, height * 2, width * 2, channels),
            method="nearest",
        )
        if self.use_conv:
            x = nn.Conv(
                features=self.in_channels,
                kernel_size=(3, 3),
                strides=(1, 1),
                padding=((1, 1), (1, 1)),
                dtype=self.dtype,
            )(x)
        return x

UpsamplingBlock

Bases: nn.Module

Upsampling block for Decoder.

Attributes:

Name Type Description
config VQGANConfig

the config of the model.

curr_res int

current resolution.

blck_idx int

current block index.

act_fn Callable

activation function.

dtype jnp.dtype

the dtype of the computation (default: float32).

Source code in modules/models.py
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
class UpsamplingBlock(nn.Module):
    """Upsampling block for Decoder.

    Attributes:
        config (VQGANConfig): the config of the model.
        curr_res (int): current resolution.
        blck_idx (int): current block index.
        act_fn (Callable): activation function.
        dtype (jnp.dtype): the dtype of the computation (default: float32).
    """

    config: VQGANConfig
    curr_res: int
    block_idx: int
    act_fn: Callable = nn.gelu
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        # get input numb of channels based on config and current block index.
        # Looking in reverse on self.config.ch_mult variable.
        block_in: int = self.config.ch
        if self.block_idx == self.config.num_resolutions - 1:
            block_in *= self.config.ch_mult[-1]
        else:
            block_in *= self.config.ch_mult[self.block_idx + 1]

        # get output numb of channels based on config and current block index
        block_out: int = self.config.ch * self.config.ch_mult[self.block_idx]
        # temporary embedding channels UpsamplingBlock don't use temporal embedding
        self.temb_ch: int = 0

        # build blocks
        res_blocks = []
        attn_blocks = []
        for _ in range(self.config.num_res_blocks + 1):
            assert block_in % 32 == 0, "block_in must be divisible by 32 for GroupNorm"
            res_blocks.append(
                ResNetBlock(
                    block_in,
                    block_out,
                    act_fn=self.act_fn,
                    temb_channels=self.temb_ch,
                    dropout_prob=self.config.dropout,
                    dtype=self.dtype,
                )
            )
            # after channel resize rest are identity blocks.
            block_in = block_out
            # check if we need to add attention block based on configs
            if self.curr_res in self.config.attn_resolutions:
                assert block_in % 32 == 0, "block_in must be divisible by 32 for GroupNorm"
                attn_blocks.append(AttnBlock(block_in, dtype=self.dtype))

        self.blocks = res_blocks
        self.attns = attn_blocks

        # upsample if not first block
        self.upsample = None
        if self.block_idx != 0:
            self.upsample = Upsample(block_in, self.config.resamp_with_conv, dtype=self.dtype)

    def __call__(
        self, x: jnp.ndarray, temb: Optional[jnp.ndarray] = None, deterministic: bool = True
    ) -> jnp.ndarray:
        """Forward pass of the block.
        Args:
            x (jnp.ndarray): input tensor.
            temb (Optional[jnp.ndarray], optional): temporal embedding. Defaults to None.
            deterministic (bool, optional): deterministic flag. Defaults to True.
        """
        assert temb is None, "UpsamplingBlock don't use temporal embedding"
        for i, res_block in enumerate(self.blocks):
            x = res_block(x, temb, deterministic=deterministic)

            if self.attns:
                x = self.attns[i](x)

        if self.upsample is not None:
            x = self.upsample(x)

        return x

__call__(x, temb=None, deterministic=True)

Forward pass of the block.

Parameters:

Name Type Description Default
x jnp.ndarray

input tensor.

required
temb Optional[jnp.ndarray]

temporal embedding. Defaults to None.

None
deterministic bool

deterministic flag. Defaults to True.

True
Source code in modules/models.py
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
def __call__(
    self, x: jnp.ndarray, temb: Optional[jnp.ndarray] = None, deterministic: bool = True
) -> jnp.ndarray:
    """Forward pass of the block.
    Args:
        x (jnp.ndarray): input tensor.
        temb (Optional[jnp.ndarray], optional): temporal embedding. Defaults to None.
        deterministic (bool, optional): deterministic flag. Defaults to True.
    """
    assert temb is None, "UpsamplingBlock don't use temporal embedding"
    for i, res_block in enumerate(self.blocks):
        x = res_block(x, temb, deterministic=deterministic)

        if self.attns:
            x = self.attns[i](x)

    if self.upsample is not None:
        x = self.upsample(x)

    return x

Training

GenerateCallback

Callback that generates and logs images during training.

Source code in modules/training.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
class GenerateCallback:
    """Callback that generates and logs images during training."""

    def __init__(self, input_imgs: Any, rng: Union[Any, jnp.ndarray], every_n_epochs: int = 1):
        """Initialize the callback.

        Args:
            input_imgs (Any): Images to reconstruct during training
            every_n_epochs (int, optional):
                Only save those images every N epochs (otherwise tensorboard gets quite large).
                Defaults to 1.
        """
        super().__init__()
        self.input_imgs = input_imgs
        self.rng = rng
        self.every_n_epochs = every_n_epochs

    def log_generations(
        self,
        model: FlaxPreTrainedModel,
        model_params: train_state.TrainState,
        logger_tb: tf.summary.SummaryWriter,
        epoch: int,
    ):
        if epoch % self.every_n_epochs == 0:
            logger.info("Logging images to tensorboard at epoch %d", epoch)
            self.rng, gumble_apply_rng, dropout_apply_rng = jax.random.split(self.rng, num=3)
            reconst_imgs = model(
                self.input_imgs,
                params=model_params.params,
                dropout_rng=dropout_apply_rng,
                gumble_rng=gumble_apply_rng,
                train=False,
            )[0]
            reconst_imgs = jax.device_get(reconst_imgs)

            # Plot and add to tensorboard
            imgs = np.stack([self.input_imgs, reconst_imgs], axis=1).reshape(
                -1, *self.input_imgs.shape[1:]
            )
            imgs = np.stack([utils.post_processing(img, resize=64) for img in imgs], axis=0)
            img_to_log = utils.make_img_grid(imgs, nrows=2)
            with logger_tb.as_default():
                tf.summary.image("Reconstructions", [img_to_log], step=epoch)

__init__(input_imgs, rng, every_n_epochs=1)

Initialize the callback.

Parameters:

Name Type Description Default
input_imgs Any

Images to reconstruct during training

required
every_n_epochs int

Only save those images every N epochs (otherwise tensorboard gets quite large). Defaults to 1.

1
Source code in modules/training.py
25
26
27
28
29
30
31
32
33
34
35
36
37
def __init__(self, input_imgs: Any, rng: Union[Any, jnp.ndarray], every_n_epochs: int = 1):
    """Initialize the callback.

    Args:
        input_imgs (Any): Images to reconstruct during training
        every_n_epochs (int, optional):
            Only save those images every N epochs (otherwise tensorboard gets quite large).
            Defaults to 1.
    """
    super().__init__()
    self.input_imgs = input_imgs
    self.rng = rng
    self.every_n_epochs = every_n_epochs

TrainStateDisc

Bases: train_state.TrainState

Train state for discriminator.

Attributes:

Name Type Description
apply_fn Callable

The function that applies the model.

step int

The current step.

params FrozenDict

The model parameters.

batch_stats FrozenDict

The batch statistics. Defaults to None.

tx optax.GradientTransformation

The optimizer. Defaults to None.

opt_state optax.OptState

The optimizer state. Defaults to None.

Source code in modules/training.py
323
324
325
326
327
328
329
330
331
332
333
334
class TrainStateDisc(train_state.TrainState):
    """Train state for discriminator.
    Attributes:
        apply_fn (Callable): The function that applies the model.
        step (int): The current step.
        params (FrozenDict): The model parameters.
        batch_stats (FrozenDict): The batch statistics. Defaults to None.
        tx (optax.GradientTransformation): The optimizer. Defaults to None.
        opt_state (optax.OptState): The optimizer state. Defaults to None.
    """

    batch_stats: Optional[FrozenDict] = None

TrainerModule

Helper functions for training.

Source code in modules/training.py
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
class TrainerModule:
    """Helper functions for training."""

    def __init__(self, module_config: config.TrainConfig, model_class: FlaxPreTrainedModel):
        """Module for summarizing all common training functionalities.

        Args:
            module_config (TrainConfig): Configuration for training
                with all hyperparameters and train module parameters.
            model_class (FlaxPreTrainedModel): Model class to be trained.
        """
        super().__init__()
        self.module_config = module_config
        self.eval_key = module_config.monitor
        self.main_rng = jax.random.PRNGKey(self.module_config.seed)
        self.generate_callback: Optional[GenerateCallback] = None
        self.generator_callback_rng: Union[Any, jnp.ndarray]
        self.main_rng, self.generator_callback_rng = jax.random.split(self.main_rng)
        # Set model name
        self.model_name = self.module_config.model_name
        self.model_class = model_class
        self.model = model_class(
            self.module_config.model_hparams,
            input_shape=(1,) + self.module_config.input_shape,
            seed=self.module_config.seed,
            dtype=self.module_config.dtype,
        )

        # Set training parameters
        self.state = train_state.TrainState(
            step=0,
            apply_fn=self.model.__call__,
            params=self.model.params,
            tx=None,
            opt_state=None,
        )
        self.start_step = 0

        # Prepare logging
        self.log_dir: str = os.path.join(self.module_config.log_dir, f"{self.model_name}/")
        self.logger: tf.summary.SummaryWriter = tf.summary.create_file_writer(self.log_dir)
        self.save_dir: str = os.path.join(self.module_config.save_dir, f"{self.model_name}/")
        # Create jitted training and eval functions
        self.create_functions()

    def create_functions(self):
        """To be implemented in sub-classes."""
        raise NotImplementedError

    def init_optimizer(self):
        """Initialize optimizer and scheduler.
        By default, we decrease the learning rate with cosine annealing.
        """
        optimizer: optax.GradientTransformation = self.module_config.optimizer
        self.create_train_state(optimizer)

    def create_train_state(self, optimizer: optax.GradientTransformation):
        """Initialize training state."""
        self.state = train_state.TrainState.create(
            apply_fn=self.state.apply_fn, params=self.state.params, tx=optimizer
        )

    def train_model(self, train_loader: utils.DataLoader, val_loader: utils.DataLoader):
        """Train model for defined number of epochs.
        Args:
            train_loader (utils.DataLoader): Training data loader.
            val_loader (utils.DataLoader): Validation data loader.
        """
        # We first need to create optimizer and the scheduler for the given number of epochs
        self.init_optimizer()
        # Track best eval metric
        logger.info("Starting training 💃")
        best_eval = None
        with self.logger.as_default():
            for epoch_idx in range(1 + self.start_step, self.module_config.num_epochs + 1):
                logger.info("Epoch: %d", epoch_idx)
                train_metrics = self.train_epoch(train_loader, epoch=epoch_idx)
                train_metrics_str = "Training metrics:"
                for key in train_metrics:
                    tf.summary.scalar(f"train/{key}", train_metrics[key], step=epoch_idx)
                    train_metrics_str += f" {key}: {train_metrics[key]:.4f},"
                logger.info(train_metrics_str)
                if epoch_idx % self.module_config.check_val_every_n_epoch == 0:
                    eval_metrics = self.eval_model(val_loader)
                    eval_metrics_str = "Evaluation metrics:"
                    for key in eval_metrics:
                        tf.summary.scalar(f"val/{key}", eval_metrics[key], step=epoch_idx)
                        eval_metrics_str += f" {key}: {eval_metrics[key]:.4f},"
                    logger.info(eval_metrics_str)
                    if best_eval is None or eval_metrics[self.eval_key] > best_eval:
                        best_eval = eval_metrics[self.eval_key]
                        self.save_model(epoch_idx)

                if self.generate_callback is None:
                    for batch in train_loader():
                        if len(batch) > 8:
                            batch = batch[:8]
                        self.generate_callback = GenerateCallback(
                            batch,
                            self.generator_callback_rng,
                            every_n_epochs=self.module_config.log_img_every_n_epoch,
                        )
                        del self.generator_callback_rng
                        break

                if self.generate_callback is not None:
                    self.generate_callback.log_generations(
                        self.model, self.state, self.logger, epoch=epoch_idx
                    )

                self.logger.flush()
        logger.info("Finished training ✅ with best eval metric: %f 😎", best_eval)

    def train_epoch(self, data_loader: utils.DataLoader, epoch: int) -> Dict[str, float]:
        """Train model for one epoch, and log avg metrics.
        Args:
            data_loader (utils.DataLoader): Data loader to train on.
            epoch (int): Current epoch.
        Returns:
            Dictionary with all metrics.
        """
        metrics: Dict[str, float] = defaultdict(float)
        for batch in tqdm(data_loader(), desc="Training", leave=False):
            # ensure that model have actual parameters
            self.model.params = self.state.params
            batch_metrics: Dict[str, float]
            self.state, self.main_rng, batch_metrics = self.train_step(  # type: ignore
                state=self.state,
                batch=batch,
                rng=self.main_rng,
                distributed=self.module_config.distributed,
            )
            for key in batch_metrics:
                metrics[key] += batch_metrics[key]

        count = len(data_loader)
        metrics = {key: metrics[key] / count for key in metrics}
        return metrics

    @staticmethod
    def train_step(
        state: train_state.TrainState,
        batch: Any,
        rng: Union[Any, jnp.ndarray],
        distributed: bool = False,
        *args,
        **kwargs,
    ) -> Any:
        """Train model on a single batch.

        Args:
            state (TrainState): Current training state.
            batch (Any): Batch of data.
            rng (Union[Any, jnp.ndarray]): Random number generator.
            distributed (bool, optional): Whether to use distributed training. Defaults to False.
        Returns:
            Updated training states, rng and metrics.
        """
        raise NotImplementedError

    @staticmethod
    def eval_step(
        state: train_state.TrainState,
        batch: Any,
        rng: Union[Any, jnp.ndarray],
        *args,
        **kwargs,
    ) -> Tuple[Union[Any, jnp.ndarray], Dict[str, float]]:
        """Evaluate model on a single batch.
        Args:
            state (TrainState): Current training state.
            batch (Any): Batch of data.
            rng (Union[Any, jnp.ndarray]): Random number generator.
        Returns:
            New rng and metrics.

        """
        raise NotImplementedError

    def eval_model(self, data_loader: utils.DataLoader) -> Dict[str, float]:
        """Test model on all images of a data loader and return avg metrics.
        Args:
            data_loader (utils.DataLoader): Data loader to evaluate on.
        Returns:
            Dictionary with all metrics.
        """
        metrics: Dict[str, float] = defaultdict(float)
        count = 0
        for batch in tqdm(data_loader(), desc="Evaluating", leave=False):
            batch_metrics: Dict[str, float]
            self.main_rng, batch_metrics = self.eval_step(  # type: ignore
                state=self.state, batch=batch, rng=self.main_rng
            )
            batch_size = (batch[0] if isinstance(batch, (tuple, list)) else batch).shape[0]
            count += batch_size
            for key in batch_metrics:
                metrics[key] += batch_metrics[key] * batch_size
        metrics = {key: metrics[key] / count for key in metrics}
        return metrics

    def save_model(self, step: Optional[int] = None):
        """Save current model.
        Args:
            step (int, optional): Current step. Defaults to None.
        """
        step = step if step is not None else self.module_config.num_epochs
        logger.info("Saving model 📠 for step %d", step)
        checkpoints.save_checkpoint(
            ckpt_dir=self.save_dir,
            target={"state": self.state, "step": step, "rng": self.main_rng},
            step=step,
            overwrite=True,
        )

    def push_model_to_hub(
        self, repo_id: str, commit_message: str = "Saving weights and logs"
    ) -> None:
        """Push model to huggingface hub.
        Args:
            repo_id (str): Repository id to push to Hugging Face.
            commit_message (str, optional): Commit message. Defaults to "Saving weights and logs".
        """
        logger.info("Pushing model to Hugging Face Hub 🚀")
        self.model.push_to_hub(
            repo_id=repo_id,
            commit_message=commit_message,
        )

    def load_model(self):
        """Load model."""
        logger.info("Loading model 🧠")
        load_dict = checkpoints.restore_checkpoint(ckpt_dir=self.save_dir, target=None)
        state_dict = load_dict["state"]
        self.state = train_state.TrainState(
            apply_fn=self.state.apply_fn,
            params=state_dict["params"],
            step=state_dict["step"],
            tx=self.state.tx if self.state.tx else self.module_config.optimizer,
            opt_state=state_dict["opt_state"],
        )
        self.main_rng = load_dict["rng"]
        self.model.params = self.state.params
        if load_dict["step"] is not None:
            self.start_step = load_dict["step"]
            if self.start_step >= self.module_config.num_epochs:
                logger.info("Model is already trained 🎉")

    def checkpoint_exists(self) -> bool:
        """Check whether a pretrained model exist.
        Returns:
            True if model exists, False otherwise.
        """
        return os.path.exists(self.save_dir) and len(os.listdir(self.log_dir)) > 0

__init__(module_config, model_class)

Module for summarizing all common training functionalities.

Parameters:

Name Type Description Default
module_config TrainConfig

Configuration for training with all hyperparameters and train module parameters.

required
model_class FlaxPreTrainedModel

Model class to be trained.

required
Source code in modules/training.py
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
def __init__(self, module_config: config.TrainConfig, model_class: FlaxPreTrainedModel):
    """Module for summarizing all common training functionalities.

    Args:
        module_config (TrainConfig): Configuration for training
            with all hyperparameters and train module parameters.
        model_class (FlaxPreTrainedModel): Model class to be trained.
    """
    super().__init__()
    self.module_config = module_config
    self.eval_key = module_config.monitor
    self.main_rng = jax.random.PRNGKey(self.module_config.seed)
    self.generate_callback: Optional[GenerateCallback] = None
    self.generator_callback_rng: Union[Any, jnp.ndarray]
    self.main_rng, self.generator_callback_rng = jax.random.split(self.main_rng)
    # Set model name
    self.model_name = self.module_config.model_name
    self.model_class = model_class
    self.model = model_class(
        self.module_config.model_hparams,
        input_shape=(1,) + self.module_config.input_shape,
        seed=self.module_config.seed,
        dtype=self.module_config.dtype,
    )

    # Set training parameters
    self.state = train_state.TrainState(
        step=0,
        apply_fn=self.model.__call__,
        params=self.model.params,
        tx=None,
        opt_state=None,
    )
    self.start_step = 0

    # Prepare logging
    self.log_dir: str = os.path.join(self.module_config.log_dir, f"{self.model_name}/")
    self.logger: tf.summary.SummaryWriter = tf.summary.create_file_writer(self.log_dir)
    self.save_dir: str = os.path.join(self.module_config.save_dir, f"{self.model_name}/")
    # Create jitted training and eval functions
    self.create_functions()

checkpoint_exists()

Check whether a pretrained model exist.

Returns:

Type Description
bool

True if model exists, False otherwise.

Source code in modules/training.py
315
316
317
318
319
320
def checkpoint_exists(self) -> bool:
    """Check whether a pretrained model exist.
    Returns:
        True if model exists, False otherwise.
    """
    return os.path.exists(self.save_dir) and len(os.listdir(self.log_dir)) > 0

create_functions()

To be implemented in sub-classes.

Source code in modules/training.py
113
114
115
def create_functions(self):
    """To be implemented in sub-classes."""
    raise NotImplementedError

create_train_state(optimizer)

Initialize training state.

Source code in modules/training.py
124
125
126
127
128
def create_train_state(self, optimizer: optax.GradientTransformation):
    """Initialize training state."""
    self.state = train_state.TrainState.create(
        apply_fn=self.state.apply_fn, params=self.state.params, tx=optimizer
    )

eval_model(data_loader)

Test model on all images of a data loader and return avg metrics.

Parameters:

Name Type Description Default
data_loader utils.DataLoader

Data loader to evaluate on.

required

Returns:

Type Description
Dict[str, float]

Dictionary with all metrics.

Source code in modules/training.py
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
def eval_model(self, data_loader: utils.DataLoader) -> Dict[str, float]:
    """Test model on all images of a data loader and return avg metrics.
    Args:
        data_loader (utils.DataLoader): Data loader to evaluate on.
    Returns:
        Dictionary with all metrics.
    """
    metrics: Dict[str, float] = defaultdict(float)
    count = 0
    for batch in tqdm(data_loader(), desc="Evaluating", leave=False):
        batch_metrics: Dict[str, float]
        self.main_rng, batch_metrics = self.eval_step(  # type: ignore
            state=self.state, batch=batch, rng=self.main_rng
        )
        batch_size = (batch[0] if isinstance(batch, (tuple, list)) else batch).shape[0]
        count += batch_size
        for key in batch_metrics:
            metrics[key] += batch_metrics[key] * batch_size
    metrics = {key: metrics[key] / count for key in metrics}
    return metrics

eval_step(state, batch, rng, *args, **kwargs) staticmethod

Evaluate model on a single batch.

Parameters:

Name Type Description Default
state TrainState

Current training state.

required
batch Any

Batch of data.

required
rng Union[Any, jnp.ndarray]

Random number generator.

required

Returns:

Type Description
Tuple[Union[Any, jnp.ndarray], Dict[str, float]]

New rng and metrics.

Source code in modules/training.py
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
@staticmethod
def eval_step(
    state: train_state.TrainState,
    batch: Any,
    rng: Union[Any, jnp.ndarray],
    *args,
    **kwargs,
) -> Tuple[Union[Any, jnp.ndarray], Dict[str, float]]:
    """Evaluate model on a single batch.
    Args:
        state (TrainState): Current training state.
        batch (Any): Batch of data.
        rng (Union[Any, jnp.ndarray]): Random number generator.
    Returns:
        New rng and metrics.

    """
    raise NotImplementedError

init_optimizer()

Initialize optimizer and scheduler. By default, we decrease the learning rate with cosine annealing.

Source code in modules/training.py
117
118
119
120
121
122
def init_optimizer(self):
    """Initialize optimizer and scheduler.
    By default, we decrease the learning rate with cosine annealing.
    """
    optimizer: optax.GradientTransformation = self.module_config.optimizer
    self.create_train_state(optimizer)

load_model()

Load model.

Source code in modules/training.py
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
def load_model(self):
    """Load model."""
    logger.info("Loading model 🧠")
    load_dict = checkpoints.restore_checkpoint(ckpt_dir=self.save_dir, target=None)
    state_dict = load_dict["state"]
    self.state = train_state.TrainState(
        apply_fn=self.state.apply_fn,
        params=state_dict["params"],
        step=state_dict["step"],
        tx=self.state.tx if self.state.tx else self.module_config.optimizer,
        opt_state=state_dict["opt_state"],
    )
    self.main_rng = load_dict["rng"]
    self.model.params = self.state.params
    if load_dict["step"] is not None:
        self.start_step = load_dict["step"]
        if self.start_step >= self.module_config.num_epochs:
            logger.info("Model is already trained 🎉")

push_model_to_hub(repo_id, commit_message='Saving weights and logs')

Push model to huggingface hub.

Parameters:

Name Type Description Default
repo_id str

Repository id to push to Hugging Face.

required
commit_message str

Commit message. Defaults to "Saving weights and logs".

'Saving weights and logs'
Source code in modules/training.py
282
283
284
285
286
287
288
289
290
291
292
293
294
def push_model_to_hub(
    self, repo_id: str, commit_message: str = "Saving weights and logs"
) -> None:
    """Push model to huggingface hub.
    Args:
        repo_id (str): Repository id to push to Hugging Face.
        commit_message (str, optional): Commit message. Defaults to "Saving weights and logs".
    """
    logger.info("Pushing model to Hugging Face Hub 🚀")
    self.model.push_to_hub(
        repo_id=repo_id,
        commit_message=commit_message,
    )

save_model(step=None)

Save current model.

Parameters:

Name Type Description Default
step int

Current step. Defaults to None.

None
Source code in modules/training.py
268
269
270
271
272
273
274
275
276
277
278
279
280
def save_model(self, step: Optional[int] = None):
    """Save current model.
    Args:
        step (int, optional): Current step. Defaults to None.
    """
    step = step if step is not None else self.module_config.num_epochs
    logger.info("Saving model 📠 for step %d", step)
    checkpoints.save_checkpoint(
        ckpt_dir=self.save_dir,
        target={"state": self.state, "step": step, "rng": self.main_rng},
        step=step,
        overwrite=True,
    )

train_epoch(data_loader, epoch)

Train model for one epoch, and log avg metrics.

Parameters:

Name Type Description Default
data_loader utils.DataLoader

Data loader to train on.

required
epoch int

Current epoch.

required

Returns:

Type Description
Dict[str, float]

Dictionary with all metrics.

Source code in modules/training.py
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
def train_epoch(self, data_loader: utils.DataLoader, epoch: int) -> Dict[str, float]:
    """Train model for one epoch, and log avg metrics.
    Args:
        data_loader (utils.DataLoader): Data loader to train on.
        epoch (int): Current epoch.
    Returns:
        Dictionary with all metrics.
    """
    metrics: Dict[str, float] = defaultdict(float)
    for batch in tqdm(data_loader(), desc="Training", leave=False):
        # ensure that model have actual parameters
        self.model.params = self.state.params
        batch_metrics: Dict[str, float]
        self.state, self.main_rng, batch_metrics = self.train_step(  # type: ignore
            state=self.state,
            batch=batch,
            rng=self.main_rng,
            distributed=self.module_config.distributed,
        )
        for key in batch_metrics:
            metrics[key] += batch_metrics[key]

    count = len(data_loader)
    metrics = {key: metrics[key] / count for key in metrics}
    return metrics

train_model(train_loader, val_loader)

Train model for defined number of epochs.

Parameters:

Name Type Description Default
train_loader utils.DataLoader

Training data loader.

required
val_loader utils.DataLoader

Validation data loader.

required
Source code in modules/training.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
def train_model(self, train_loader: utils.DataLoader, val_loader: utils.DataLoader):
    """Train model for defined number of epochs.
    Args:
        train_loader (utils.DataLoader): Training data loader.
        val_loader (utils.DataLoader): Validation data loader.
    """
    # We first need to create optimizer and the scheduler for the given number of epochs
    self.init_optimizer()
    # Track best eval metric
    logger.info("Starting training 💃")
    best_eval = None
    with self.logger.as_default():
        for epoch_idx in range(1 + self.start_step, self.module_config.num_epochs + 1):
            logger.info("Epoch: %d", epoch_idx)
            train_metrics = self.train_epoch(train_loader, epoch=epoch_idx)
            train_metrics_str = "Training metrics:"
            for key in train_metrics:
                tf.summary.scalar(f"train/{key}", train_metrics[key], step=epoch_idx)
                train_metrics_str += f" {key}: {train_metrics[key]:.4f},"
            logger.info(train_metrics_str)
            if epoch_idx % self.module_config.check_val_every_n_epoch == 0:
                eval_metrics = self.eval_model(val_loader)
                eval_metrics_str = "Evaluation metrics:"
                for key in eval_metrics:
                    tf.summary.scalar(f"val/{key}", eval_metrics[key], step=epoch_idx)
                    eval_metrics_str += f" {key}: {eval_metrics[key]:.4f},"
                logger.info(eval_metrics_str)
                if best_eval is None or eval_metrics[self.eval_key] > best_eval:
                    best_eval = eval_metrics[self.eval_key]
                    self.save_model(epoch_idx)

            if self.generate_callback is None:
                for batch in train_loader():
                    if len(batch) > 8:
                        batch = batch[:8]
                    self.generate_callback = GenerateCallback(
                        batch,
                        self.generator_callback_rng,
                        every_n_epochs=self.module_config.log_img_every_n_epoch,
                    )
                    del self.generator_callback_rng
                    break

            if self.generate_callback is not None:
                self.generate_callback.log_generations(
                    self.model, self.state, self.logger, epoch=epoch_idx
                )

            self.logger.flush()
    logger.info("Finished training ✅ with best eval metric: %f 😎", best_eval)

train_step(state, batch, rng, distributed=False, *args, **kwargs) staticmethod

Train model on a single batch.

Parameters:

Name Type Description Default
state TrainState

Current training state.

required
batch Any

Batch of data.

required
rng Union[Any, jnp.ndarray]

Random number generator.

required
distributed bool

Whether to use distributed training. Defaults to False.

False

Returns:

Type Description
Any

Updated training states, rng and metrics.

Source code in modules/training.py
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
@staticmethod
def train_step(
    state: train_state.TrainState,
    batch: Any,
    rng: Union[Any, jnp.ndarray],
    distributed: bool = False,
    *args,
    **kwargs,
) -> Any:
    """Train model on a single batch.

    Args:
        state (TrainState): Current training state.
        batch (Any): Batch of data.
        rng (Union[Any, jnp.ndarray]): Random number generator.
        distributed (bool, optional): Whether to use distributed training. Defaults to False.
    Returns:
        Updated training states, rng and metrics.
    """
    raise NotImplementedError

TrainerVQGan

Bases: TrainerModule

Helper functions for training VQGAN.

Source code in modules/training.py
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
class TrainerVQGan(TrainerModule):
    """Helper functions for training VQGAN."""

    def __init__(self, module_config: config.TrainConfig):
        # Initialize parent class
        self.module_config = module_config
        self.model_name = self.module_config.model_name
        if self.module_config.recon_loss == "l2":
            self.recon_loss_fn = losses.l2_loss
        elif self.module_config.recon_loss == "l1":
            self.recon_loss_fn = losses.l1_loss
        elif self.module_config.recon_loss == "combo":
            self.recon_loss_fn = losses.combo_loss
        elif self.module_config.recon_loss == "mape":
            self.recon_loss_fn = losses.mape_loss
        else:
            logger.warning(
                f"""Reconstruction loss function {self.module_config.recon_loss} not supported.
                Will be used default l1 loss instead."""
            )
            self.recon_loss_fn = losses.l1_loss
        # Train state for discriminator
        self.model_disc: FlaxPreTrainedModel = vqgan.VQGanDiscriminator(
            self.module_config.disc_hparams
        )
        if self.module_config.disc_loss == "vanilla":
            self.disc_loss_fn = losses.disc_loss_vanilla
        elif self.module_config.disc_loss == "hinge":
            self.disc_loss_fn = losses.disc_loss_hinge
        else:
            logger.warning(
                f"""Discriminator loss function {self.module_config.disc_loss} not supported.
                Will be used default hinge loss instead."""
            )
            self.disc_loss_fn = losses.disc_loss_hinge
        self.state_disc = TrainStateDisc(
            step=0,
            apply_fn=self.model_disc.__call__,
            params=self.model_disc.params["params"],
            batch_stats=self.model_disc.params["batch_stats"],
            tx=None,
            opt_state=None,
        )
        # Log dir discriminator loss
        self.save_dir_disc: str = os.path.join(
            self.module_config.save_dir, f"{self.model_name}_disc/"
        )
        self.temp_scheduler: Optional[Callable] = self.module_config.temp_scheduler

        super().__init__(module_config=module_config, model_class=vqgan.VQModel)

    def temperature_scheduling(self, epoch: int) -> float:
        """Temperature scheduling.
        Args:
            epoch (int): Current epoch.
        Returns:
            Temperature.
        """
        if self.temp_scheduler is not None:
            return self.model.update_temperature(self.temp_scheduler(epoch))
        else:
            return self.module_config.model_hparams.gumb_temp

    def init_optimizer(self):
        """Initialize optimizer and scheduler also for discriminator.
        By default, we decrease the learning rate with cosine annealing.
        """
        optimizer: optax.GradientTransformation = self.module_config.optimizer
        optimizer_disc: optax.GradientTransformation = self.module_config.optimizer_disc
        self.create_train_stat_full(optimizer, optimizer_disc)

    def create_train_stat_full(
        self,
        optimizer: optax.GradientTransformation,
        optimizer_disc: optax.GradientTransformation,
    ):
        """Initialize training state.
        Args:
            optimizer (optax.GradientTransformation): Optimizer for generator.
            optimizer_disc (optax.GradientTransformation): Optimizer for discriminator.
        """
        self.state = train_state.TrainState.create(
            apply_fn=self.state.apply_fn, params=self.state.params, tx=optimizer
        )
        self.state_disc = TrainStateDisc.create(
            apply_fn=self.state_disc.apply_fn,
            params=self.state_disc.params,
            batch_stats=self.state_disc.batch_stats,
            tx=optimizer_disc,
        )

    def save_model(self, step: Optional[int] = None):
        """Save current model.
        Args:
            step (int, optional): Current step. Defaults to None.
        """
        step = step if step is not None else self.module_config.num_epochs
        super().save_model(step=step)
        checkpoints.save_checkpoint(
            ckpt_dir=self.save_dir_disc, target=self.state_disc, step=step, overwrite=True
        )

    def load_model(self):
        """Load model."""
        super().load_model()
        state_dict = checkpoints.restore_checkpoint(ckpt_dir=self.save_dir_disc, target=None)
        self.state_disc = TrainStateDisc(
            apply_fn=self.state_disc.apply_fn,
            params=state_dict["params"],
            batch_stats=state_dict["batch_stats"],
            step=state_dict["step"],
            tx=self.state_disc.tx if self.state_disc.tx else self.module_config.optimizer_disc,
            opt_state=state_dict["opt_state"],
        )
        self.model_disc.params["params"] = self.state_disc.params
        self.model_disc.params["batch_stats"] = self.state_disc.batch_stats

    def checkpoint_exists(self) -> bool:
        """Check whether a pretrained model exist.
        Returns:
            True if model and discriminator exists, False otherwise.
        """
        main_model: bool = os.path.exists(self.save_dir) and len(os.listdir(self.log_dir)) > 0
        disc_model: bool = os.path.exists(self.save_dir_disc) and len(os.listdir(self.log_dir)) > 0
        return main_model and disc_model

    def train_epoch(self, data_loader: utils.DataLoader, epoch: int) -> Dict[str, float]:
        """Train model for one epoch, and log avg metrics.
        Args:
            data_loader (utils.DataLoader): Data loader to train on.
        Returns:
            Dictionary with all metrics.
        """
        metrics: Dict[str, float] = defaultdict(float)
        metrics_disc = defaultdict(float)
        metrics_disc["step"] = 0.0
        new_temp: float = self.temperature_scheduling(epoch - 1)
        for batch in tqdm(data_loader(), desc="Training", leave=False):
            batch_metrics: Dict[str, float]
            train_outs = self.train_step(  # type: ignore
                state=self.state,
                disc_state=self.state_disc,
                batch=batch,
                rng=self.main_rng,
                optimizer_idx=0,
                disc_use=self.module_config.disc_start > epoch,
                distributed=self.module_config.distributed,
            )
            self.state, self.state_disc, self.main_rng, batch_metrics = train_outs
            # Update metrics
            for key, value in batch_metrics.items():
                metrics[key] += value

            if self.module_config.disc_start > epoch:
                batch_metrics_disc: Dict[str, float]
                train_outs = self.train_step(  # type: ignore
                    state=self.state,
                    disc_state=self.state_disc,
                    batch=batch,
                    rng=self.main_rng,
                    optimizer_idx=1,
                    disc_use=True,
                    distributed=self.module_config.distributed,
                )
                self.state, self.state_disc, self.main_rng, batch_metrics_disc = train_outs
                # Update metrics discriminator
                for key, value in batch_metrics_disc.items():
                    metrics_disc[key] += value
                metrics_disc["step"] += 1.0

            # ensure that model have actual parameters
            self.model.params = self.state.params
            self.model_disc.params["params"] = self.state_disc.params
            self.model_disc.params["batch_stats"] = self.state_disc.batch_stats

        count = len(data_loader)
        metrics = {key: metrics[key] / count for key in metrics}
        metrics["temp"] = new_temp
        count_disc = metrics_disc["step"]
        del metrics_disc["step"]
        metrics_disc_resized: Dict[str, float] = {
            key: metrics_disc[key] / count_disc for key in metrics_disc
        }
        # merge metrics
        for key, value in metrics_disc_resized.items():
            metrics[key] = value

        return metrics

    def create_functions(self):
        """Create training and eval functions."""
        recon_loss_fn: Callable = self.recon_loss_fn
        disc_loss_fn: Callable = self.disc_loss_fn

        def calculate_loss_autoencoder(
            params: FrozenDict[str, Any],
            batch: jnp.ndarray,
            train: bool,
            rng: Union[Any, jnp.ndarray],
            disc_use: bool,
            disc_variables: TrainStateDisc,
        ) -> Tuple[Any, Tuple[Dict[str, Any], Union[Any, jnp.ndarray], Any]]:
            """Function to calculate the loss autoencoder for a batch of images."""
            new_rng, gumble_apply_rng, dropout_apply_rng = jax.random.split(rng, num=3)
            outs = self.model(
                batch,
                params=params,
                dropout_rng=dropout_apply_rng,
                gumble_rng=gumble_apply_rng,
                train=train,
            )
            x_recon, z_q, codebook_loss, indices = outs
            # for now we will use l1 loss than it will be combined with perceptual loss
            rec_loss = recon_loss_fn(x_recon, batch)
            nll_loss = jnp.mean(rec_loss)

            # Generator loss (autoencode)
            outs = self.model_disc(
                x_recon,
                params=disc_variables.params,
                batch_stats=disc_variables.batch_stats,
                train=train,
            )
            logits_fake, new_model_state = outs if train else (outs, None)
            # Original loss is
            # g_loss = -jnp.mean(logits_fake)
            # But we think that based on disc for generator should work we will use minimax loss
            # This loss is none negative and tries to maximize the probability of the fake_logits.
            g_loss = jnp.mean(jnp.maximum(1.0 - logits_fake, 0.0))
            disc_factor = 0.0
            if disc_use:
                disc_factor = self.module_config.disc_weight
            disc_factor = jax.lax.cond(
                disc_use, lambda _: self.module_config.disc_weight, lambda _: 0.0, None
            )
            loss = (
                nll_loss
                + disc_factor * g_loss
                + self.module_config.codebook_weight * codebook_loss.mean()
            )

            metrics = {
                "total_loss": loss,
                "quant_loss": jnp.mean(codebook_loss),
                "nll_loss": nll_loss,
                "rec_loss": jnp.mean(rec_loss),
                "g_loss": g_loss,
            }

            return loss, (metrics, new_rng, new_model_state)

        def calculate_loss_disc(
            params: FrozenDict[str, Any],
            batch: jnp.ndarray,
            train: bool,
            rng: Union[Any, jnp.ndarray],
            disc_use: bool,
            batch_stats: FrozenDict[str, Any],
            model_params: Optional[FrozenDict[str, Any]],
        ) -> Tuple[Any, Tuple[Dict[str, Any], Union[Any, jnp.ndarray], Any]]:
            """Function to calculate the loss discriminator for a batch of images."""
            new_rng, gumble_apply_rng, dropout_apply_rng = jax.random.split(rng, num=3)
            outs = self.model(
                batch,
                params=model_params,
                dropout_rng=dropout_apply_rng,
                gumble_rng=gumble_apply_rng,
                train=train,
            )
            x_recon, z_q, codebook_loss, indices = outs

            # Discriminator loss
            outs = self.model_disc(batch, params=params, batch_stats=batch_stats, train=train)
            logits_real, new_model_state = outs if train else (outs, None)
            outs = self.model_disc(x_recon, params=params, batch_stats=batch_stats, train=train)
            logits_fake, new_model_state = outs if train else (outs, None)
            disc_factor = jax.lax.cond(
                disc_use, lambda _: self.module_config.disc_weight, lambda _: 0.0, None
            )
            loss = disc_factor * disc_loss_fn(logits_real, logits_fake)
            metrics = {
                "disc_loss": loss,
                "logits_real": logits_real.mean(),
                "logits_fake": logits_fake.mean(),
            }

            return loss, (metrics, new_rng, new_model_state)

        def train_step_autoencoder(
            state: train_state.TrainState,
            disc_state: TrainStateDisc,
            batch: jnp.ndarray,
            rng: Union[Any, jnp.ndarray],
            disc_use: bool,
            distributed: bool = False,
        ) -> Tuple[
            train_state.TrainState, TrainStateDisc, Union[Any, jnp.ndarray], Dict[str, float]
        ]:
            """Train step for autoencoder."""
            loss_fn = partial(
                calculate_loss_autoencoder,
                batch=batch,
                train=True,
                rng=rng,
                disc_use=disc_use,
                disc_variables=disc_state,
            )
            (_, (metrics, new_rng, new_model_state)), grads = jax.value_and_grad(
                loss_fn, has_aux=True
            )(state.params)
            # if distributed training, average grads
            if distributed:
                grads = jax.lax.pmean(grads, axis_name="batch")
            # Update parameters
            state = state.apply_gradients(grads=grads)
            return state, disc_state, new_rng, metrics

        def train_step_disc(
            state: train_state.TrainState,
            disc_state: TrainStateDisc,
            batch: jnp.ndarray,
            rng: Union[Any, jnp.ndarray],
            disc_use: bool,
            distributed: bool = False,
        ) -> Tuple[
            train_state.TrainState, TrainStateDisc, Union[Any, jnp.ndarray], Dict[str, float]
        ]:
            """Train step for discriminator."""
            loss_fn = partial(
                calculate_loss_disc,
                batch=batch,
                train=True,
                rng=rng,
                disc_use=disc_use,
                batch_stats=disc_state.batch_stats,
                model_params=state.params,
            )
            (_, (metrics, new_rng, new_model_state)), grads = jax.value_and_grad(
                loss_fn, has_aux=True
            )(disc_state.params)
            # if distributed training, average grads
            if distributed:
                grads = jax.lax.pmean(grads, axis_name="batch")
            # Update parameters, batch statistics
            disc_state = disc_state.apply_gradients(
                grads=grads, batch_stats=new_model_state["batch_stats"]
            )
            return state, disc_state, new_rng, metrics

        def train_step(
            state: train_state.TrainState,
            disc_state: TrainStateDisc,
            batch: jnp.ndarray,
            rng: Union[Any, jnp.ndarray],
            optimizer_idx: int,
            disc_use: bool,
            distributed: bool,
        ) -> Tuple[
            train_state.TrainState, TrainStateDisc, Union[Any, jnp.ndarray], Dict[str, float]
        ]:
            """Train model on a single batch."""
            # calculate loss
            if optimizer_idx == 0:
                outs = train_step_autoencoder(state, disc_state, batch, rng, disc_use, distributed)
            else:
                outs = train_step_disc(state, disc_state, batch, rng, disc_use, distributed)
            # outs = jax.lax.cond(
            #  optimizer_idx == 0,
            #    lambda _: train_step_autoencoder(state,
            #                                     disc_state,
            #                                     batch,
            #                                     rng,
            #                                     disc_use,
            #                                     distributed),
            #    lambda _: train_step_disc(state,
            #                              disc_state,
            #                              batch,
            #                              rng,
            #                              disc_use,
            #                              distributed),
            #    None)
            state, disc_state, new_rng, metrics = outs
            return state, disc_state, new_rng, metrics

        def eval_step(
            state: train_state.TrainState,
            batch: jnp.ndarray,
            disc_state: TrainStateDisc,
            rng: Union[Any, jnp.ndarray],
        ) -> Tuple[Union[Any, jnp.ndarray], Dict[str, float]]:
            """Evaluate model on a single batch."""
            _, (metrics, new_rng, _) = calculate_loss_autoencoder(
                state.params,
                batch=batch,
                train=False,
                rng=rng,
                disc_use=False,
                disc_variables=disc_state,
            )
            return new_rng, metrics

        # pmap or jit for efficiency
        if self.module_config.distributed:
            self.train_step = jax.pmap(  # type: ignore
                train_step, axis_name="batch", static_broadcasted_argnums=(4, 5, 6)
            )
            self.eval_step = jax.jit(partial(eval_step, disc_state=self.state_disc))  # type: ignore
        else:
            self.train_step = jax.jit(train_step, static_argnums=(4, 5, 6))  # type: ignore
            self.eval_step = jax.jit(partial(eval_step, disc_state=self.state_disc))  # type: ignore

checkpoint_exists()

Check whether a pretrained model exist.

Returns:

Type Description
bool

True if model and discriminator exists, False otherwise.

Source code in modules/training.py
454
455
456
457
458
459
460
461
def checkpoint_exists(self) -> bool:
    """Check whether a pretrained model exist.
    Returns:
        True if model and discriminator exists, False otherwise.
    """
    main_model: bool = os.path.exists(self.save_dir) and len(os.listdir(self.log_dir)) > 0
    disc_model: bool = os.path.exists(self.save_dir_disc) and len(os.listdir(self.log_dir)) > 0
    return main_model and disc_model

create_functions()

Create training and eval functions.

Source code in modules/training.py
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
def create_functions(self):
    """Create training and eval functions."""
    recon_loss_fn: Callable = self.recon_loss_fn
    disc_loss_fn: Callable = self.disc_loss_fn

    def calculate_loss_autoencoder(
        params: FrozenDict[str, Any],
        batch: jnp.ndarray,
        train: bool,
        rng: Union[Any, jnp.ndarray],
        disc_use: bool,
        disc_variables: TrainStateDisc,
    ) -> Tuple[Any, Tuple[Dict[str, Any], Union[Any, jnp.ndarray], Any]]:
        """Function to calculate the loss autoencoder for a batch of images."""
        new_rng, gumble_apply_rng, dropout_apply_rng = jax.random.split(rng, num=3)
        outs = self.model(
            batch,
            params=params,
            dropout_rng=dropout_apply_rng,
            gumble_rng=gumble_apply_rng,
            train=train,
        )
        x_recon, z_q, codebook_loss, indices = outs
        # for now we will use l1 loss than it will be combined with perceptual loss
        rec_loss = recon_loss_fn(x_recon, batch)
        nll_loss = jnp.mean(rec_loss)

        # Generator loss (autoencode)
        outs = self.model_disc(
            x_recon,
            params=disc_variables.params,
            batch_stats=disc_variables.batch_stats,
            train=train,
        )
        logits_fake, new_model_state = outs if train else (outs, None)
        # Original loss is
        # g_loss = -jnp.mean(logits_fake)
        # But we think that based on disc for generator should work we will use minimax loss
        # This loss is none negative and tries to maximize the probability of the fake_logits.
        g_loss = jnp.mean(jnp.maximum(1.0 - logits_fake, 0.0))
        disc_factor = 0.0
        if disc_use:
            disc_factor = self.module_config.disc_weight
        disc_factor = jax.lax.cond(
            disc_use, lambda _: self.module_config.disc_weight, lambda _: 0.0, None
        )
        loss = (
            nll_loss
            + disc_factor * g_loss
            + self.module_config.codebook_weight * codebook_loss.mean()
        )

        metrics = {
            "total_loss": loss,
            "quant_loss": jnp.mean(codebook_loss),
            "nll_loss": nll_loss,
            "rec_loss": jnp.mean(rec_loss),
            "g_loss": g_loss,
        }

        return loss, (metrics, new_rng, new_model_state)

    def calculate_loss_disc(
        params: FrozenDict[str, Any],
        batch: jnp.ndarray,
        train: bool,
        rng: Union[Any, jnp.ndarray],
        disc_use: bool,
        batch_stats: FrozenDict[str, Any],
        model_params: Optional[FrozenDict[str, Any]],
    ) -> Tuple[Any, Tuple[Dict[str, Any], Union[Any, jnp.ndarray], Any]]:
        """Function to calculate the loss discriminator for a batch of images."""
        new_rng, gumble_apply_rng, dropout_apply_rng = jax.random.split(rng, num=3)
        outs = self.model(
            batch,
            params=model_params,
            dropout_rng=dropout_apply_rng,
            gumble_rng=gumble_apply_rng,
            train=train,
        )
        x_recon, z_q, codebook_loss, indices = outs

        # Discriminator loss
        outs = self.model_disc(batch, params=params, batch_stats=batch_stats, train=train)
        logits_real, new_model_state = outs if train else (outs, None)
        outs = self.model_disc(x_recon, params=params, batch_stats=batch_stats, train=train)
        logits_fake, new_model_state = outs if train else (outs, None)
        disc_factor = jax.lax.cond(
            disc_use, lambda _: self.module_config.disc_weight, lambda _: 0.0, None
        )
        loss = disc_factor * disc_loss_fn(logits_real, logits_fake)
        metrics = {
            "disc_loss": loss,
            "logits_real": logits_real.mean(),
            "logits_fake": logits_fake.mean(),
        }

        return loss, (metrics, new_rng, new_model_state)

    def train_step_autoencoder(
        state: train_state.TrainState,
        disc_state: TrainStateDisc,
        batch: jnp.ndarray,
        rng: Union[Any, jnp.ndarray],
        disc_use: bool,
        distributed: bool = False,
    ) -> Tuple[
        train_state.TrainState, TrainStateDisc, Union[Any, jnp.ndarray], Dict[str, float]
    ]:
        """Train step for autoencoder."""
        loss_fn = partial(
            calculate_loss_autoencoder,
            batch=batch,
            train=True,
            rng=rng,
            disc_use=disc_use,
            disc_variables=disc_state,
        )
        (_, (metrics, new_rng, new_model_state)), grads = jax.value_and_grad(
            loss_fn, has_aux=True
        )(state.params)
        # if distributed training, average grads
        if distributed:
            grads = jax.lax.pmean(grads, axis_name="batch")
        # Update parameters
        state = state.apply_gradients(grads=grads)
        return state, disc_state, new_rng, metrics

    def train_step_disc(
        state: train_state.TrainState,
        disc_state: TrainStateDisc,
        batch: jnp.ndarray,
        rng: Union[Any, jnp.ndarray],
        disc_use: bool,
        distributed: bool = False,
    ) -> Tuple[
        train_state.TrainState, TrainStateDisc, Union[Any, jnp.ndarray], Dict[str, float]
    ]:
        """Train step for discriminator."""
        loss_fn = partial(
            calculate_loss_disc,
            batch=batch,
            train=True,
            rng=rng,
            disc_use=disc_use,
            batch_stats=disc_state.batch_stats,
            model_params=state.params,
        )
        (_, (metrics, new_rng, new_model_state)), grads = jax.value_and_grad(
            loss_fn, has_aux=True
        )(disc_state.params)
        # if distributed training, average grads
        if distributed:
            grads = jax.lax.pmean(grads, axis_name="batch")
        # Update parameters, batch statistics
        disc_state = disc_state.apply_gradients(
            grads=grads, batch_stats=new_model_state["batch_stats"]
        )
        return state, disc_state, new_rng, metrics

    def train_step(
        state: train_state.TrainState,
        disc_state: TrainStateDisc,
        batch: jnp.ndarray,
        rng: Union[Any, jnp.ndarray],
        optimizer_idx: int,
        disc_use: bool,
        distributed: bool,
    ) -> Tuple[
        train_state.TrainState, TrainStateDisc, Union[Any, jnp.ndarray], Dict[str, float]
    ]:
        """Train model on a single batch."""
        # calculate loss
        if optimizer_idx == 0:
            outs = train_step_autoencoder(state, disc_state, batch, rng, disc_use, distributed)
        else:
            outs = train_step_disc(state, disc_state, batch, rng, disc_use, distributed)
        # outs = jax.lax.cond(
        #  optimizer_idx == 0,
        #    lambda _: train_step_autoencoder(state,
        #                                     disc_state,
        #                                     batch,
        #                                     rng,
        #                                     disc_use,
        #                                     distributed),
        #    lambda _: train_step_disc(state,
        #                              disc_state,
        #                              batch,
        #                              rng,
        #                              disc_use,
        #                              distributed),
        #    None)
        state, disc_state, new_rng, metrics = outs
        return state, disc_state, new_rng, metrics

    def eval_step(
        state: train_state.TrainState,
        batch: jnp.ndarray,
        disc_state: TrainStateDisc,
        rng: Union[Any, jnp.ndarray],
    ) -> Tuple[Union[Any, jnp.ndarray], Dict[str, float]]:
        """Evaluate model on a single batch."""
        _, (metrics, new_rng, _) = calculate_loss_autoencoder(
            state.params,
            batch=batch,
            train=False,
            rng=rng,
            disc_use=False,
            disc_variables=disc_state,
        )
        return new_rng, metrics

    # pmap or jit for efficiency
    if self.module_config.distributed:
        self.train_step = jax.pmap(  # type: ignore
            train_step, axis_name="batch", static_broadcasted_argnums=(4, 5, 6)
        )
        self.eval_step = jax.jit(partial(eval_step, disc_state=self.state_disc))  # type: ignore
    else:
        self.train_step = jax.jit(train_step, static_argnums=(4, 5, 6))  # type: ignore
        self.eval_step = jax.jit(partial(eval_step, disc_state=self.state_disc))  # type: ignore

create_train_stat_full(optimizer, optimizer_disc)

Initialize training state.

Parameters:

Name Type Description Default
optimizer optax.GradientTransformation

Optimizer for generator.

required
optimizer_disc optax.GradientTransformation

Optimizer for discriminator.

required
Source code in modules/training.py
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
def create_train_stat_full(
    self,
    optimizer: optax.GradientTransformation,
    optimizer_disc: optax.GradientTransformation,
):
    """Initialize training state.
    Args:
        optimizer (optax.GradientTransformation): Optimizer for generator.
        optimizer_disc (optax.GradientTransformation): Optimizer for discriminator.
    """
    self.state = train_state.TrainState.create(
        apply_fn=self.state.apply_fn, params=self.state.params, tx=optimizer
    )
    self.state_disc = TrainStateDisc.create(
        apply_fn=self.state_disc.apply_fn,
        params=self.state_disc.params,
        batch_stats=self.state_disc.batch_stats,
        tx=optimizer_disc,
    )

init_optimizer()

Initialize optimizer and scheduler also for discriminator. By default, we decrease the learning rate with cosine annealing.

Source code in modules/training.py
400
401
402
403
404
405
406
def init_optimizer(self):
    """Initialize optimizer and scheduler also for discriminator.
    By default, we decrease the learning rate with cosine annealing.
    """
    optimizer: optax.GradientTransformation = self.module_config.optimizer
    optimizer_disc: optax.GradientTransformation = self.module_config.optimizer_disc
    self.create_train_stat_full(optimizer, optimizer_disc)

load_model()

Load model.

Source code in modules/training.py
439
440
441
442
443
444
445
446
447
448
449
450
451
452
def load_model(self):
    """Load model."""
    super().load_model()
    state_dict = checkpoints.restore_checkpoint(ckpt_dir=self.save_dir_disc, target=None)
    self.state_disc = TrainStateDisc(
        apply_fn=self.state_disc.apply_fn,
        params=state_dict["params"],
        batch_stats=state_dict["batch_stats"],
        step=state_dict["step"],
        tx=self.state_disc.tx if self.state_disc.tx else self.module_config.optimizer_disc,
        opt_state=state_dict["opt_state"],
    )
    self.model_disc.params["params"] = self.state_disc.params
    self.model_disc.params["batch_stats"] = self.state_disc.batch_stats

save_model(step=None)

Save current model.

Parameters:

Name Type Description Default
step int

Current step. Defaults to None.

None
Source code in modules/training.py
428
429
430
431
432
433
434
435
436
437
def save_model(self, step: Optional[int] = None):
    """Save current model.
    Args:
        step (int, optional): Current step. Defaults to None.
    """
    step = step if step is not None else self.module_config.num_epochs
    super().save_model(step=step)
    checkpoints.save_checkpoint(
        ckpt_dir=self.save_dir_disc, target=self.state_disc, step=step, overwrite=True
    )

temperature_scheduling(epoch)

Temperature scheduling.

Parameters:

Name Type Description Default
epoch int

Current epoch.

required

Returns:

Type Description
float

Temperature.

Source code in modules/training.py
388
389
390
391
392
393
394
395
396
397
398
def temperature_scheduling(self, epoch: int) -> float:
    """Temperature scheduling.
    Args:
        epoch (int): Current epoch.
    Returns:
        Temperature.
    """
    if self.temp_scheduler is not None:
        return self.model.update_temperature(self.temp_scheduler(epoch))
    else:
        return self.module_config.model_hparams.gumb_temp

train_epoch(data_loader, epoch)

Train model for one epoch, and log avg metrics.

Parameters:

Name Type Description Default
data_loader utils.DataLoader

Data loader to train on.

required

Returns:

Type Description
Dict[str, float]

Dictionary with all metrics.

Source code in modules/training.py
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
def train_epoch(self, data_loader: utils.DataLoader, epoch: int) -> Dict[str, float]:
    """Train model for one epoch, and log avg metrics.
    Args:
        data_loader (utils.DataLoader): Data loader to train on.
    Returns:
        Dictionary with all metrics.
    """
    metrics: Dict[str, float] = defaultdict(float)
    metrics_disc = defaultdict(float)
    metrics_disc["step"] = 0.0
    new_temp: float = self.temperature_scheduling(epoch - 1)
    for batch in tqdm(data_loader(), desc="Training", leave=False):
        batch_metrics: Dict[str, float]
        train_outs = self.train_step(  # type: ignore
            state=self.state,
            disc_state=self.state_disc,
            batch=batch,
            rng=self.main_rng,
            optimizer_idx=0,
            disc_use=self.module_config.disc_start > epoch,
            distributed=self.module_config.distributed,
        )
        self.state, self.state_disc, self.main_rng, batch_metrics = train_outs
        # Update metrics
        for key, value in batch_metrics.items():
            metrics[key] += value

        if self.module_config.disc_start > epoch:
            batch_metrics_disc: Dict[str, float]
            train_outs = self.train_step(  # type: ignore
                state=self.state,
                disc_state=self.state_disc,
                batch=batch,
                rng=self.main_rng,
                optimizer_idx=1,
                disc_use=True,
                distributed=self.module_config.distributed,
            )
            self.state, self.state_disc, self.main_rng, batch_metrics_disc = train_outs
            # Update metrics discriminator
            for key, value in batch_metrics_disc.items():
                metrics_disc[key] += value
            metrics_disc["step"] += 1.0

        # ensure that model have actual parameters
        self.model.params = self.state.params
        self.model_disc.params["params"] = self.state_disc.params
        self.model_disc.params["batch_stats"] = self.state_disc.batch_stats

    count = len(data_loader)
    metrics = {key: metrics[key] / count for key in metrics}
    metrics["temp"] = new_temp
    count_disc = metrics_disc["step"]
    del metrics_disc["step"]
    metrics_disc_resized: Dict[str, float] = {
        key: metrics_disc[key] / count_disc for key in metrics_disc
    }
    # merge metrics
    for key, value in metrics_disc_resized.items():
        metrics[key] = value

    return metrics

Utils

BaseDataset

Bases: ABC

Load the dataset. Abstract method.

Source code in modules/utils.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
class BaseDataset(ABC):
    """Load the dataset. Abstract method."""

    def __init__(self, train: bool, dtype: jnp.dtype, config: DataConfig) -> None:
        """Set the dataset.
        Args:
            train: If the dataset is for training.
            config: The config for the dataset.
        """
        self.root: str = config.dataset_root
        self.use_transforms: bool = True if train else False
        if self.use_transforms and config.transform is None:
            raise ValueError("Transforms must be provided for training.")
        self.transforms = A.from_dict(config.transform)
        self.image_size: int = config.size
        self.dataset_name = config.dataset_name
        self.dtype = dtype
        self.params = config.train_params if train else config.test_params
        self.dataset: tf.data.Dataset = self.load_dataset(train)
        assert len(self.dataset) > 0, "Dataset is empty."

    @abstractmethod
    def load_dataset(self, train: bool) -> tf.data.Dataset:
        """Load the dataset.
        Args:
            train: If the dataset is for training.
        Returns:
            The dataset.
        """
        raise NotImplementedError

    def __len__(self) -> int:
        """Return the length of the dataset."""
        return len(self.dataset)

    def _preprocess(self, image: tf.Tensor) -> tf.Tensor:
        """Preprocess the image.
        Args:
            image: The image to preprocess.
        Returns:
            The preprocessed image.
        """

        def aug_fn(image: tf.Tensor) -> tf.Tensor:
            data = {"image": image}
            aug_data = self.transforms(**data)
            aug_img = aug_data["image"]
            aug_img = tf.cast(aug_img, self.dtype) / 255.0
            aug_img = (aug_img - IMAGENET_STANDARD_MEAN) / IMAGENET_STANDARD_STD
            aug_img = tf.image.resize(aug_img, (self.image_size, self.image_size))
            return aug_img

        if self.use_transforms:
            image = tf.numpy_function(func=aug_fn, inp=[image], Tout=self.dtype)
        else:
            image = tf.cast(image, self.dtype) / 255.0
            image = (image - IMAGENET_STANDARD_MEAN) / IMAGENET_STANDARD_STD
            image = tf.image.resize(image, (self.image_size, self.image_size))
        return image

    def get_dataset(self) -> tf.data.Dataset:
        """Return the dataset.
        Returns:
            The dataset.
        """
        dataset = self.dataset.map(self._preprocess)
        dataset = dataset.shuffle(self.params.batch_size * 16) if self.params.shuffle else dataset
        return dataset

__init__(train, dtype, config)

Set the dataset.

Parameters:

Name Type Description Default
train bool

If the dataset is for training.

required
config DataConfig

The config for the dataset.

required
Source code in modules/utils.py
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def __init__(self, train: bool, dtype: jnp.dtype, config: DataConfig) -> None:
    """Set the dataset.
    Args:
        train: If the dataset is for training.
        config: The config for the dataset.
    """
    self.root: str = config.dataset_root
    self.use_transforms: bool = True if train else False
    if self.use_transforms and config.transform is None:
        raise ValueError("Transforms must be provided for training.")
    self.transforms = A.from_dict(config.transform)
    self.image_size: int = config.size
    self.dataset_name = config.dataset_name
    self.dtype = dtype
    self.params = config.train_params if train else config.test_params
    self.dataset: tf.data.Dataset = self.load_dataset(train)
    assert len(self.dataset) > 0, "Dataset is empty."

__len__()

Return the length of the dataset.

Source code in modules/utils.py
103
104
105
def __len__(self) -> int:
    """Return the length of the dataset."""
    return len(self.dataset)

get_dataset()

Return the dataset.

Returns:

Type Description
tf.data.Dataset

The dataset.

Source code in modules/utils.py
132
133
134
135
136
137
138
139
def get_dataset(self) -> tf.data.Dataset:
    """Return the dataset.
    Returns:
        The dataset.
    """
    dataset = self.dataset.map(self._preprocess)
    dataset = dataset.shuffle(self.params.batch_size * 16) if self.params.shuffle else dataset
    return dataset

load_dataset(train) abstractmethod

Load the dataset.

Parameters:

Name Type Description Default
train bool

If the dataset is for training.

required

Returns:

Type Description
tf.data.Dataset

The dataset.

Source code in modules/utils.py
 93
 94
 95
 96
 97
 98
 99
100
101
@abstractmethod
def load_dataset(self, train: bool) -> tf.data.Dataset:
    """Load the dataset.
    Args:
        train: If the dataset is for training.
    Returns:
        The dataset.
    """
    raise NotImplementedError

DataLoader

Dataloader similar as in pytorch.

Source code in modules/utils.py
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
class DataLoader:
    """Dataloader similar as in pytorch."""

    def __init__(self, dataset: BaseDataset, distributed: bool) -> None:
        """Create a data loader.
        Args:
            dataset (BaseDataset): The dataset to load.
            distributed (bool): If the data is distributed.
        """
        self.dataset_placeholder = dataset
        self.dist = distributed

    def __len__(self) -> int:
        """Return the length of the dataset."""
        return len(self.dataset_placeholder) // self.dataset_placeholder.params.batch_size

    def __call__(self, *args: Any, **kwds: Any) -> Iterable:
        """Return the dataset in dataloader style."""
        ds = self.dataset_placeholder.get_dataset()
        batch_size = self.dataset_placeholder.params.batch_size
        if self.dist:
            per_core_bs, remainder = divmod(batch_size, len(jax.devices()))
            assert remainder == 0
            ds = ds.batch(per_core_bs, drop_remainder=True).batch(
                len(jax.devices()), drop_remainder=True
            )
        else:
            ds = ds.batch(batch_size, drop_remainder=True)

        # ds = map(lambda x: x._numpy(), ds.prefetch(tf.data.AUTOTUNE))
        # data = flax.jax_utils.prefetch_to_device(ds, 3) if self.dist else ds
        ds = tfds.as_numpy(ds.prefetch(tf.data.AUTOTUNE))
        ds = flax.jax_utils.prefetch_to_device(ds, 3) if self.dist else ds
        return ds

__call__(*args, **kwds)

Return the dataset in dataloader style.

Source code in modules/utils.py
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
def __call__(self, *args: Any, **kwds: Any) -> Iterable:
    """Return the dataset in dataloader style."""
    ds = self.dataset_placeholder.get_dataset()
    batch_size = self.dataset_placeholder.params.batch_size
    if self.dist:
        per_core_bs, remainder = divmod(batch_size, len(jax.devices()))
        assert remainder == 0
        ds = ds.batch(per_core_bs, drop_remainder=True).batch(
            len(jax.devices()), drop_remainder=True
        )
    else:
        ds = ds.batch(batch_size, drop_remainder=True)

    # ds = map(lambda x: x._numpy(), ds.prefetch(tf.data.AUTOTUNE))
    # data = flax.jax_utils.prefetch_to_device(ds, 3) if self.dist else ds
    ds = tfds.as_numpy(ds.prefetch(tf.data.AUTOTUNE))
    ds = flax.jax_utils.prefetch_to_device(ds, 3) if self.dist else ds
    return ds

__init__(dataset, distributed)

Create a data loader.

Parameters:

Name Type Description Default
dataset BaseDataset

The dataset to load.

required
distributed bool

If the data is distributed.

required
Source code in modules/utils.py
199
200
201
202
203
204
205
206
def __init__(self, dataset: BaseDataset, distributed: bool) -> None:
    """Create a data loader.
    Args:
        dataset (BaseDataset): The dataset to load.
        distributed (bool): If the data is distributed.
    """
    self.dataset_placeholder = dataset
    self.dist = distributed

__len__()

Return the length of the dataset.

Source code in modules/utils.py
208
209
210
def __len__(self) -> int:
    """Return the length of the dataset."""
    return len(self.dataset_placeholder) // self.dataset_placeholder.params.batch_size

DummyDataset

Bases: BaseDataset

Create dummy dataset.

Source code in modules/utils.py
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
class DummyDataset(BaseDataset):
    """Create dummy dataset."""

    def load_dataset(self, train: bool) -> tf.data.Dataset:
        """Load the dataset.
        Args:
            train: If the dataset is for training.
        Returns:
            The dataset.
        """
        dummy = (
            tf.random.normal(
                (self.params.batch_size * 4, self.image_size, self.image_size, 3),
                dtype=tf.float32,
            )
            * 255.0
        )  # 0-255
        ds = tf.data.Dataset.from_tensor_slices(dummy)
        self.dataset_name = "dummy"
        return ds.cache()

load_dataset(train)

Load the dataset.

Parameters:

Name Type Description Default
train bool

If the dataset is for training.

required

Returns:

Type Description
tf.data.Dataset

The dataset.

Source code in modules/utils.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
def load_dataset(self, train: bool) -> tf.data.Dataset:
    """Load the dataset.
    Args:
        train: If the dataset is for training.
    Returns:
        The dataset.
    """
    dummy = (
        tf.random.normal(
            (self.params.batch_size * 4, self.image_size, self.image_size, 3),
            dtype=tf.float32,
        )
        * 255.0
    )  # 0-255
    ds = tf.data.Dataset.from_tensor_slices(dummy)
    self.dataset_name = "dummy"
    return ds.cache()

TensorflowDataset

Bases: BaseDataset

Tensorflow dataset.

Source code in modules/utils.py
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
class TensorflowDataset(BaseDataset):
    """Tensorflow dataset."""

    def load_dataset(self, train: bool) -> tf.data.Dataset:
        """Load the dataset.
        Args:
            train: If the dataset is for training.
        Returns:
            The dataset."""

        # if you get error 'Too many open files' one can resolve it doing what this issue proposed
        # https://github.com/tensorflow/datasets/issues/1441#issuecomment-581660890
        # Below is the code to resolve it
        # import resource
        # low, high = resource.getrlimit(resource.RLIMIT_NOFILE)
        # resource.setrlimit(resource.RLIMIT_NOFILE, (high, high))
        split = "train" if train else "validation"
        try:
            ds = tfds.load(
                name=self.dataset_name, split=split, as_supervised=True, data_dir=self.root
            ).map(tf.autograph.experimental.do_not_convert(lambda x, y: x))
        except Exception as e:
            if split == "validation":
                ds = tfds.load(
                    name=self.dataset_name, split="test", as_supervised=True, data_dir=self.root
                ).map(tf.autograph.experimental.do_not_convert(lambda x, y: x))
            else:
                raise e

        return ds.cache()

load_dataset(train)

Load the dataset.

Parameters:

Name Type Description Default
train bool

If the dataset is for training.

required

Returns:

Type Description
tf.data.Dataset

The dataset.

Source code in modules/utils.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
def load_dataset(self, train: bool) -> tf.data.Dataset:
    """Load the dataset.
    Args:
        train: If the dataset is for training.
    Returns:
        The dataset."""

    # if you get error 'Too many open files' one can resolve it doing what this issue proposed
    # https://github.com/tensorflow/datasets/issues/1441#issuecomment-581660890
    # Below is the code to resolve it
    # import resource
    # low, high = resource.getrlimit(resource.RLIMIT_NOFILE)
    # resource.setrlimit(resource.RLIMIT_NOFILE, (high, high))
    split = "train" if train else "validation"
    try:
        ds = tfds.load(
            name=self.dataset_name, split=split, as_supervised=True, data_dir=self.root
        ).map(tf.autograph.experimental.do_not_convert(lambda x, y: x))
    except Exception as e:
        if split == "validation":
            ds = tfds.load(
                name=self.dataset_name, split="test", as_supervised=True, data_dir=self.root
            ).map(tf.autograph.experimental.do_not_convert(lambda x, y: x))
        else:
            raise e

    return ds.cache()

VQGanFeatureExtractor

Bases: VQGanImageProcessor

Extract features for VQGan from images. Extends VQGanImageProcessor only with call function to run preprocessing.

Source code in modules/utils.py
501
502
503
504
505
506
507
class VQGanFeatureExtractor(VQGanImageProcessor):
    """Extract features for VQGan from images.
    Extends VQGanImageProcessor only with call function to run preprocessing.
    """

    def __call__(self, *args: Any, **kwds: Any) -> BatchFeature:
        return self.preprocess(*args, **kwds)

VQGanImageProcessor

Constructs a VQGan image processor.

Parameters:

Name Type Description Default
do_resize `bool`, *optional*, defaults to `True`

Whether to resize the image's (height, width) dimensions to the specified (size["height"], size["width"]). Can be overridden by the do_resize parameter in the preprocess method.

True
size `dict`, *optional*, defaults to `{"height"

256, "width": 256}): Size of the output image after resizing. Can be overridden by thesizeparameter in thepreprocess` method.

None
resample `Image.Resampling`, *optional*, defaults to `Image.Resampling.BILINEAR`

Resampling filter to use if resizing the image. Can be overridden by the resample parameter in the preprocess method.

Image.Resampling.BILINEAR
do_rescale `bool`, *optional*, defaults to `True`

Whether to rescale the image by the specified scale rescale_factor. Can be overridden by the do_rescale parameter in the preprocess method.

True
rescale_factor `int` or `float`, *optional*, defaults to `1/255`

Scale factor to use if rescaling the image. Can be overridden by the rescale_factor parameter in the preprocess method.

1 / 255
do_normalize bool

Whether to normalize the image. Can be overridden by the do_normalize parameter in the preprocess method.

True
image_mean `float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`

Mean to use if normalizing the image. This is a float or list of floats the length of the number of channels in the image. Can be overridden by the image_mean parameter in the preprocess method.

None
image_std `float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`

Standard deviation to use if normalizing the image. This is a float or list of floats the length of the number of channels in the image. Can be overridden by the image_std parameter in the preprocess method.

None
dtype `jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32` jnp.float32
Source code in modules/utils.py
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
class VQGanImageProcessor:
    """
    Constructs a VQGan image processor.
    Args:
        do_resize (`bool`, *optional*, defaults to `True`):
            Whether to resize the image's (height, width) dimensions to the specified
            `(size["height"], size["width"])`. Can be overridden by the `do_resize` parameter in
            the `preprocess` method.
        size (`dict`, *optional*, defaults to `{"height": 256, "width": 256}`):
            Size of the output image after resizing. Can be overridden by the `size` parameter in
            the `preprocess` method.
        resample (`Image.Resampling`, *optional*, defaults to `Image.Resampling.BILINEAR`):
            Resampling filter to use if resizing the image. Can be overridden by the `resample`
            parameter in the `preprocess` method.
        do_rescale (`bool`, *optional*, defaults to `True`):
            Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden
            by the `do_rescale` parameter in the `preprocess` method.
        rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
            Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor`
            parameter in the `preprocess` method.
        do_normalize:
            Whether to normalize the image. Can be overridden by the `do_normalize` parameter in
            the `preprocess` method.
        image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
            Mean to use if normalizing the image. This is a float or list of floats the length of
            the number of channels in the image. Can be overridden by the `image_mean` parameter
            in the `preprocess` method.
        image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
            Standard deviation to use if normalizing the image. This is a float or list of floats
            the length of the number of channels in the image. Can be overridden by the `image_std`
            parameter in the `preprocess` method.
        dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
    """

    model_input_names = ["pixel_values"]

    def __init__(
        self,
        do_resize: bool = True,
        size: Optional[Dict[str, int]] = None,
        resample: Image.Resampling = Image.Resampling.BILINEAR,
        do_rescale: bool = True,
        rescale_factor: Union[int, float] = 1 / 255,
        do_normalize: bool = True,
        image_mean: Optional[Union[float, List[float]]] = None,
        image_std: Optional[Union[float, List[float]]] = None,
        dtype: jnp.dtype = jnp.float32,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)
        size = size if size is not None else {"height": 256, "width": 256}
        self.do_resize = do_resize
        self.do_rescale = do_rescale
        self.do_normalize = do_normalize
        self.size = size
        self.resample = resample
        self.rescale_factor = rescale_factor
        self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
        self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
        self.dtype = dtype

    def resize(
        self,
        image: np.ndarray,
        size: Dict[str, int],
        resample: Image.Resampling = Image.Resampling.BILINEAR,
        data_format: Optional[str] = None,
        **kwargs,
    ) -> np.ndarray:
        """
        Resize an image to `(size["height"], size["width"])`.
        Args:
            image (`np.ndarray`):
                Image to resize.
            size (`Dict[str, int]`):
                Dictionary in the format `{"height": int, "width": int}` specifying the size
                of the output image.
            resample:
                `Image.Resampling` filter to use when resizing the image e.g.
                `Image.Resampling.BILINEAR`.
            data_format (`str`, *optional*):
                The channel dimension format for the output image. If unset, the channel
                dimension format of the input image is used. Can be one of:
                - `"channels_first"`: image in (num_channels, height, width) format.
                - `"channels_last"`: image in (height, width, num_channels) format.
        Returns:
            The resized image.
        """
        if "height" not in size or "width" not in size:
            raise ValueError(
                "The `size` dictionary must contain the keys `height` and `width`. Got"
                f" {size.keys()}"
            )

        revert_format = False
        if data_format == "channels_first":
            revert_format = True
            image = np.transpose(image, (1, 2, 0))
        else:
            if image.shape[0] == 3:
                revert_format = True
                image = np.transpose(image, (1, 2, 0))

        pil_image = Image.fromarray(np.uint8(image))
        pil_image_resized = pil_image.resize((size["width"], size["height"]), resample=resample)
        image_np = np.array(pil_image_resized)
        assert image_np.shape == (size["height"], size["width"], image.shape[-1])
        image_np = np.transpose(image_np, (2, 0, 1)) if revert_format else image_np
        return image_np

    def rescale(
        self,
        image: np.ndarray,
        scale: float,
        data_format: Optional[str] = None,
        **kwargs,
    ) -> np.ndarray:
        """
        Rescale an image by a scale factor. image = image * scale.
        Args:
            image (`np.ndarray`):
                Image to resize.
            scale (`float`):
                The scaling factor to rescale pixel values by.
            data_format (`str`, *optional*):
                The channel dimension format for the output image. If unset, the channel dimension
                format of the input image is used. Can be one of:
                - `"channels_first"`: image in (num_channels, height, width) format.
                - `"channels_last"`: image in (height, width, num_channels) format.
        Returns:
            The resized image.
        """
        return image * scale

    def normalize(
        self,
        image: np.ndarray,
        mean: Union[float, List[float]],
        std: Union[float, List[float]],
        data_format: Optional[str] = None,
        **kwargs,
    ) -> np.ndarray:
        """
        Normalize an image. image = (image - image_mean) / image_std.
        Args:
            image (`np.ndarray`):
                Image to normalize.
            mean (`float` or `List[float]`):
                Image mean to use for normalization.
            std (`float` or `List[float]`):
                Image standard deviation to use for normalization.
            data_format (`str`, *optional*):
                The channel dimension format for the output image. If unset, the channel dimension
                format of the input image is used. Can be one of:
                - `"channels_first"`: image in (num_channels, height, width) format.
                - `"channels_last"`: image in (height, width, num_channels) format.
        Returns:
            The normalized image.
        """
        if isinstance(mean, list):
            assert len(mean) == min(image.shape)
        if isinstance(std, list):
            assert len(std) == min(image.shape)

        revert_format = False
        if data_format == "channels_first":
            revert_format = True
            image = np.transpose(image, (1, 2, 0))
        else:
            if image.shape[0] == 3:
                revert_format = True
                image = np.transpose(image, (1, 2, 0))

        image_normalized = (image - mean) / std
        image_normalized = (
            np.transpose(image_normalized, (2, 0, 1)) if revert_format else image_normalized
        )
        return image_normalized

    def preprocess(
        self,
        images: Union[Image.Image, np.ndarray, List[Image.Image], List[np.ndarray]],
        do_resize: Optional[bool] = None,
        size: Optional[Dict[str, int]] = None,
        resample: Optional[Image.Resampling] = None,
        do_rescale: Optional[bool] = None,
        rescale_factor: Optional[float] = None,
        do_normalize: Optional[bool] = None,
        image_mean: Optional[Union[float, List[float]]] = None,
        image_std: Optional[Union[float, List[float]]] = None,
        data_format: str = "channels_last",
        **kwargs,
    ) -> BatchFeature:
        """
        Preprocess an image or batch of images.
        Args:
            images (`Image.Image`, `np.ndarray`, `List[Image.Image]`, `List[np.ndarray]`):
                Image to preprocess.
            do_resize (`bool`, *optional*, defaults to `self.do_resize`):
                Whether to resize the image.
            size (`Dict[str, int]`, *optional*, defaults to `self.size`):
                Dictionary in the format `{"height": h, "width": w}` specifying the size of the
                output image after resizing.
            resample (`Image.Resampling` filter, *optional*, defaults to `self.resample`):
                `Image.Resampling` filter to use if resizing the image e.g.
                `Image.Resampling.BILINEAR`. Only has an effect if `do_resize` is set to `True`.
            do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
                Whether to rescale the image values between [0 - 1].
            rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
                Rescale factor to rescale the image by if `do_rescale` is set to `True`.
            do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
                Whether to normalize the image.
            image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
                Image mean to use if `do_normalize` is set to `True`.
            image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
                Image standard deviation to use if `do_normalize` is set to `True`.
            data_format (`str`, *optional*, defaults to `channels_las`):):
                The channel dimension format for the output image. Can be one of:
                - `"channels_first"`: image in (num_channels, height, width) format.
                - `"channels_last"`: image in (height, width, num_channels) format.
                - Unset: Use the channel dimension format of the input image.
        Returns:
            The preprocessed image(s).
        """
        assert data_format in ["channels_first", "channels_last"]
        do_resize = do_resize if do_resize is not None else self.do_resize
        do_rescale = do_rescale if do_rescale is not None else self.do_rescale
        do_normalize = do_normalize if do_normalize is not None else self.do_normalize
        resample = resample if resample is not None else self.resample
        rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
        image_mean = image_mean if image_mean is not None else self.image_mean
        image_std = image_std if image_std is not None else self.image_std

        size = size if size is not None else self.size

        # Batch of images and jnp.ndarray
        if isinstance(images, (list, tuple)):
            images = [np.array(image) for image in images]
        else:
            images = [np.array(images)]

        # if needed to rescale the image
        for img in images:
            assert img.min() >= 0 and img.max() <= 255, "Image values must be in [0 - 255] range."

        if do_resize and size is None:
            raise ValueError("Size must be specified if do_resize is True.")

        if do_rescale and rescale_factor is None:
            raise ValueError("Rescale factor must be specified if do_rescale is True.")

        if do_resize:
            images = [self.resize(image=image, size=size, resample=resample) for image in images]

        if do_rescale:
            images = [self.rescale(image=image, scale=rescale_factor) for image in images]

        if do_normalize:
            images = [
                self.normalize(image=image, mean=image_mean, std=image_std) for image in images
            ]

        if data_format == "channels_first":
            images = [jnp.transpose(image, (1, 2, 0)) for image in images]

        data = {"pixel_values": images}
        return BatchFeature(data=data, tensor_type="jax")

normalize(image, mean, std, data_format=None, **kwargs)

Normalize an image. image = (image - image_mean) / image_std.

Parameters:

Name Type Description Default
image `np.ndarray`

Image to normalize.

required
mean `float` or `List[float]`

Image mean to use for normalization.

required
std `float` or `List[float]`

Image standard deviation to use for normalization.

required
data_format `str`, *optional*

The channel dimension format for the output image. If unset, the channel dimension format of the input image is used. Can be one of: - "channels_first": image in (num_channels, height, width) format. - "channels_last": image in (height, width, num_channels) format.

None

Returns:

Type Description
np.ndarray

The normalized image.

Source code in modules/utils.py
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
def normalize(
    self,
    image: np.ndarray,
    mean: Union[float, List[float]],
    std: Union[float, List[float]],
    data_format: Optional[str] = None,
    **kwargs,
) -> np.ndarray:
    """
    Normalize an image. image = (image - image_mean) / image_std.
    Args:
        image (`np.ndarray`):
            Image to normalize.
        mean (`float` or `List[float]`):
            Image mean to use for normalization.
        std (`float` or `List[float]`):
            Image standard deviation to use for normalization.
        data_format (`str`, *optional*):
            The channel dimension format for the output image. If unset, the channel dimension
            format of the input image is used. Can be one of:
            - `"channels_first"`: image in (num_channels, height, width) format.
            - `"channels_last"`: image in (height, width, num_channels) format.
    Returns:
        The normalized image.
    """
    if isinstance(mean, list):
        assert len(mean) == min(image.shape)
    if isinstance(std, list):
        assert len(std) == min(image.shape)

    revert_format = False
    if data_format == "channels_first":
        revert_format = True
        image = np.transpose(image, (1, 2, 0))
    else:
        if image.shape[0] == 3:
            revert_format = True
            image = np.transpose(image, (1, 2, 0))

    image_normalized = (image - mean) / std
    image_normalized = (
        np.transpose(image_normalized, (2, 0, 1)) if revert_format else image_normalized
    )
    return image_normalized

preprocess(images, do_resize=None, size=None, resample=None, do_rescale=None, rescale_factor=None, do_normalize=None, image_mean=None, image_std=None, data_format='channels_last', **kwargs)

Preprocess an image or batch of images.

Parameters:

Name Type Description Default
images `Image.Image`, `np.ndarray`, `List[Image.Image]`, `List[np.ndarray]`

Image to preprocess.

required
do_resize `bool`, *optional*, defaults to `self.do_resize`

Whether to resize the image.

None
size `Dict[str, int]`, *optional*, defaults to `self.size`

Dictionary in the format {"height": h, "width": w} specifying the size of the output image after resizing.

None
resample `Image.Resampling` filter, *optional*, defaults to `self.resample`

Image.Resampling filter to use if resizing the image e.g. Image.Resampling.BILINEAR. Only has an effect if do_resize is set to True.

None
do_rescale `bool`, *optional*, defaults to `self.do_rescale`

Whether to rescale the image values between [0 - 1].

None
rescale_factor `float`, *optional*, defaults to `self.rescale_factor`

Rescale factor to rescale the image by if do_rescale is set to True.

None
do_normalize `bool`, *optional*, defaults to `self.do_normalize`

Whether to normalize the image.

None
image_mean `float` or `List[float]`, *optional*, defaults to `self.image_mean`

Image mean to use if do_normalize is set to True.

None
image_std `float` or `List[float]`, *optional*, defaults to `self.image_std`

Image standard deviation to use if do_normalize is set to True.

None
data_format `str`, *optional*, defaults to `channels_las`

): The channel dimension format for the output image. Can be one of: - "channels_first": image in (num_channels, height, width) format. - "channels_last": image in (height, width, num_channels) format. - Unset: Use the channel dimension format of the input image.

'channels_last'

Returns:

Type Description
BatchFeature

The preprocessed image(s).

Source code in modules/utils.py
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
def preprocess(
    self,
    images: Union[Image.Image, np.ndarray, List[Image.Image], List[np.ndarray]],
    do_resize: Optional[bool] = None,
    size: Optional[Dict[str, int]] = None,
    resample: Optional[Image.Resampling] = None,
    do_rescale: Optional[bool] = None,
    rescale_factor: Optional[float] = None,
    do_normalize: Optional[bool] = None,
    image_mean: Optional[Union[float, List[float]]] = None,
    image_std: Optional[Union[float, List[float]]] = None,
    data_format: str = "channels_last",
    **kwargs,
) -> BatchFeature:
    """
    Preprocess an image or batch of images.
    Args:
        images (`Image.Image`, `np.ndarray`, `List[Image.Image]`, `List[np.ndarray]`):
            Image to preprocess.
        do_resize (`bool`, *optional*, defaults to `self.do_resize`):
            Whether to resize the image.
        size (`Dict[str, int]`, *optional*, defaults to `self.size`):
            Dictionary in the format `{"height": h, "width": w}` specifying the size of the
            output image after resizing.
        resample (`Image.Resampling` filter, *optional*, defaults to `self.resample`):
            `Image.Resampling` filter to use if resizing the image e.g.
            `Image.Resampling.BILINEAR`. Only has an effect if `do_resize` is set to `True`.
        do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
            Whether to rescale the image values between [0 - 1].
        rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
            Rescale factor to rescale the image by if `do_rescale` is set to `True`.
        do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
            Whether to normalize the image.
        image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
            Image mean to use if `do_normalize` is set to `True`.
        image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
            Image standard deviation to use if `do_normalize` is set to `True`.
        data_format (`str`, *optional*, defaults to `channels_las`):):
            The channel dimension format for the output image. Can be one of:
            - `"channels_first"`: image in (num_channels, height, width) format.
            - `"channels_last"`: image in (height, width, num_channels) format.
            - Unset: Use the channel dimension format of the input image.
    Returns:
        The preprocessed image(s).
    """
    assert data_format in ["channels_first", "channels_last"]
    do_resize = do_resize if do_resize is not None else self.do_resize
    do_rescale = do_rescale if do_rescale is not None else self.do_rescale
    do_normalize = do_normalize if do_normalize is not None else self.do_normalize
    resample = resample if resample is not None else self.resample
    rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
    image_mean = image_mean if image_mean is not None else self.image_mean
    image_std = image_std if image_std is not None else self.image_std

    size = size if size is not None else self.size

    # Batch of images and jnp.ndarray
    if isinstance(images, (list, tuple)):
        images = [np.array(image) for image in images]
    else:
        images = [np.array(images)]

    # if needed to rescale the image
    for img in images:
        assert img.min() >= 0 and img.max() <= 255, "Image values must be in [0 - 255] range."

    if do_resize and size is None:
        raise ValueError("Size must be specified if do_resize is True.")

    if do_rescale and rescale_factor is None:
        raise ValueError("Rescale factor must be specified if do_rescale is True.")

    if do_resize:
        images = [self.resize(image=image, size=size, resample=resample) for image in images]

    if do_rescale:
        images = [self.rescale(image=image, scale=rescale_factor) for image in images]

    if do_normalize:
        images = [
            self.normalize(image=image, mean=image_mean, std=image_std) for image in images
        ]

    if data_format == "channels_first":
        images = [jnp.transpose(image, (1, 2, 0)) for image in images]

    data = {"pixel_values": images}
    return BatchFeature(data=data, tensor_type="jax")

rescale(image, scale, data_format=None, **kwargs)

Rescale an image by a scale factor. image = image * scale.

Parameters:

Name Type Description Default
image `np.ndarray`

Image to resize.

required
scale `float`

The scaling factor to rescale pixel values by.

required
data_format `str`, *optional*

The channel dimension format for the output image. If unset, the channel dimension format of the input image is used. Can be one of: - "channels_first": image in (num_channels, height, width) format. - "channels_last": image in (height, width, num_channels) format.

None

Returns:

Type Description
np.ndarray

The resized image.

Source code in modules/utils.py
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
def rescale(
    self,
    image: np.ndarray,
    scale: float,
    data_format: Optional[str] = None,
    **kwargs,
) -> np.ndarray:
    """
    Rescale an image by a scale factor. image = image * scale.
    Args:
        image (`np.ndarray`):
            Image to resize.
        scale (`float`):
            The scaling factor to rescale pixel values by.
        data_format (`str`, *optional*):
            The channel dimension format for the output image. If unset, the channel dimension
            format of the input image is used. Can be one of:
            - `"channels_first"`: image in (num_channels, height, width) format.
            - `"channels_last"`: image in (height, width, num_channels) format.
    Returns:
        The resized image.
    """
    return image * scale

resize(image, size, resample=Image.Resampling.BILINEAR, data_format=None, **kwargs)

Resize an image to (size["height"], size["width"]).

Parameters:

Name Type Description Default
image `np.ndarray`

Image to resize.

required
size `Dict[str, int]`

Dictionary in the format {"height": int, "width": int} specifying the size of the output image.

required
resample Image.Resampling

Image.Resampling filter to use when resizing the image e.g. Image.Resampling.BILINEAR.

Image.Resampling.BILINEAR
data_format `str`, *optional*

The channel dimension format for the output image. If unset, the channel dimension format of the input image is used. Can be one of: - "channels_first": image in (num_channels, height, width) format. - "channels_last": image in (height, width, num_channels) format.

None

Returns:

Type Description
np.ndarray

The resized image.

Source code in modules/utils.py
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
def resize(
    self,
    image: np.ndarray,
    size: Dict[str, int],
    resample: Image.Resampling = Image.Resampling.BILINEAR,
    data_format: Optional[str] = None,
    **kwargs,
) -> np.ndarray:
    """
    Resize an image to `(size["height"], size["width"])`.
    Args:
        image (`np.ndarray`):
            Image to resize.
        size (`Dict[str, int]`):
            Dictionary in the format `{"height": int, "width": int}` specifying the size
            of the output image.
        resample:
            `Image.Resampling` filter to use when resizing the image e.g.
            `Image.Resampling.BILINEAR`.
        data_format (`str`, *optional*):
            The channel dimension format for the output image. If unset, the channel
            dimension format of the input image is used. Can be one of:
            - `"channels_first"`: image in (num_channels, height, width) format.
            - `"channels_last"`: image in (height, width, num_channels) format.
    Returns:
        The resized image.
    """
    if "height" not in size or "width" not in size:
        raise ValueError(
            "The `size` dictionary must contain the keys `height` and `width`. Got"
            f" {size.keys()}"
        )

    revert_format = False
    if data_format == "channels_first":
        revert_format = True
        image = np.transpose(image, (1, 2, 0))
    else:
        if image.shape[0] == 3:
            revert_format = True
            image = np.transpose(image, (1, 2, 0))

    pil_image = Image.fromarray(np.uint8(image))
    pil_image_resized = pil_image.resize((size["width"], size["height"]), resample=resample)
    image_np = np.array(pil_image_resized)
    assert image_np.shape == (size["height"], size["width"], image.shape[-1])
    image_np = np.transpose(image_np, (2, 0, 1)) if revert_format else image_np
    return image_np

make_img_grid(images, nrows=4)

Make image grid from images.

Parameters:

Name Type Description Default
images np.ndarray

image list to make grid.

required
nrows int

number of rows. Defaults to 4.

4

Returns:

Type Description
np.ndarray

image grid object.

Source code in modules/utils.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def make_img_grid(images: np.ndarray, nrows: int = 4) -> np.ndarray:
    """Make image grid from images.
    Args:
        images (np.ndarray): image list to make grid.
        nrows (int, optional): number of rows. Defaults to 4.
    Returns:
        image grid object.
    """
    nindex, height, width, intensity = images.shape
    ncols = nindex // nrows
    if nindex != nrows * ncols:
        images = images[: nrows * ncols]
        nindex = nrows * ncols

    # want result.shape = (height*nrows, width*ncols, intensity)
    result = (
        images.reshape(nrows, ncols, height, width, intensity)
        .swapaxes(1, 2)
        .reshape(height * nrows, width * ncols, intensity)
    )
    return result

post_processing(image, resize=None)

Post processing for image. un standarize image and multiply by 255. Next clip values to [0, 255] and convert to uint8.

Parameters:

Name Type Description Default
image np.ndarray

image to post process.

required

Returns:

Type Description
np.ndarray

post processed image.

Source code in modules/utils.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def post_processing(image: np.ndarray, resize: Optional[int] = None) -> np.ndarray:
    """Post processing for image.
    un standarize image and multiply by 255.
    Next clip values to [0, 255] and convert to uint8.
    Args:
        image (np.ndarray): image to post process.

    Returns:
        post processed image.
    """
    image = image * IMAGENET_STANDARD_STD + IMAGENET_STANDARD_MEAN
    image *= 255.0
    image = np.clip(image, 0.0, 255.0)
    image = image.astype(np.uint8)
    if resize:
        image = Image.fromarray(image)
        image = image.resize((resize, resize))
        image = np.array(image)
    return image

set_seed(seed)

Set seed for random operations.

Source code in modules/utils.py
21
22
23
24
25
def set_seed(seed: int) -> None:
    """Set seed for random operations."""
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)

VQGAN

GumbelQuantize

Bases: nn.Module

Gumbel Softmax trick quantizer Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016. z (continuous) -> z_q (discrete) z.shape = (batch, height, width, channel)

quantization pipeline
  1. get encoder input (B,H,W,C)
  2. get logits(prob) of input (B,H,W,n_embed)
See

https://arxiv.org/abs/1611.01144

Attributes:

Name Type Description
config VQGANConfig

the config of the model.

dtype jnp.dtype

the dtype of the computation (default: float32).

Config Attributes

n_embed (int) : number of embeddings. emb_dim (int): dimension of embedding. kl_weight (float): weight of kl loss.

Source code in modules/vqgan.py
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
class GumbelQuantize(nn.Module):
    """Gumbel Softmax trick quantizer
    Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016.
    z (continuous) -> z_q (discrete)
    z.shape = (batch, height, width, channel)
    quantization pipeline:
        1. get encoder input (B,H,W,C)
        2. get logits(prob) of input (B,H,W,n_embed)

    See:
        https://arxiv.org/abs/1611.01144

    Attributes:
        config (VQGANConfig): the config of the model.
        dtype (jnp.dtype): the dtype of the computation (default: float32).

    Config Attributes:
        n_embed (int) : number of embeddings.
        emb_dim (int): dimension of embedding.
        kl_weight (float): weight of kl loss.
    """

    config: VQGANConfig
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        # Project the input to the embedding space
        self.proj = nn.Conv(
            self.config.n_embed,
            kernel_size=(1, 1),
            strides=(1, 1),
            padding="VALID",
            dtype=self.dtype,
        )
        # Embeddings (codebook)
        init_embbeding = jax.nn.initializers.uniform(
            scale=-1.0 / self.config.n_embed, dtype=self.dtype
        )
        self.embedding = nn.Embed(
            self.config.n_embed,
            self.config.embed_dim,
            embedding_init=init_embbeding,
            dtype=self.dtype,
        )

    def __call__(self, z: jnp.ndarray) -> Tuple[jnp.ndarray, float, jnp.ndarray]:
        # project z to get logits
        logits = self.proj(z)

        # given logits, sample from the Gumbel-Softmax distribution
        gumbel_rng = self.make_rng("gumbel")
        gumbels = jax.random.gumbel(gumbel_rng, logits.shape, dtype=self.dtype)
        indicies_prob = nn.softmax((logits + gumbels) / self.config.gumb_temp, axis=-1)

        # dummy op to init the weights, so we can access them below
        self.embedding(jnp.ones((1, 1), dtype="i4"))

        # get quantized latent vectors
        emb_weights = self.variables["params"]["embedding"]["embedding"]
        z_q = jnp.einsum("bhwp,pd->bhwd", indicies_prob, emb_weights).reshape(z.shape)

        # get indices [BxHxWxP] -> [BxH*W]
        indices = jnp.argmax(indicies_prob, axis=-1).reshape(z.shape[0], -1)

        # compute the codebook_loss (q_loss)
        qy = nn.softmax(logits)
        q_loss = self.config.kl_weight * jnp.mean(
            jnp.sum(qy * jnp.log(qy * self.config.n_embed + 1e-10), axis=-1)
        )

        # here we return the embeddings, indices and logits (for loss)
        return z_q, q_loss, indices

    @staticmethod
    def get_codebook_entry(
        params: FrozenDict, indices: jnp.ndarray, shape: Optional[Tuple[int, ...]] = None
    ) -> jnp.ndarray:
        """Get the codebook entry for a given index.
        Input is expected to be of shape (batch, num_tokens)"""
        # indices are expected to be of shape (batch, num_tokens)
        # get quantized latent vectors
        B = indices.shape[0]
        indices = indices.reshape(
            -1,
        )
        emb_weights = params["embedding"]["embedding"]
        z_q = jnp.take(emb_weights, indices, axis=0).reshape(B, -1)
        # z_q = self.embedding(indices) can't be used because of not accessibility of self.embedding

        if shape is not None:
            z_q = z_q.reshape(shape)

        return z_q

get_codebook_entry(params, indices, shape=None) staticmethod

Get the codebook entry for a given index. Input is expected to be of shape (batch, num_tokens)

Source code in modules/vqgan.py
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
@staticmethod
def get_codebook_entry(
    params: FrozenDict, indices: jnp.ndarray, shape: Optional[Tuple[int, ...]] = None
) -> jnp.ndarray:
    """Get the codebook entry for a given index.
    Input is expected to be of shape (batch, num_tokens)"""
    # indices are expected to be of shape (batch, num_tokens)
    # get quantized latent vectors
    B = indices.shape[0]
    indices = indices.reshape(
        -1,
    )
    emb_weights = params["embedding"]["embedding"]
    z_q = jnp.take(emb_weights, indices, axis=0).reshape(B, -1)
    # z_q = self.embedding(indices) can't be used because of not accessibility of self.embedding

    if shape is not None:
        z_q = z_q.reshape(shape)

    return z_q

NLayerDiscriminator

Bases: nn.Module

Defines a PatchGAN discriminator as in Pix2Pix

See

https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py

Attributes:

Name Type Description
ndf int

the number of filters in the last conv layer

n_layers int

the number of conv layers in the discriminator

output_dim bool

dim of output the last channel of the discriminator

dtype jnp.dtype

the dtype of the computation (default: float32)

Source code in modules/vqgan.py
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
class NLayerDiscriminator(nn.Module):
    """Defines a PatchGAN discriminator as in Pix2Pix
    See:
        https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py

    Attributes:
        ndf (int): the number of filters in the last conv layer
        n_layers (int): the number of conv layers in the discriminator
        output_dim (bool): dim of output the last channel of the discriminator
        dtype: the dtype of the computation (default: float32)
    """

    ndf: int
    n_layers: int
    output_dim: int
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(self, x: jnp.ndarray, train: bool = True) -> jnp.ndarray:
        # input is bx256x256x(nc) return bx30x30x1
        x = nn.Conv(
            self.ndf,
            kernel_size=(4, 4),
            strides=(2, 2),
            padding=((1, 1), (1, 1)),
            dtype=self.dtype,
        )(x)
        x = nn.leaky_relu(x, negative_slope=0.2)
        # downsample
        for n in range(1, self.n_layers):
            nf_mult = min(2**n, 8)
            x = nn.Conv(
                self.ndf * nf_mult,
                kernel_size=(4, 4),
                strides=(2, 2),
                padding=((1, 1), (1, 1)),
                use_bias=False,
                dtype=self.dtype,
            )(x)
            x = nn.BatchNorm(use_running_average=not train, dtype=self.dtype)(x)
            x = nn.leaky_relu(x, negative_slope=0.2)

        # last downsample
        nf_mult = min(2**n, 8)
        x = nn.Conv(
            self.ndf * nf_mult,
            kernel_size=(4, 4),
            strides=(1, 1),
            padding=((1, 1), (1, 1)),
            use_bias=False,
            dtype=self.dtype,
        )(x)
        x = nn.BatchNorm(use_running_average=not train, dtype=self.dtype)(x)
        x = nn.leaky_relu(x, negative_slope=0.2)

        # output 1 channel prediction map
        logits = nn.Conv(
            self.output_dim,
            kernel_size=(4, 4),
            strides=(1, 1),
            padding=((1, 1), (1, 1)),
            dtype=self.dtype,
        )(x)
        return logits

VQGANPreTrainedModel

Bases: FlaxPreTrainedModel

An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.

Attributes:

Name Type Description
module_class nn.Module

a class derived from nn.Module that defines the model's core computation.

config_class PretrainedConfig

a class derived from PretrainedConfig

Source code in modules/vqgan.py
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
class VQGANPreTrainedModel(FlaxPreTrainedModel):
    """An abstract class to handle weights initialization and a simple interface
    for downloading and loading pretrained models.

    Attributes:
        module_class (nn.Module): a class derived from nn.Module
            that defines the model's core computation.
        config_class (PretrainedConfig): a class derived from PretrainedConfig

    """

    module_class: nn.Module
    config_class: PretrainedConfig

    def __init__(
        self,
        config: PretrainedConfig = VQGANConfig(),
        input_shape: Tuple = (1, 256, 256, 3),
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
        _do_init: bool = True,
        **kwargs,
    ):
        """Initialize the model.

        Args:
            config (PretrainedConfig, optional): the config of the model. Defaults to VQGANConfig.
            input_shape (Tuple, optional): the input shape of the model.
                Defaults to (1, 256, 256, 3).
            seed (int, optional): the seed of the model. Defaults to 0.
            dtype (jnp.dtype, optional): the dtype of the computation. Defaults to jnp.float32.
            _do_init (bool, optional): whether to initialize the model. Defaults to True.
        """
        self._missing_keys: Set[str] = set()
        if not isinstance(config, self.config_class):
            raise ValueError(f"config: {config} has to be an instance of {self.config_class}")
        if self.module_class is None:
            raise NotImplementedError("module_class should be defined in derived classes")
        module = self.module_class(config=config, dtype=dtype, **kwargs)
        self.seed = seed
        self.dtype = dtype
        super().__init__(
            config,
            module,
            input_shape=input_shape,
            seed=seed,
            dtype=dtype,
            _do_init=_do_init,
        )

    def init_weights(
        self,
        rng: Union[Any, jnp.ndarray],
        input_shape: Tuple,
        params: Optional[FrozenDict[str, Any]] = None,
    ) -> FrozenDict[str, Any]:
        """Initialize the weights of the model. Get the params

        Args:
            rng (Union[Any,jnp.ndarray]): the random number generator.
            input_shape (Tuple): the input shape of the model.
            params (FrozenDict, optional): the params of the model. Defaults to None.

        Returns:
            initialized params of the model.
        """
        # initialize model
        input_x = jnp.zeros(input_shape, dtype=self.dtype)
        params_rng, dropout_rng, gumble_rng = jax.random.split(rng, num=3)
        rngs: Dict[str, Union[Any, jnp.ndarray]] = {
            "params": params_rng,
            "dropout": dropout_rng,
            "gumbel": gumble_rng,
        }

        random_params = self.module.init(rngs, input_x, True)["params"]

        # If params provided find unitialized params and replace with provided params
        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]  # type: ignore
            self._missing_keys = set()
            return freeze(unflatten_dict(params))
        else:
            return random_params

    def encode(
        self,
        pixel_values: jnp.ndarray,
        params: Optional[FrozenDict] = None,
        dropout_rng: Optional[Union[Any, jnp.ndarray]] = None,
        gumble_rng: Optional[Union[Any, jnp.ndarray]] = None,
        train: bool = False,
    ) -> Tuple[jnp.ndarray, float, jnp.ndarray]:
        """Encode the input.
        Args:
            pixel_values (jnp.ndarray): the input to the encoder.
            params (Optional[FrozenDict], optional): the params of the model. Defaults to None.
            dropout_rng (Union[Any,jnp.ndarray], optional): the dropout rng. Defaults to None.
            gumble_rng (Union[Any,jnp.ndarray], optional): the gumbel rng. Defaults to None.
            train (bool, optional): Training or inference mode. Defaults to False.
        """
        # Handle any PRNG if needed
        rngs: Dict[str, Union[Any, jnp.ndarray]] = (
            {"dropout": dropout_rng} if dropout_rng is not None else {}
        )
        rngs["gumbel"] = gumble_rng if gumble_rng is not None else {}
        return self.module.apply(
            {"params": params or self.params},
            pixel_values,
            not train,
            rngs=rngs,
            method=self.module.encode,
        )

    def decode(
        self,
        z: jnp.ndarray,
        params: Optional[FrozenDict] = None,
        dropout_rng: Optional[Union[Any, jnp.ndarray]] = None,
        gumble_rng: Optional[Union[Any, jnp.ndarray]] = None,
        train: bool = False,
    ) -> jnp.ndarray:
        """Decode the latent vector.

        Args:
            z (jnp.ndarray): the latent vector.
            params (Optional[FrozenDict], optional): the params of the model. Defaults to None.
            dropout_rng (Union[Any,jnp.ndarray], optional): the dropout rng. Defaults to None.
            gumble_rng (Union[Any,jnp.ndarray], optional): the gumbel rng. Defaults to None.
            train (bool, optional): Training or inference mode. Defaults to False.

        Returns:
            the decoded image.
        """
        # Handle any PRNG if needed
        rngs: Dict[str, Union[Any, jnp.ndarray]] = (
            {"dropout": dropout_rng} if dropout_rng is not None else {}
        )
        rngs["gumbel"] = gumble_rng if gumble_rng is not None else {}
        return self.module.apply(
            {"params": params or self.params},
            z,
            not train,
            rngs=rngs,
            method=self.module.decode,
        )

    def decode_code(
        self,
        indices: jnp.ndarray,
        z_shape: Tuple[int, ...],
        params: Optional[FrozenDict] = None,
    ) -> jnp.ndarray:
        """Decode the indices.

        Args:
            indices (jnp.ndarray): the indices.
            z_shape (Tuple[int, ...]): the shape of the latent vector.
            params (Optional[FrozenDict], optional): the params of the model. Defaults to None.

        Returns:
            the decoded image from indices.
        """
        return self.module.apply(
            {"params": params or self.params},
            indices,
            z_shape,
            method=self.module.decode_code,
        )

    def update_temperature(self, temperature: float, params: Optional[FrozenDict] = None) -> float:
        """Update the temperature of the model.
        Args:
            temperature (float): the temperature to update to.
            params (Optional[FrozenDict], optional): the params of the model. Defaults to None.
        Returns:
            the updated temperature.
        """
        new_temperature = self.module.apply(
            {"params": params or self.params},
            temperature,
            method=self.module.update_temperature,
        )
        return new_temperature

    def __call__(
        self,
        pixel_values: jnp.ndarray,
        params: Optional[FrozenDict] = None,
        dropout_rng: Optional[jnp.ndarray] = None,
        gumble_rng: Optional[jnp.ndarray] = None,
        train: bool = False,
    ) -> Tuple[jnp.ndarray, jnp.ndarray, float, jnp.ndarray]:
        """Encode and decode the input.

        Args:
            pixel_values (jnp.ndarray): the input to the encoder.
            params (Optional[FrozenDict], optional): the params of the model. Defaults to None.
            dropout_rng (Optional[jnp.ndarray], optional): the dropout rng. Defaults to None.
            gumble_rng (Optional[jnp.ndarray], optional): the gumbel rng. Defaults to None.
                If gumble_rng is None then the defult rng is used and produce deterministic results.
            train (bool, optional): Training or inference mode. Defaults to False.

        Returns:
                the encoded latent vector,
                the decoded image,
                the log prob of the latent vector,
                the indices of the latent vector.
        """
        # Check dtype
        pixel_values = (
            pixel_values.astype(self.dtype) if pixel_values.dtype != self.dtype else pixel_values
        )
        # Handle any PRNG if needed
        rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
        rngs["gumbel"] = gumble_rng if gumble_rng is not None else self.key
        return self.module.apply(
            {"params": params or self.params}, pixel_values, not train, rngs=rngs
        )

__call__(pixel_values, params=None, dropout_rng=None, gumble_rng=None, train=False)

Encode and decode the input.

Parameters:

Name Type Description Default
pixel_values jnp.ndarray

the input to the encoder.

required
params Optional[FrozenDict]

the params of the model. Defaults to None.

None
dropout_rng Optional[jnp.ndarray]

the dropout rng. Defaults to None.

None
gumble_rng Optional[jnp.ndarray]

the gumbel rng. Defaults to None. If gumble_rng is None then the defult rng is used and produce deterministic results.

None
train bool

Training or inference mode. Defaults to False.

False

Returns:

Type Description
jnp.ndarray

the encoded latent vector,

jnp.ndarray

the decoded image,

float

the log prob of the latent vector,

jnp.ndarray

the indices of the latent vector.

Source code in modules/vqgan.py
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
def __call__(
    self,
    pixel_values: jnp.ndarray,
    params: Optional[FrozenDict] = None,
    dropout_rng: Optional[jnp.ndarray] = None,
    gumble_rng: Optional[jnp.ndarray] = None,
    train: bool = False,
) -> Tuple[jnp.ndarray, jnp.ndarray, float, jnp.ndarray]:
    """Encode and decode the input.

    Args:
        pixel_values (jnp.ndarray): the input to the encoder.
        params (Optional[FrozenDict], optional): the params of the model. Defaults to None.
        dropout_rng (Optional[jnp.ndarray], optional): the dropout rng. Defaults to None.
        gumble_rng (Optional[jnp.ndarray], optional): the gumbel rng. Defaults to None.
            If gumble_rng is None then the defult rng is used and produce deterministic results.
        train (bool, optional): Training or inference mode. Defaults to False.

    Returns:
            the encoded latent vector,
            the decoded image,
            the log prob of the latent vector,
            the indices of the latent vector.
    """
    # Check dtype
    pixel_values = (
        pixel_values.astype(self.dtype) if pixel_values.dtype != self.dtype else pixel_values
    )
    # Handle any PRNG if needed
    rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
    rngs["gumbel"] = gumble_rng if gumble_rng is not None else self.key
    return self.module.apply(
        {"params": params or self.params}, pixel_values, not train, rngs=rngs
    )

__init__(config=VQGANConfig(), input_shape=(1, 256, 256, 3), seed=0, dtype=jnp.float32, _do_init=True, **kwargs)

Initialize the model.

Parameters:

Name Type Description Default
config PretrainedConfig

the config of the model. Defaults to VQGANConfig.

VQGANConfig()
input_shape Tuple

the input shape of the model. Defaults to (1, 256, 256, 3).

(1, 256, 256, 3)
seed int

the seed of the model. Defaults to 0.

0
dtype jnp.dtype

the dtype of the computation. Defaults to jnp.float32.

jnp.float32
_do_init bool

whether to initialize the model. Defaults to True.

True
Source code in modules/vqgan.py
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
def __init__(
    self,
    config: PretrainedConfig = VQGANConfig(),
    input_shape: Tuple = (1, 256, 256, 3),
    seed: int = 0,
    dtype: jnp.dtype = jnp.float32,
    _do_init: bool = True,
    **kwargs,
):
    """Initialize the model.

    Args:
        config (PretrainedConfig, optional): the config of the model. Defaults to VQGANConfig.
        input_shape (Tuple, optional): the input shape of the model.
            Defaults to (1, 256, 256, 3).
        seed (int, optional): the seed of the model. Defaults to 0.
        dtype (jnp.dtype, optional): the dtype of the computation. Defaults to jnp.float32.
        _do_init (bool, optional): whether to initialize the model. Defaults to True.
    """
    self._missing_keys: Set[str] = set()
    if not isinstance(config, self.config_class):
        raise ValueError(f"config: {config} has to be an instance of {self.config_class}")
    if self.module_class is None:
        raise NotImplementedError("module_class should be defined in derived classes")
    module = self.module_class(config=config, dtype=dtype, **kwargs)
    self.seed = seed
    self.dtype = dtype
    super().__init__(
        config,
        module,
        input_shape=input_shape,
        seed=seed,
        dtype=dtype,
        _do_init=_do_init,
    )

decode(z, params=None, dropout_rng=None, gumble_rng=None, train=False)

Decode the latent vector.

Parameters:

Name Type Description Default
z jnp.ndarray

the latent vector.

required
params Optional[FrozenDict]

the params of the model. Defaults to None.

None
dropout_rng Union[Any, jnp.ndarray]

the dropout rng. Defaults to None.

None
gumble_rng Union[Any, jnp.ndarray]

the gumbel rng. Defaults to None.

None
train bool

Training or inference mode. Defaults to False.

False

Returns:

Type Description
jnp.ndarray

the decoded image.

Source code in modules/vqgan.py
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
def decode(
    self,
    z: jnp.ndarray,
    params: Optional[FrozenDict] = None,
    dropout_rng: Optional[Union[Any, jnp.ndarray]] = None,
    gumble_rng: Optional[Union[Any, jnp.ndarray]] = None,
    train: bool = False,
) -> jnp.ndarray:
    """Decode the latent vector.

    Args:
        z (jnp.ndarray): the latent vector.
        params (Optional[FrozenDict], optional): the params of the model. Defaults to None.
        dropout_rng (Union[Any,jnp.ndarray], optional): the dropout rng. Defaults to None.
        gumble_rng (Union[Any,jnp.ndarray], optional): the gumbel rng. Defaults to None.
        train (bool, optional): Training or inference mode. Defaults to False.

    Returns:
        the decoded image.
    """
    # Handle any PRNG if needed
    rngs: Dict[str, Union[Any, jnp.ndarray]] = (
        {"dropout": dropout_rng} if dropout_rng is not None else {}
    )
    rngs["gumbel"] = gumble_rng if gumble_rng is not None else {}
    return self.module.apply(
        {"params": params or self.params},
        z,
        not train,
        rngs=rngs,
        method=self.module.decode,
    )

decode_code(indices, z_shape, params=None)

Decode the indices.

Parameters:

Name Type Description Default
indices jnp.ndarray

the indices.

required
z_shape Tuple[int, ...]

the shape of the latent vector.

required
params Optional[FrozenDict]

the params of the model. Defaults to None.

None

Returns:

Type Description
jnp.ndarray

the decoded image from indices.

Source code in modules/vqgan.py
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
def decode_code(
    self,
    indices: jnp.ndarray,
    z_shape: Tuple[int, ...],
    params: Optional[FrozenDict] = None,
) -> jnp.ndarray:
    """Decode the indices.

    Args:
        indices (jnp.ndarray): the indices.
        z_shape (Tuple[int, ...]): the shape of the latent vector.
        params (Optional[FrozenDict], optional): the params of the model. Defaults to None.

    Returns:
        the decoded image from indices.
    """
    return self.module.apply(
        {"params": params or self.params},
        indices,
        z_shape,
        method=self.module.decode_code,
    )

encode(pixel_values, params=None, dropout_rng=None, gumble_rng=None, train=False)

Encode the input.

Parameters:

Name Type Description Default
pixel_values jnp.ndarray

the input to the encoder.

required
params Optional[FrozenDict]

the params of the model. Defaults to None.

None
dropout_rng Union[Any, jnp.ndarray]

the dropout rng. Defaults to None.

None
gumble_rng Union[Any, jnp.ndarray]

the gumbel rng. Defaults to None.

None
train bool

Training or inference mode. Defaults to False.

False
Source code in modules/vqgan.py
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
def encode(
    self,
    pixel_values: jnp.ndarray,
    params: Optional[FrozenDict] = None,
    dropout_rng: Optional[Union[Any, jnp.ndarray]] = None,
    gumble_rng: Optional[Union[Any, jnp.ndarray]] = None,
    train: bool = False,
) -> Tuple[jnp.ndarray, float, jnp.ndarray]:
    """Encode the input.
    Args:
        pixel_values (jnp.ndarray): the input to the encoder.
        params (Optional[FrozenDict], optional): the params of the model. Defaults to None.
        dropout_rng (Union[Any,jnp.ndarray], optional): the dropout rng. Defaults to None.
        gumble_rng (Union[Any,jnp.ndarray], optional): the gumbel rng. Defaults to None.
        train (bool, optional): Training or inference mode. Defaults to False.
    """
    # Handle any PRNG if needed
    rngs: Dict[str, Union[Any, jnp.ndarray]] = (
        {"dropout": dropout_rng} if dropout_rng is not None else {}
    )
    rngs["gumbel"] = gumble_rng if gumble_rng is not None else {}
    return self.module.apply(
        {"params": params or self.params},
        pixel_values,
        not train,
        rngs=rngs,
        method=self.module.encode,
    )

init_weights(rng, input_shape, params=None)

Initialize the weights of the model. Get the params

Parameters:

Name Type Description Default
rng Union[Any, jnp.ndarray]

the random number generator.

required
input_shape Tuple

the input shape of the model.

required
params FrozenDict

the params of the model. Defaults to None.

None

Returns:

Type Description
FrozenDict[str, Any]

initialized params of the model.

Source code in modules/vqgan.py
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
def init_weights(
    self,
    rng: Union[Any, jnp.ndarray],
    input_shape: Tuple,
    params: Optional[FrozenDict[str, Any]] = None,
) -> FrozenDict[str, Any]:
    """Initialize the weights of the model. Get the params

    Args:
        rng (Union[Any,jnp.ndarray]): the random number generator.
        input_shape (Tuple): the input shape of the model.
        params (FrozenDict, optional): the params of the model. Defaults to None.

    Returns:
        initialized params of the model.
    """
    # initialize model
    input_x = jnp.zeros(input_shape, dtype=self.dtype)
    params_rng, dropout_rng, gumble_rng = jax.random.split(rng, num=3)
    rngs: Dict[str, Union[Any, jnp.ndarray]] = {
        "params": params_rng,
        "dropout": dropout_rng,
        "gumbel": gumble_rng,
    }

    random_params = self.module.init(rngs, input_x, True)["params"]

    # If params provided find unitialized params and replace with provided params
    if params is not None:
        random_params = flatten_dict(unfreeze(random_params))
        params = flatten_dict(unfreeze(params))
        for missing_key in self._missing_keys:
            params[missing_key] = random_params[missing_key]  # type: ignore
        self._missing_keys = set()
        return freeze(unflatten_dict(params))
    else:
        return random_params

update_temperature(temperature, params=None)

Update the temperature of the model.

Parameters:

Name Type Description Default
temperature float

the temperature to update to.

required
params Optional[FrozenDict]

the params of the model. Defaults to None.

None

Returns:

Type Description
float

the updated temperature.

Source code in modules/vqgan.py
494
495
496
497
498
499
500
501
502
503
504
505
506
507
def update_temperature(self, temperature: float, params: Optional[FrozenDict] = None) -> float:
    """Update the temperature of the model.
    Args:
        temperature (float): the temperature to update to.
        params (Optional[FrozenDict], optional): the params of the model. Defaults to None.
    Returns:
        the updated temperature.
    """
    new_temperature = self.module.apply(
        {"params": params or self.params},
        temperature,
        method=self.module.update_temperature,
    )
    return new_temperature

VQGanDiscriminator

Bases: FlaxPreTrainedModel

VQGAN discriminator model.

Attributes:

Name Type Description
module_class nn.Module

the discriminator module class (NLayerDiscriminator).

Source code in modules/vqgan.py
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
class VQGanDiscriminator(FlaxPreTrainedModel):
    """VQGAN discriminator model.

    Attributes:
        module_class (nn.Module): the discriminator module class (NLayerDiscriminator).
    """

    module_class: nn.Module = NLayerDiscriminator  # type: ignore

    def __init__(
        self,
        config: DiscConfig = DiscConfig(),
        input_shape: Tuple = (1, 256, 256, 3),
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
        _do_init: bool = True,
        **kwargs,
    ):
        self._missing_keys: Set[str] = set()
        module = self.module_class(
            ndf=config.ndf,
            n_layers=config.n_layers,
            output_dim=config.output_last_dim,
            dtype=dtype,
            **kwargs,
        )
        super().__init__(
            config,
            module,
            input_shape=input_shape,
            seed=seed,
            dtype=dtype,
            _do_init=_do_init,
        )

    def init_weights(
        self,
        params_rng: jnp.ndarray,
        input_shape: Tuple,
        params: Optional[FrozenDict[str, Any]] = None,
    ) -> FrozenDict[str, Any]:
        # initialize model
        input_x = jnp.zeros(input_shape, dtype=self.dtype)
        random_params = self.module.init(params_rng, input_x, True)

        # If params provided find unitialized params and replace with provided params
        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]  # type: ignore
            self._missing_keys = set()
            return freeze(unflatten_dict(params))
        else:
            return random_params

    def __call__(
        self,
        input: jnp.ndarray,
        params: Optional[FrozenDict] = None,
        batch_stats: Optional[FrozenDict] = None,
        train: bool = False,
    ) -> Tuple[jnp.ndarray, jnp.ndarray, float, jnp.ndarray]:
        # Handle any PRNG if needed
        if batch_stats is not None and params is not None:
            dict_params = {"params": params, "batch_stats": batch_stats}
        elif batch_stats is not None:
            dict_params = {"params": self.params["params"], "batch_stats": batch_stats}
        elif params is not None:
            dict_params = {"params": params, "batch_stats": self.params["batch_stats"]}
        else:
            dict_params = {
                "params": self.params["params"],
                "batch_stats": self.params["batch_stats"],
            }
        return self.module.apply(
            dict_params, input, train=train, mutable=["batch_stats"] if train else False
        )

VQModel

Bases: VQGANPreTrainedModel

VQ-VAE model from pre-trained VQGAN.

Source code in modules/vqgan.py
545
546
547
548
549
class VQModel(VQGANPreTrainedModel):
    """VQ-VAE model from pre-trained VQGAN."""

    module_class = VQModule  # type: ignore
    config_class = VQGANConfig

VQModule

Bases: nn.Module

VQ-VAE module.

See

https://arxiv.org/abs/1711.00937v2

Attributes:

Name Type Description
config VQGANConfig

the config of the model.

dtype jnp.dtype

the dtype of the computation (default: float32).

Source code in modules/vqgan.py
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
class VQModule(nn.Module):
    """VQ-VAE module.
    See:
        https://arxiv.org/abs/1711.00937v2

    Attributes:
        config (VQGANConfig): the config of the model.
        dtype (jnp.dtype): the dtype of the computation (default: float32).
    """

    config: VQGANConfig
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        """Setup the VQ-VAE module."""
        # Set activation function
        act_fn: Callable = ACTFUN[self.config.act_name]  # type: ignore
        # Encoder
        self.encoder = Encoder(self.config, act_fn=act_fn, dtype=self.dtype)
        # Map last channel of encoder to embedding dim for VQ
        self.pre_quantizer = nn.Conv(
            self.config.embed_dim,
            kernel_size=(1, 1),
            strides=(1, 1),
            padding="VALID",
            dtype=self.dtype,
        )
        # Which quantizer to use
        self.quantizer: Union[VectorQuantizer, GumbelQuantize]
        if self.config.use_gumbel:
            self.quantizer = GumbelQuantize(self.config, dtype=self.dtype)
        else:
            self.quantizer = VectorQuantizer(self.config, dtype=self.dtype)
        # Map last channel of VQ to z channels dim
        self.post_quantizer = nn.Conv(
            self.config.z_channels,
            kernel_size=(1, 1),
            strides=(1, 1),
            padding="VALID",
            dtype=self.dtype,
        )
        # Decoder
        self.decoder = Decoder(self.config, act_fn=act_fn, dtype=self.dtype)

    def encode(
        self, x: jnp.ndarray, deterministic: bool = True
    ) -> Tuple[jnp.ndarray, float, jnp.ndarray]:
        """Encode the input.
        Args:
            x (jnp.ndarray): the input to the encoder.
        Returns:
            the encoded input, the loss and the indices.
        """
        # Encoder
        z = self.encoder(x, deterministic=deterministic)
        # Pre-quantizer
        z = self.pre_quantizer(z)
        # Quantizer
        z_q, q_loss, indices = self.quantizer(z)
        # Post-quantizer
        z_q = self.post_quantizer(z_q)
        return z_q, q_loss, indices

    def decode(self, z_q: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
        """Decode the quantized latent vector.

        Args:
            z_q (jnp.ndarray): the quantized latent vector.
            deterministic (bool, optional): for Dropout. Defaults to True.

        Returns:
            the reconstructed image.
        """
        # Post-quantizer
        z_q = self.post_quantizer(z_q)
        # Decoder
        x_recon = self.decoder(z_q, deterministic=deterministic)
        return x_recon

    def decode_code(self, code: jnp.ndarray, z_shape: Tuple[int, ...]) -> jnp.ndarray:
        """Decode already created z_code"""
        params = self.variables["params"]["quantizer"]
        z_q = self.quantizer.get_codebook_entry(params=params, indices=code, shape=z_shape)
        x = self.decode(z_q, deterministic=True)
        return x

    def update_temperature(self, temperature: float) -> float:
        """Update the temperature of the Gumbel-Softmax distribution.

        Args:
            temperature (float): the new temperature of the Gumbel-Softmax distribution
        Returns:
            the new temperature of the Gumbel-Softmax distribution
        """
        self.quantizer.config.gumb_temp = temperature
        return self.quantizer.config.gumb_temp

    def __call__(
        self, x: jnp.ndarray, deterministic: bool = True
    ) -> Tuple[jnp.ndarray, jnp.ndarray, float, jnp.ndarray]:
        z_q, q_loss, indices = self.encode(x, deterministic=deterministic)
        x_recon = self.decode(z_q, deterministic=deterministic)
        return x_recon, z_q, q_loss, indices

decode(z_q, deterministic=True)

Decode the quantized latent vector.

Parameters:

Name Type Description Default
z_q jnp.ndarray

the quantized latent vector.

required
deterministic bool

for Dropout. Defaults to True.

True

Returns:

Type Description
jnp.ndarray

the reconstructed image.

Source code in modules/vqgan.py
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
def decode(self, z_q: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
    """Decode the quantized latent vector.

    Args:
        z_q (jnp.ndarray): the quantized latent vector.
        deterministic (bool, optional): for Dropout. Defaults to True.

    Returns:
        the reconstructed image.
    """
    # Post-quantizer
    z_q = self.post_quantizer(z_q)
    # Decoder
    x_recon = self.decoder(z_q, deterministic=deterministic)
    return x_recon

decode_code(code, z_shape)

Decode already created z_code

Source code in modules/vqgan.py
295
296
297
298
299
300
def decode_code(self, code: jnp.ndarray, z_shape: Tuple[int, ...]) -> jnp.ndarray:
    """Decode already created z_code"""
    params = self.variables["params"]["quantizer"]
    z_q = self.quantizer.get_codebook_entry(params=params, indices=code, shape=z_shape)
    x = self.decode(z_q, deterministic=True)
    return x

encode(x, deterministic=True)

Encode the input.

Parameters:

Name Type Description Default
x jnp.ndarray

the input to the encoder.

required

Returns:

Type Description
Tuple[jnp.ndarray, float, jnp.ndarray]

the encoded input, the loss and the indices.

Source code in modules/vqgan.py
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
def encode(
    self, x: jnp.ndarray, deterministic: bool = True
) -> Tuple[jnp.ndarray, float, jnp.ndarray]:
    """Encode the input.
    Args:
        x (jnp.ndarray): the input to the encoder.
    Returns:
        the encoded input, the loss and the indices.
    """
    # Encoder
    z = self.encoder(x, deterministic=deterministic)
    # Pre-quantizer
    z = self.pre_quantizer(z)
    # Quantizer
    z_q, q_loss, indices = self.quantizer(z)
    # Post-quantizer
    z_q = self.post_quantizer(z_q)
    return z_q, q_loss, indices

setup()

Setup the VQ-VAE module.

Source code in modules/vqgan.py
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
def setup(self):
    """Setup the VQ-VAE module."""
    # Set activation function
    act_fn: Callable = ACTFUN[self.config.act_name]  # type: ignore
    # Encoder
    self.encoder = Encoder(self.config, act_fn=act_fn, dtype=self.dtype)
    # Map last channel of encoder to embedding dim for VQ
    self.pre_quantizer = nn.Conv(
        self.config.embed_dim,
        kernel_size=(1, 1),
        strides=(1, 1),
        padding="VALID",
        dtype=self.dtype,
    )
    # Which quantizer to use
    self.quantizer: Union[VectorQuantizer, GumbelQuantize]
    if self.config.use_gumbel:
        self.quantizer = GumbelQuantize(self.config, dtype=self.dtype)
    else:
        self.quantizer = VectorQuantizer(self.config, dtype=self.dtype)
    # Map last channel of VQ to z channels dim
    self.post_quantizer = nn.Conv(
        self.config.z_channels,
        kernel_size=(1, 1),
        strides=(1, 1),
        padding="VALID",
        dtype=self.dtype,
    )
    # Decoder
    self.decoder = Decoder(self.config, act_fn=act_fn, dtype=self.dtype)

update_temperature(temperature)

Update the temperature of the Gumbel-Softmax distribution.

Parameters:

Name Type Description Default
temperature float

the new temperature of the Gumbel-Softmax distribution

required

Returns:

Type Description
float

the new temperature of the Gumbel-Softmax distribution

Source code in modules/vqgan.py
302
303
304
305
306
307
308
309
310
311
def update_temperature(self, temperature: float) -> float:
    """Update the temperature of the Gumbel-Softmax distribution.

    Args:
        temperature (float): the new temperature of the Gumbel-Softmax distribution
    Returns:
        the new temperature of the Gumbel-Softmax distribution
    """
    self.quantizer.config.gumb_temp = temperature
    return self.quantizer.config.gumb_temp

VectorQuantizer

Bases: nn.Module

Discretization bottleneck part of the VQ-VAE. Module get z lattent vector (Encoder output) and maps it to a discrete one-hot vector that is the index of the closest embedding vector e_j z (continuous) -> z_q (discrete) z.shape = (batch, height, width, channel)

quantization pipeline
  1. get encoder input (B,H,W,C)
  2. flatten input to (BHW,C)
See

https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py

Attributes:

Name Type Description
config VQGANConfig

the config of the model.

dtype jnp.dtype

the dtype of the computation for embeddings (default: float32).

Config Attributes

n_embed (int) : number of embeddings. emb_dim (int): dimension of embedding. beta (float): weight of commitment loss.

Source code in modules/vqgan.py
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
class VectorQuantizer(nn.Module):
    """
    Discretization bottleneck part of the VQ-VAE.
    Module get z lattent vector (Encoder output)
    and maps it to a discrete one-hot vector
    that is the index of the closest embedding vector e_j
    z (continuous) -> z_q (discrete)
    z.shape = (batch, height, width, channel)
    quantization pipeline:
        1. get encoder input (B,H,W,C)
        2. flatten input to (B*H*W,C)

    See:
        https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py

    Attributes:
        config (VQGANConfig): the config of the model.
        dtype (jnp.dtype): the dtype of the computation for embeddings (default: float32).

    Config Attributes:
        n_embed (int) : number of embeddings.
        emb_dim (int): dimension of embedding.
        beta (float): weight of commitment loss.
    """

    config: VQGANConfig
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        init_embbeding = jax.nn.initializers.uniform(
            scale=-1.0 / self.config.n_embed, dtype=self.dtype
        )
        self.embedding = nn.Embed(
            self.config.n_embed,
            self.config.embed_dim,
            embedding_init=init_embbeding,
            dtype=self.dtype,
        )

    def __call__(self, z: jnp.ndarray) -> Tuple[jnp.ndarray, float, jnp.ndarray]:
        # flatten z
        z_flatten = z.reshape(-1, self.config.embed_dim)

        # dummy op to init the weights, so we can access them below
        self.embedding(jnp.ones((1, 1), dtype="i4"))

        # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
        emb_weights = self.variables["params"]["embedding"]["embedding"]
        distance = (
            jnp.sum(z_flatten**2, axis=1, keepdims=True)
            + jnp.sum(emb_weights**2, axis=1)
            - 2 * jnp.dot(z_flatten, emb_weights.T)
        )

        # get quantized latent vectors
        min_encoding_indices = jnp.argmin(distance, axis=1)
        z_q = self.embedding(min_encoding_indices).reshape(z.shape)

        # reshape to (batch, num_tokens)
        min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1)

        # compute the codebook_loss (q_loss)
        q_loss = self.config.beta * jnp.mean((jax.lax.stop_gradient(z_q) - z) ** 2) + jnp.mean(
            (z_q - jax.lax.stop_gradient(z)) ** 2
        )

        # here we return the embeddings and indices
        return z_q, q_loss, min_encoding_indices

    @staticmethod
    def get_codebook_entry(
        params: FrozenDict, indices: jnp.ndarray, shape: Optional[Tuple[int, ...]] = None
    ) -> jnp.ndarray:
        """Get the codebook entry for a given index.
        Input is expected to be of shape (batch, num_tokens)"""
        # indices are expected to be of shape (batch, num_tokens)
        # get quantized latent vectors
        B = indices.shape[0]
        indices = indices.reshape(
            -1,
        )
        emb_weights = params["embedding"]["embedding"]
        z_q = jnp.take(emb_weights, indices, axis=0).reshape(B, -1)
        # z_q = self.embedding(indices) can't be used because of not accessibility of self.embedding

        if shape is not None:
            z_q = z_q.reshape(shape)

        return z_q

get_codebook_entry(params, indices, shape=None) staticmethod

Get the codebook entry for a given index. Input is expected to be of shape (batch, num_tokens)

Source code in modules/vqgan.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
@staticmethod
def get_codebook_entry(
    params: FrozenDict, indices: jnp.ndarray, shape: Optional[Tuple[int, ...]] = None
) -> jnp.ndarray:
    """Get the codebook entry for a given index.
    Input is expected to be of shape (batch, num_tokens)"""
    # indices are expected to be of shape (batch, num_tokens)
    # get quantized latent vectors
    B = indices.shape[0]
    indices = indices.reshape(
        -1,
    )
    emb_weights = params["embedding"]["embedding"]
    z_q = jnp.take(emb_weights, indices, axis=0).reshape(B, -1)
    # z_q = self.embedding(indices) can't be used because of not accessibility of self.embedding

    if shape is not None:
        z_q = z_q.reshape(shape)

    return z_q