PhysAugNet 1.0.1
VQ-VQE powered augmentation for metal defect segmentation
Loading...
Searching...
No Matches
train.VQVAETrainer Class Reference

Public Member Functions

 __init__ (self, config)
 get_dataloader (self)
 train (self)

Public Attributes

 cfg = config
str device = "cuda" if torch.cuda.is_available() else "cpu"
 logger = setup_logger("vqvae_trainer", config["log_dir"])
 model = VQVAE().to(self.device)
 optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
 criterion = torch.nn.MSELoss()

Detailed Description

Definition at line 8 of file train.py.

Constructor & Destructor Documentation

◆ __init__()

train.VQVAETrainer.__init__ ( self,
config )

Definition at line 9 of file train.py.

Member Function Documentation

◆ get_dataloader()

train.VQVAETrainer.get_dataloader ( self)

Definition at line 18 of file train.py.

◆ train()

train.VQVAETrainer.train ( self)

Definition at line 23 of file train.py.

Member Data Documentation

◆ cfg

train.VQVAETrainer.cfg = config

Definition at line 10 of file train.py.

◆ criterion

train.VQVAETrainer.criterion = torch.nn.MSELoss()

Definition at line 16 of file train.py.

◆ device

str train.VQVAETrainer.device = "cuda" if torch.cuda.is_available() else "cpu"

Definition at line 11 of file train.py.

◆ logger

train.VQVAETrainer.logger = setup_logger("vqvae_trainer", config["log_dir"])

Definition at line 12 of file train.py.

◆ model

train.VQVAETrainer.model = VQVAE().to(self.device)

Definition at line 13 of file train.py.

◆ optimizer

train.VQVAETrainer.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)

Definition at line 15 of file train.py.


The documentation for this class was generated from the following file: