PhysAugNet 1.0.1
VQ-VQE powered augmentation for metal defect segmentation
Loading...
Searching...
No Matches
train.py
Go to the documentation of this file.
1import torch
2from torch.utils.data import DataLoader
3from torchvision import transforms, datasets
4from .vqvae import VQVAE
5from physaug.utils.logger import setup_logger
6from physaug.utils.io import save_image
7
9 def __init__(self, config):
10 self.cfg = config
11 self.device = "cuda" if torch.cuda.is_available() else "cpu"
12 self.logger = setup_logger("vqvae_trainer", config["log_dir"])
13 self.model = VQVAE().to(self.device)
14 lr = float(config["vqvae"]["learning_rate"]) # Convert string to float
15 self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
16 self.criterion = torch.nn.MSELoss()
17
18 def get_dataloader(self):
19 transform = transforms.Compose([transforms.Resize(tuple(self.cfg["vqvae"]["image_size"])), transforms.ToTensor()])
20 dataset = datasets.ImageFolder(root=self.cfg["dataset_dir"], transform=transform)
21 return DataLoader(dataset, batch_size=self.cfg["vqvae"]["batch_size"], shuffle=True, num_workers=self.cfg["vqvae"]["num_workers"])
22
23 def train(self):
24 dataloader = self.get_dataloader()
25 self.model.train()
26 for epoch in range(self.cfg["vqvae"]["num_epochs"]):
27 running_loss = 0
28 for imgs, _ in dataloader:
29 imgs = imgs.to(self.device)
30 recon, vq_loss, _ = self.model(imgs)
31 recon_loss = self.criterion(recon, imgs)
32 loss = recon_loss + vq_loss
33 self.optimizer.zero_grad()
34 loss.backward()
35 self.optimizer.step()
36 running_loss += loss.item()
37 avg_loss = running_loss / len(dataloader)
38 self.logger.info(f"Epoch {epoch+1}/{self.cfg['vqvae']['num_epochs']}: Loss={avg_loss:.4f}")
39 if (epoch + 1) % self.cfg["vqvae"]["save_interval"] == 0:
40 ckpt_path = f"{self.cfg['vqvae']['checkpoint_dir']}/vqvae_{epoch+1}.pth"
41 torch.save(self.model.state_dict(), ckpt_path)
42 self.logger.info(f"Saved checkpoint: {ckpt_path}")
get_dataloader(self)
Definition train.py:18
__init__(self, config)
Definition train.py:9
Definition train.py:1