|
| | cfg = config |
| str | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | logger = setup_logger("vqvae_trainer", config["log_dir"]) |
| | model = VQVAE().to(self.device) |
| | optimizer = torch.optim.Adam(self.model.parameters(), lr=lr) |
| | criterion = torch.nn.MSELoss() |
Definition at line 8 of file train.py.
◆ __init__()
| train.VQVAETrainer.__init__ |
( |
| self, |
|
|
| config ) |
◆ get_dataloader()
| train.VQVAETrainer.get_dataloader |
( |
| self | ) |
|
◆ train()
| train.VQVAETrainer.train |
( |
| self | ) |
|
◆ cfg
| train.VQVAETrainer.cfg = config |
◆ criterion
| train.VQVAETrainer.criterion = torch.nn.MSELoss() |
◆ device
| str train.VQVAETrainer.device = "cuda" if torch.cuda.is_available() else "cpu" |
◆ logger
| train.VQVAETrainer.logger = setup_logger("vqvae_trainer", config["log_dir"]) |
◆ model
| train.VQVAETrainer.model = VQVAE().to(self.device) |
◆ optimizer
| train.VQVAETrainer.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr) |
The documentation for this class was generated from the following file: