PhysAugNet 1.0.1
VQ-VQE powered augmentation for metal defect segmentation
Loading...
Searching...
No Matches
gen_vqvae.py
Go to the documentation of this file.
1import os
2import torch
3from physaug.vqvae.infer import reconstruct_folder
4from physaug.utils.logger import setup_logger
5
6def main(input_dir, output_dir, checkpoint, config_path="configs/default.yaml"):
7 cfg = load_config(config_path)
8 logger = setup_logger("gen_vqvae", cfg["log_dir"])
9 reconstruct_folder(input_dir, output_dir, checkpoint, cfg["vqvae"]["image_size"])
10 logger.info(f"Reconstructed images saved to {output_dir}")
11
12if __name__ == "__main__":
13 import argparse
14 parser = argparse.ArgumentParser()
15 parser.add_argument("--input_dir", required=True)
16 parser.add_argument("--output_dir", required=True)
17 parser.add_argument("--checkpoint", required=True)
18 parser.add_argument("--config", type=str, default="configs/default.yaml")
19 args = parser.parse_args()
20 main(args.input_dir, args.output_dir, args.checkpoint, args.config)
Definition main.py:1