10def main(input_dir, output_dir, checkpoint, config_path="configs/default.yaml"):
11 cfg = load_config(config_path)
12 logger = setup_logger(
"augment_combined", cfg[
"log_dir"])
13 device =
"cuda" if torch.cuda.is_available()
else "cpu"
14 model = VQVAE().to(device)
15 model.load_state_dict(torch.load(checkpoint, map_location=device))
17 os.makedirs(output_dir, exist_ok=
True)
18 images, names = load_image_folder(input_dir, cfg[
"vqvae"][
"image_size"])
19 for img, name
in zip(images, names):
20 img = img.unsqueeze(0).to(device)
22 recon, _, _ =
model(img)
23 recon = recon.squeeze(0).cpu()
24 aug = apply_thermal_augmentation(recon)
26 save_image(aug, f
"{output_dir}/{name}")
27 logger.info(f
"Combined augmentations saved to {output_dir}")
37 main(args.input_dir, args.output_dir, args.vqvae_ckpt, args.config)