PhysAugNet 1.0.1
VQ-VQE powered augmentation for metal defect segmentation
Loading...
Searching...
No Matches
train_vqvae.py
Go to the documentation of this file.
1import argparse
2import yaml
3import torch
4import os
5from physaug.vqvae.train import VQVAETrainer
6from physaug.vqvae.vqvae import VQVAE
7from physaug.utils.logger import setup_logging
8from physaug.utils.config import load_config
9from physaug.utils.io import get_dataloaders
10
11
13 parser = argparse.ArgumentParser(description="Train VQ-VAE on metal defect images")
14
15 parser.add_argument("--config", type=str, help="Path to YAML config file")
16 parser.add_argument("--resume", type=str, default=None, help="Path to checkpoint to resume from")
17 parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
18
19 return parser.parse_args()
20
21
22def main():
23 args = parse_args()
24 cfg = load_config(args.config)
25
26 os.makedirs(cfg["log_dir"], exist_ok=True)
27 os.makedirs(cfg["ckpt_dir"], exist_ok=True)
28
29 logger, writer = setup_logging(cfg)
30 logger.info("Starting VQ-VAE training")
31
32 train_loader, val_loader = get_dataloaders(cfg)
33 model = VQVAE(in_channels=3 if cfg["mode"] == "rgb" else 1)
34 model.to(args.device)
35
36 trainer = VQVAETrainer(
37 model=model,
38 train_loader=train_loader,
39 val_loader=val_loader,
40 cfg=cfg,
41 writer=writer,
42 logger=logger,
43 resume_path=args.resume,
44 device=args.device
45 )
46
47 trainer.train()
48
49
50if __name__ == "__main__":
51 main()
Definition main.py:1