PhysAugNet 1.0.1
VQ-VQE powered augmentation for metal defect segmentation
Loading...
Searching...
No Matches
infer.py
Go to the documentation of this file.
1import os
2import torch
3from torchvision.utils import save_image
4from torchvision import transforms
5from PIL import Image
6from .vqvae import VQVAE
7from physaug.utils.logger import setup_logger
8from physaug.utils.config import load_config
9
10def reconstruct_folder(input_dir, output_dir, model_path, image_size=(128, 128), config_path="configs/default.yaml"):
11 cfg = load_config(config_path)
12 logger = setup_logger("infer", cfg["log_dir"])
13 device = "cuda" if torch.cuda.is_available() else "cpu"
14 model = VQVAE().to(device)
15 model.load_state_dict(torch.load(model_path, map_location=device))
16 model.eval()
17 transform = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
18 files = [f for f in os.listdir(input_dir) if f.lower().endswith(('png', 'jpg', 'jpeg'))]
19 os.makedirs(output_dir, exist_ok=True)
20 for fname in files:
21 img = Image.open(f"{input_dir}/{fname}").convert("RGB")
22 x = transform(img).unsqueeze(0).to(device)
23 with torch.no_grad():
24 recon, _, _ = model(x)
25 save_image(recon.clamp(0, 1), f"{output_dir}/{fname}")
26 logger.info(f"Reconstructed images saved to {output_dir}")
27
28if __name__ == "__main__":
29 import argparse
30 parser = argparse.ArgumentParser()
31 parser.add_argument("--input_dir", required=True)
32 parser.add_argument("--output_dir", required=True)
33 parser.add_argument("--model_path", required=True)
34 parser.add_argument("--config", type=str, default="configs/default.yaml")
35 args = parser.parse_args()
36 reconstruct_folder(args.input_dir, args.output_dir, args.model_path, tuple(args.image_size), args.config)
reconstruct_folder(input_dir, output_dir, model_path, image_size=(128, 128), config_path="configs/default.yaml")
Definition infer.py:10