7def main(config_path="configs/default.yaml"):
8 cfg = load_config(config_path)
9 logger = setup_logger(
"train_vqvae", cfg[
"log_dir"])
10 device =
"cuda" if torch.cuda.is_available()
else "cpu"
12 train_loader, _ = get_dataloaders(cfg)
13 model = VQVAE().to(device)
14 optimizer = torch.optim.Adam(model.parameters(), lr=cfg[
"vqvae"][
"learning_rate"])
16 for epoch
in range(cfg[
"vqvae"][
"num_epochs"]):
19 for batch
in train_loader:
20 imgs = batch[0].to(device)
21 recon, vq_loss, _ = model(imgs)
22 loss = torch.nn.MSELoss()(recon, imgs) + vq_loss
26 running_loss += loss.item()
27 logger.info(f
"Epoch {epoch+1}/{cfg['vqvae']['num_epochs']}: Loss={running_loss/len(train_loader):.4f}")
28 if (epoch + 1) % cfg[
"vqvae"][
"save_interval"] == 0:
29 torch.save(model.state_dict(), f
"{cfg['vqvae']['checkpoint_dir']}/vqvae_{epoch+1}.pth")
30 logger.info(f
"Saved checkpoint: {cfg['vqvae']['checkpoint_dir']}/vqvae_{epoch+1}.pth")