17 model = VQVAE(img_channels=3, hidden_channels=128, embedding_dim=64, num_embeddings=512)
18 model.load_state_dict(torch.load(checkpoint_path, map_location=device))
25 os.makedirs(output_dir, exist_ok=
True)
26 transform = transforms.ToTensor()
27 image_paths = load_images_from_folder(input_dir)
29 for img_path
in tqdm(image_paths, desc=
"Processing VQ+Thermal(+Grain)"):
30 img_name = os.path.basename(img_path)
31 image = Image.open(img_path).convert(
"RGB")
32 tensor = transform(image).unsqueeze(0).to(device)
35 recon, _ =
model(tensor)
38 recon = recon.squeeze(0).cpu().clamp(0, 1)
41 thermal = apply_thermal_augmentation(recon)
45 thermal = apply_grain_noise(thermal)
47 save_image(thermal, os.path.join(output_dir, img_name))