7 from .vqvae
import VQVAE
9 from physaug.utils.config
import load_config
10 from physaug.utils.logger
import setup_logger
11 cfg = load_config(config_path)
12 logger = setup_logger(
"combined", cfg[
"log_dir"])
13 device =
"cuda" if torch.cuda.is_available()
else "cpu"
14 model = VQVAE().to(device)
15 model.load_state_dict(torch.load(checkpoint, map_location=device))
17 images, names = load_image_folder(input_dir, cfg[
"vqvae"][
"image_size"])
18 for img, name
in zip(images, names):
19 img = img.unsqueeze(0).to(device)
21 recon, _, _ = model(img)
22 recon = recon.squeeze(0).cpu()
23 aug = apply_thermal_augmentation(recon)
25 save_image(aug, f
"{output_dir}/{name}")
26 logger.info(f
"Combined augmentations saved to {output_dir}")