PhysAugNet 1.0.1
VQ-VQE powered augmentation for metal defect segmentation
Loading...
Searching...
No Matches
augment_combined.py
Go to the documentation of this file.
1import os
2import torch
3from physaug.vqvae.vqvae import VQVAE
4from physaug.augment.thermal import apply_thermal_augmentation
5from physaug.augment.grain import add_grain
6from physaug.utils.io import load_image_folder, save_image
7from physaug.utils.logger import setup_logger
8from physaug.utils.config import load_config
9
10def main(input_dir, output_dir, checkpoint, config_path="configs/default.yaml"):
11 cfg = load_config(config_path)
12 logger = setup_logger("augment_combined", cfg["log_dir"])
13 device = "cuda" if torch.cuda.is_available() else "cpu"
14 model = VQVAE().to(device)
15 model.load_state_dict(torch.load(checkpoint, map_location=device))
16 model.eval()
17 os.makedirs(output_dir, exist_ok=True)
18 images, names = load_image_folder(input_dir, cfg["vqvae"]["image_size"])
19 for img, name in zip(images, names):
20 img = img.unsqueeze(0).to(device)
21 with torch.no_grad():
22 recon, _, _ = model(img)
23 recon = recon.squeeze(0).cpu()
24 aug = apply_thermal_augmentation(recon)
25 aug = add_grain(aug)
26 save_image(aug, f"{output_dir}/{name}")
27 logger.info(f"Combined augmentations saved to {output_dir}")
28
29if __name__ == "__main__":
30 import argparse
31 parser = argparse.ArgumentParser()
32 parser.add_argument("--input_dir", required=True)
33 parser.add_argument("--output_dir", required=True)
34 parser.add_argument("--vqvae_ckpt", required=True)
35 parser.add_argument("--config", type=str, default="configs/default.yaml")
36 args = parser.parse_args()
37 main(args.input_dir, args.output_dir, args.vqvae_ckpt, args.config)
Definition main.py:1