linear_model.ZINB_grad.train_ZINB_with_val
- linear_model.ZINB_grad.train_ZINB_with_val(x, val_data, optimizer, model, device, X_val=None, epochs=150, PATH='/home/longlab/Data/Thesis/Data/', early_stop=False)[source]
Trains a ZINB-Grad model with validation.
The function will train a ZINB-Grad model with validation using an optimizer for a number of epochs, and it will return losses, negative log-likelihood, and validation
losses which were obtained during the training procedure.
The function will save the model with the best validation loss, and it uses early stopping to avoid overfitting. In the early stopping the model with the best validation loss will be loaded.
Parameters
- xtorch.Tensor
It is the data for training, a Tensor of shape (n_samples, n_features).
- val_datatorch.Tensor
It is the validation data, a Tensor of shape (n_samples_val, n_features).
- optimizer: An object of torch.optim.Optimizer
For more details, please refer to Pytorch documentation.
- model: An object of the ZINB_Grad class
Please refer to the example.
- deviceA torch.device object
Please refer to Pytorch documentation for more details.
- X_valtorch.Tensor (optional, default=None)
It is the X parameter of the ZINB-Grad model for the validation samples, a Tensor of shape (n_samples_val, M).
- epochsint (optional, default=150)
Number of iteration for training.
- early_stopbool (optional, default=False)
If True the function will use early stopping.
- PATHstr
The path to save the best model.
Returns
- losseslist
A list consisting of the loss of each epoch.
- neg_log_likslist
A list consisting of the negative Log-likelihood of each epoch.
- val_losseslist
A list consisting of the validation losses of each validation step.