10def reconstruct_folder(input_dir, output_dir, model_path, image_size=(128, 128), config_path=
"configs/default.yaml"):
11 cfg = load_config(config_path)
12 logger = setup_logger(
"infer", cfg[
"log_dir"])
13 device =
"cuda" if torch.cuda.is_available()
else "cpu"
14 model =
VQVAE().to(device)
15 model.load_state_dict(torch.load(model_path, map_location=device))
17 transform = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
18 files = [f
for f
in os.listdir(input_dir)
if f.lower().endswith((
'png',
'jpg',
'jpeg'))]
19 os.makedirs(output_dir, exist_ok=
True)
21 img = Image.open(f
"{input_dir}/{fname}").convert(
"RGB")
22 x = transform(img).unsqueeze(0).to(device)
24 recon, _, _ = model(x)
25 save_image(recon.clamp(0, 1), f
"{output_dir}/{fname}")
26 logger.info(f
"Reconstructed images saved to {output_dir}")
reconstruct_folder(input_dir, output_dir, model_path, image_size=(128, 128), config_path="configs/default.yaml")