PhysAugNet 1.0.1
VQ-VQE powered augmentation for metal defect segmentation
Loading...
Searching...
No Matches
infer_video.py
Go to the documentation of this file.
1import os
2import cv2
3import torch
4from torchvision import transforms
5from torchvision.utils import save_image
6from physaug.vqvae.vqvae import VQVAE
7from physaug.utils.logger import setup_logger
8from physaug.utils.config import load_config
9
10def infer_video(video_path, output_path, checkpoint, config_path="configs/default.yaml"):
11 cfg = load_config(config_path)
12 logger = setup_logger("infer_video", cfg["log_dir"])
13 device = "cuda" if torch.cuda.is_available() else "cpu"
14
15 model = VQVAE().to(device)
16 model.load_state_dict(torch.load(checkpoint, map_location=device))
17 model.eval()
18
19 transform = transforms.Compose([transforms.ToPILImage(), transforms.Resize(cfg["vqvae"]["image_size"]), transforms.ToTensor()])
20
21 cap = cv2.VideoCapture(video_path)
22 if not cap.isOpened():
23 logger.error(f"Failed to open video: {video_path}")
24 return
25
26 fps = cap.get(cv2.CAP_PROP_FPS)
27 frame_size = (cfg["vqvae"]["image_size"][1], cfg["vqvae"]["image_size"][0])
28 out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, frame_size)
29
30 while cap.isOpened():
31 ret, frame = cap.read()
32 if not ret:
33 break
34 frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
35 tensor = transform(frame).unsqueeze(0).to(device)
36 with torch.no_grad():
37 recon, _, _ = model(tensor)
38 recon = recon.squeeze(0).cpu().mul(255).byte().permute(1, 2, 0).numpy()
39 recon = cv2.cvtColor(recon, cv2.COLOR_RGB2BGR)
40 out.write(recon)
41
42 cap.release()
43 out.release()
44 logger.info(f"Reconstructed video saved to {output_path}")
45
46if __name__ == "__main__":
47 import argparse
48 parser = argparse.ArgumentParser(description="Reconstruct video using VQ-VAE")
49 parser.add_argument("--video_path", required=True, help="Input video file")
50 parser.add_argument("--output_path", required=True, help="Output video file")
51 parser.add_argument("--checkpoint", required=True, help="Path to VQ-VAE checkpoint")
52 parser.add_argument("--config", type=str, default="configs/default.yaml")
53 args = parser.parse_args()
54 infer_video(args.video_path, args.output_path, args.checkpoint, args.config)