PhysAugNet 1.0.1
VQ-VQE powered augmentation for metal defect segmentation
Loading...
Searching...
No Matches
combined.py
Go to the documentation of this file.
1from .thermal import apply_thermal_augmentation
2from .grain import add_grain
3from torchvision.utils import save_image
4from ..utils.io import load_image_folder
5
6def apply_combined_augmentation(input_dir, output_dir, checkpoint, config_path="configs/default.yaml"):
7 from .vqvae import VQVAE
8 import torch
9 from physaug.utils.config import load_config
10 from physaug.utils.logger import setup_logger
11 cfg = load_config(config_path)
12 logger = setup_logger("combined", cfg["log_dir"])
13 device = "cuda" if torch.cuda.is_available() else "cpu"
14 model = VQVAE().to(device)
15 model.load_state_dict(torch.load(checkpoint, map_location=device))
16 model.eval()
17 images, names = load_image_folder(input_dir, cfg["vqvae"]["image_size"])
18 for img, name in zip(images, names):
19 img = img.unsqueeze(0).to(device)
20 with torch.no_grad():
21 recon, _, _ = model(img)
22 recon = recon.squeeze(0).cpu()
23 aug = apply_thermal_augmentation(recon)
24 aug = add_grain(aug)
25 save_image(aug, f"{output_dir}/{name}")
26 logger.info(f"Combined augmentations saved to {output_dir}")
apply_combined_augmentation(input_dir, output_dir, checkpoint, config_path="configs/default.yaml")
Definition combined.py:6