PhysAugNet 1.0.1
VQ-VQE powered augmentation for metal defect segmentation
Loading...
Searching...
No Matches
augment_combined.py
Go to the documentation of this file.
1import os
2import argparse
3import torch
4from torchvision.utils import save_image
5from torchvision import transforms
6from PIL import Image
7from tqdm import tqdm
8
9from physaug.vqvae.vqvae import VQVAE
10from physaug.augment.thermal import apply_thermal_augmentation
11from physaug.augment.grain import apply_grain_noise
12from physaug.utils.io import load_images_from_folder
13from physaug.utils.logger import get_logger
14
15
16def load_vqvae_model(checkpoint_path, device):
17 model = VQVAE(img_channels=3, hidden_channels=128, embedding_dim=64, num_embeddings=512)
18 model.load_state_dict(torch.load(checkpoint_path, map_location=device))
19 model.to(device)
20 model.eval()
21 return model
22
23
24def augment_images(model, input_dir, output_dir, device, apply_grain):
25 os.makedirs(output_dir, exist_ok=True)
26 transform = transforms.ToTensor()
27 image_paths = load_images_from_folder(input_dir)
28
29 for img_path in tqdm(image_paths, desc="Processing VQ+Thermal(+Grain)"):
30 img_name = os.path.basename(img_path)
31 image = Image.open(img_path).convert("RGB")
32 tensor = transform(image).unsqueeze(0).to(device)
33
34 with torch.no_grad():
35 recon, _ = model(tensor)
36
37 # Remove batch
38 recon = recon.squeeze(0).cpu().clamp(0, 1)
39
40 # Apply thermal
41 thermal = apply_thermal_augmentation(recon)
42
43 # Apply optional grain
44 if apply_grain:
45 thermal = apply_grain_noise(thermal)
46
47 save_image(thermal, os.path.join(output_dir, img_name))
48
49
50if __name__ == '__main__':
51 parser = argparse.ArgumentParser(description="VQ-VAE + Thermal + Grain Augmentation")
52 parser.add_argument('--input_dir', type=str, required=True, help='Input folder')
53 parser.add_argument('--output_dir', type=str, required=True, help='Output folder')
54 parser.add_argument('--vqvae_ckpt', type=str, required=True, help='Path to trained VQ-VAE checkpoint')
55 parser.add_argument('--device', type=str, default='cuda', help='cuda or cpu')
56 parser.add_argument('--apply_grain', action='store_true', help='Apply grain noise after thermal')
57 args = parser.parse_args()
58
59 logger = get_logger("augment_combined")
60 logger.info(f"Input: {args.input_dir}")
61 logger.info(f"VQ-VAE Checkpoint: {args.vqvae_ckpt}")
62
63 device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
64 model = load_vqvae_model(args.vqvae_ckpt, device)
65
66 augment_images(model, args.input_dir, args.output_dir, device, args.apply_grain)
67
68 logger.info("✅ Combined augmentation completed.")
augment_images(model, input_dir, output_dir, device, apply_grain)
load_vqvae_model(checkpoint_path, device)