PhysAugNet 1.0.1
VQ-VQE powered augmentation for metal defect segmentation
Loading...
Searching...
No Matches
train_vqvae.py
Go to the documentation of this file.
1import torch
2from physaug.vqvae.vqvae import VQVAE
3from physaug.utils.config import load_config
4from physaug.utils.io import get_dataloaders
5from physaug.utils.logger import setup_logger
6
7def main(config_path="configs/default.yaml"):
8 cfg = load_config(config_path)
9 logger = setup_logger("train_vqvae", cfg["log_dir"])
10 device = "cuda" if torch.cuda.is_available() else "cpu"
11
12 train_loader, _ = get_dataloaders(cfg)
13 model = VQVAE().to(device)
14 optimizer = torch.optim.Adam(model.parameters(), lr=cfg["vqvae"]["learning_rate"])
15
16 for epoch in range(cfg["vqvae"]["num_epochs"]):
17 model.train()
18 running_loss = 0
19 for batch in train_loader:
20 imgs = batch[0].to(device)
21 recon, vq_loss, _ = model(imgs)
22 loss = torch.nn.MSELoss()(recon, imgs) + vq_loss
23 optimizer.zero_grad()
24 loss.backward()
25 optimizer.step()
26 running_loss += loss.item()
27 logger.info(f"Epoch {epoch+1}/{cfg['vqvae']['num_epochs']}: Loss={running_loss/len(train_loader):.4f}")
28 if (epoch + 1) % cfg["vqvae"]["save_interval"] == 0:
29 torch.save(model.state_dict(), f"{cfg['vqvae']['checkpoint_dir']}/vqvae_{epoch+1}.pth")
30 logger.info(f"Saved checkpoint: {cfg['vqvae']['checkpoint_dir']}/vqvae_{epoch+1}.pth")
31
32if __name__ == "__main__":
33 import argparse
34 parser = argparse.ArgumentParser()
35 parser.add_argument("--config", type=str, default="configs/default.yaml")
36 args = parser.parse_args()
37 main(args.config)
Definition main.py:1