19 transform = transforms.Compose([transforms.Resize(tuple(self.
cfg[
"vqvae"][
"image_size"])), transforms.ToTensor()])
20 dataset = datasets.ImageFolder(root=self.
cfg[
"dataset_dir"], transform=transform)
21 return DataLoader(dataset, batch_size=self.
cfg[
"vqvae"][
"batch_size"], shuffle=
True, num_workers=self.
cfg[
"vqvae"][
"num_workers"])
26 for epoch
in range(self.
cfg[
"vqvae"][
"num_epochs"]):
28 for imgs, _
in dataloader:
29 imgs = imgs.to(self.
device)
30 recon, vq_loss, _ = self.
model(imgs)
32 loss = recon_loss + vq_loss
36 running_loss += loss.item()
37 avg_loss = running_loss / len(dataloader)
38 self.
logger.info(f
"Epoch {epoch+1}/{self.cfg['vqvae']['num_epochs']}: Loss={avg_loss:.4f}")
39 if (epoch + 1) % self.
cfg[
"vqvae"][
"save_interval"] == 0:
40 ckpt_path = f
"{self.cfg['vqvae']['checkpoint_dir']}/vqvae_{epoch+1}.pth"
41 torch.save(self.
model.state_dict(), ckpt_path)
42 self.
logger.info(f
"Saved checkpoint: {ckpt_path}")