PhysAugNet 1.0.1
VQ-VQE powered augmentation for metal defect segmentation
Loading...
Searching...
No Matches
gen_vqvae.py
Go to the documentation of this file.
1import argparse
2import os
3from physaug.vqvae.infer import reconstruct_folder
4from physaug.utils.io import load_image_folder, save_image
5from tqdm import tqdm
6import torch
7
8
10 parser = argparse.ArgumentParser(description="Generate VQ-VAE reconstructions")
11 parser.add_argument("--input_dir", type=str, required=True, help="Directory with input images")
12 parser.add_argument("--output_dir", type=str, required=True, help="Where to save reconstructed images")
13 parser.add_argument("--checkpoint", type=str, required=True, help="Path to VQ-VAE model checkpoint")
14 parser.add_argument("--mode", type=str, choices=["rgb", "gray"], default="rgb", help="Image mode")
15 parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to run on")
16 return parser.parse_args()
17
18
19def main():
20 args = parse_args()
21 os.makedirs(args.output_dir, exist_ok=True)
22
23 images, filenames = load_image_folder(args.input_dir, mode=args.mode)
24
25 reconstructions = reconstruct_folder(images, args.checkpoint, device=args.device)
26
27 for rec, name in tqdm(zip(reconstructions, filenames), total=len(images)):
28 save_path = os.path.join(args.output_dir, name)
29 save_image(rec, save_path, mode=args.mode)
30
31
32if __name__ == "__main__":
33 main()
parse_args()
Definition gen_vqvae.py:9
Definition main.py:1