neurovlm.train.Trainer

neurovlm.train.Trainer#

class neurovlm.train.Trainer(model, loss_fn, lr, batch_size, n_epochs, optimizer, X_val=None, y_val=None, verbose=True, interval=None, use_tqdm=False, tensorboard_path=None, device=None)[source]#

Training loop.

Parameters:
  • model (Module)

  • loss_fn (Callable)

  • lr (float)

  • batch_size (int)

  • n_epochs (int)

  • optimizer (Callable)

  • X_val (tensor | None)

  • y_val (tensor | None)

  • verbose (bool | None)

  • interval (int | None)

  • use_tqdm (bool | None)

  • tensorboard_path (str | None)

  • device (str | None)

__init__(model, loss_fn, lr, batch_size, n_epochs, optimizer, X_val=None, y_val=None, verbose=True, interval=None, use_tqdm=False, tensorboard_path=None, device=None)[source]#

Initialize training parameters.

Parameters:
  • model (torch.nn.Module) – Model to fit.

  • loss_fn (Callable) – Loss function, e.g. torch.nn.MSELoss.

  • lr (float) – Learning rate or step size.

  • batch_size (int) – Size of mini-batches.

  • n_epochs (int) – Number of epochs or full dataset passes through the model.

  • optimizer (Callable) – Un-inialized torch optimizer. Use partial to set optional kwargs if needed.

  • X_val (2d torch.tensor, optional, default: None) – Validation input data.

  • y_val (2d torch.tensor, optional, default: None) – Validation target data.

  • verbose (bool, optional, default: True) – Prints val loss after every epoch if True. Must pass X_val. Int

  • interval (optional, default: None) – How often to traci val loss, in epochs.

  • use_tqdm (bool, optional, default: False) – Training progess bar.

  • tensorboard_path (bool, optional, default: False) – Path to store interactive webpage that displays validation loss over epochs.

  • device ({None, "cuda", "mps", "cpu", "auto"}) –

    Moves model and tensors to requested device.
    • None: leaves leaves on current device

    • ”auto”: moves to gpu based on availablity

    • ”mps”: Apple

    • ”cuda”: Nivida

Methods

__init__(model, loss_fn, lr, batch_size, ...)

Initialize training parameters.

fit(X_train[, y_train])

Training loop.

predict(X)

Foward pass.

restore_best()

Restore the best model, basd on best validation loss.

save(path)

Save model.