PhysAugNet 1.0.1
VQ-VQE powered augmentation for metal defect segmentation
Loading...
Searching...
No Matches
vqvae.py
Go to the documentation of this file.
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5class VectorQuantizer(nn.Module):
6 def __init__(self, num_embeddings=512, embedding_dim=64, commitment_cost=0.25):
7 super().__init__()
8 self.embedding = nn.Embedding(num_embeddings, embedding_dim)
9 self.embedding.weight.data.uniform_(-1/num_embeddings, 1/num_embeddings)
10 self.commitment_cost = commitment_cost
11
12 def forward(self, x):
13 flat_x = x.permute(0, 2, 3, 1).reshape(-1, x.size(1))
14 distances = torch.cdist(flat_x, self.embedding.weight)
15 encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
16 encodings = torch.zeros(encoding_indices.size(0), self.embedding.num_embeddings, device=x.device)
17 encodings.scatter_(1, encoding_indices, 1)
18 quantized = torch.matmul(encodings, self.embedding.weight).view(x.shape)
19 loss = F.mse_loss(quantized.detach(), x) * self.commitment_cost + F.mse_loss(quantized, x.detach())
20 quantized = x + (quantized - x).detach()
21 return quantized, loss, encoding_indices.view(x.shape[0], x.shape[2], x.shape[3])
22
23class VQVAE(nn.Module):
24 def __init__(self, in_channels=3, embedding_dim=64, num_embeddings=512):
25 super().__init__()
26 self.encoder = nn.Sequential(
27 nn.Conv2d(in_channels, 32, 4, 2, 1), nn.ReLU(),
28 nn.Conv2d(32, embedding_dim, 4, 2, 1), nn.ReLU(),
29 )
30 self.vq = VectorQuantizer(num_embeddings, embedding_dim)
31 self.decoder = nn.Sequential(
32 nn.ConvTranspose2d(embedding_dim, 64, 4, 2, 1), nn.ReLU(),
33 nn.ConvTranspose2d(64, 32, 4, 2, 1), nn.ReLU(),
34 nn.Conv2d(32, in_channels, 3, 1, 1), nn.Sigmoid(),
35 )
36
37 def forward(self, x):
38 z = self.encoder(x)
39 quantized, vq_loss, _ = self.vq(z)
40 recon = self.decoder(quantized)
41 return recon, vq_loss, quantized
forward(self, x)
Definition vqvae.py:37
__init__(self, in_channels=3, embedding_dim=64, num_embeddings=512)
Definition vqvae.py:24
__init__(self, num_embeddings=512, embedding_dim=64, commitment_cost=0.25)
Definition vqvae.py:6
forward(self, x)
Definition vqvae.py:12