10def infer_video(video_path, output_path, checkpoint, config_path="configs/default.yaml"):
11 cfg = load_config(config_path)
12 logger = setup_logger(
"infer_video", cfg[
"log_dir"])
13 device =
"cuda" if torch.cuda.is_available()
else "cpu"
15 model = VQVAE().to(device)
16 model.load_state_dict(torch.load(checkpoint, map_location=device))
19 transform = transforms.Compose([transforms.ToPILImage(), transforms.Resize(cfg[
"vqvae"][
"image_size"]), transforms.ToTensor()])
21 cap = cv2.VideoCapture(video_path)
22 if not cap.isOpened():
23 logger.error(f
"Failed to open video: {video_path}")
26 fps = cap.get(cv2.CAP_PROP_FPS)
27 frame_size = (cfg[
"vqvae"][
"image_size"][1], cfg[
"vqvae"][
"image_size"][0])
28 out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*
"mp4v"), fps, frame_size)
31 ret, frame = cap.read()
34 frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
35 tensor = transform(frame).unsqueeze(0).to(device)
37 recon, _, _ = model(tensor)
38 recon = recon.squeeze(0).cpu().mul(255).byte().permute(1, 2, 0).numpy()
39 recon = cv2.cvtColor(recon, cv2.COLOR_RGB2BGR)
44 logger.info(f
"Reconstructed video saved to {output_path}")