neurovlm.train.Trainer#
- class neurovlm.train.Trainer(model, loss_fn, lr, batch_size, n_epochs, optimizer, X_val=None, y_val=None, verbose=True, interval=None, use_tqdm=False, tensorboard_path=None, device=None)[source]#
Training loop.
- Parameters:
model (Module)
loss_fn (Callable)
lr (float)
batch_size (int)
n_epochs (int)
optimizer (Callable)
X_val (tensor | None)
y_val (tensor | None)
verbose (bool | None)
interval (int | None)
use_tqdm (bool | None)
tensorboard_path (str | None)
device (str | None)
- __init__(model, loss_fn, lr, batch_size, n_epochs, optimizer, X_val=None, y_val=None, verbose=True, interval=None, use_tqdm=False, tensorboard_path=None, device=None)[source]#
Initialize training parameters.
- Parameters:
model (torch.nn.Module) – Model to fit.
loss_fn (Callable) – Loss function, e.g. torch.nn.MSELoss.
lr (float) – Learning rate or step size.
batch_size (int) – Size of mini-batches.
n_epochs (int) – Number of epochs or full dataset passes through the model.
optimizer (Callable) – Un-inialized torch optimizer. Use partial to set optional kwargs if needed.
X_val (2d torch.tensor, optional, default: None) – Validation input data.
y_val (2d torch.tensor, optional, default: None) – Validation target data.
verbose (bool, optional, default: True) – Prints val loss after every epoch if True. Must pass X_val. Int
interval (optional, default: None) – How often to traci val loss, in epochs.
use_tqdm (bool, optional, default: False) – Training progess bar.
tensorboard_path (bool, optional, default: False) – Path to store interactive webpage that displays validation loss over epochs.
device ({None, "cuda", "mps", "cpu", "auto"}) –
- Moves model and tensors to requested device.
None: leaves leaves on current device
”auto”: moves to gpu based on availablity
”mps”: Apple
”cuda”: Nivida
Methods
__init__(model, loss_fn, lr, batch_size, ...)Initialize training parameters.
fit(X_train[, y_train])Training loop.
predict(X)Foward pass.
restore_best()Restore the best model, basd on best validation loss.
save(path)Save model.