Source code for linear_model.ZINB_grad

"""
the ZINB-Grad, a gradient-based ZINB GLMM with GPU acceleration, high-performance
scalability, and memory-efficient estimation.
"""

import torch
from pyro.distributions import ZeroInflatedNegativeBinomial as ZINB
from torch import nn


[docs] class ZINB_Grad(nn.Module): """ The ZINB-Grad model. A gradient descent-based stochastic optimization process for the ZINB-WaVE to overcome the scalability and efficiency challenges inherited in its optimization procedure. The result of this combination is ZINB-Grad. Parameters ---------- Y : torch.Tensor Tensor of shape (n_samples, n_features). device : A `torch.device` object Please refer to Pytorch documentation for more details. K : int (optional, default=1) Number of latent space dimensions W : torch.Tensor (optional, default=None) Tensor of shape (n_samples, K). X : torch.Tensor (optional, default=None) Known tensor of shape (n_samples, M). M is user definedand is the number of covariates in X. When X = None, X is a column of ones. V : torch.Tensor (optional, default=None) Known tensor of shape (n_features, L). L is user defined and is the number of covariates in V. When V = None, V is a column of ones. alpha_mu : torch.Tensor (optional, default=None) Tensor of shape (K, n_features). alpha_pi : torch.Tensor (optional, default=None) Tensor of shape (K, n_features). beta_mu : torch.Tensor (optional, default=None) Tensor of shape (M, n_features). beta_pi : torch.Tensor (optional, default=None) Tensor of shape (M, n_features). gamma_mu : torch.Tensor (optional, default=None) Tensor of shape (L, n_samples). gamma_pi : torch.Tensor (optional, default=None) Tensor of shape (L, n_samples). log_theta : torch.Tensor (optional, default=None) Tensor of shape (1, n_features). The natural logarithm of the theta parameter in the ZINB distribution. O_mu : torch.Tensor (optional, default=None) Tensor of shape (n_samples, n_features). O_pi : torch.Tensor (optional, default=None) Tensor of shape (n_samples, n_features). Attributes ---------- n : int The number of samples J : int The number of features (genes) M : int The number of covariates (columns) in X. Examples -------- >>> import ZINB_grad >>> import torch >>> import data_prep >>> cortex = data_prep.CORTEX() >>> y, labels = next(iter(DataLoader(cortex, batch_size= cortex.n_cells, shuffle=True))) >>> device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') >>> model = ZINB_grad.ZINB_WaVE(Y = y, K = 10, device =device) """ def __init__( self, Y, device, W=None, alpha_mu=None, alpha_pi=None, beta_mu=None, beta_pi=None, gamma_mu=None, gamma_pi=None, log_theta=None, X=None, V=None, O_mu=None, O_pi=None, K=1, ): super().__init__() self.n, self.J = Y.size() # n samples and J genes self.X = X self.V = V self.K = K if log_theta: self.log_theta = log_theta else: self.log_theta = nn.Parameter(torch.rand((1, self.J))) if X: _, self.M = X.size() self.X = X.to(device) if beta_mu: self.beta_mu = beta_mu else: self.beta_mu = nn.Parameter(torch.rand(self.M, self.J)) if beta_pi: self.beta_pi = beta_pi else: self.beta_pi = nn.Parameter(torch.rand(self.M, self.J)) else: self.X = torch.ones((self.n, 1)).to(device) if beta_mu: self.beta_mu = beta_mu else: self.beta_mu = nn.Parameter(torch.rand((1, self.J))) if beta_pi: self.beta_pi = beta_pi else: self.beta_pi = nn.Parameter(torch.rand((1, self.J))) if V: _, self.L = V.size() self.V = V.to(device) if gamma_mu: self.gamma_mu = gamma_mu else: self.gamma_mu = nn.Parameter(torch.rand((self.L, self.n))) if gamma_pi: self.gamma_pi = gamma_pi else: self.gamma_pi = nn.Parameter(torch.rand((self.L, self.n))) else: self.V = torch.ones((self.J, 1)).to(device) if gamma_mu: self.gamma_mu = gamma_mu else: self.gamma_mu = nn.Parameter(torch.rand((1, self.n))) if gamma_pi: self.gamma_pi = gamma_pi else: self.gamma_pi = nn.Parameter(torch.rand((1, self.n))) if W: self.W = W else: self.W = nn.Parameter(torch.rand((self.n, self.K))) if alpha_mu: self.alpha_mu = alpha_mu else: self.alpha_mu = nn.Parameter(torch.rand((self.K, self.J))) if alpha_pi: self.alpha_pi = alpha_pi else: self.alpha_pi = nn.Parameter(torch.rand((self.K, self.J)))
[docs] def forward(self, x): """ The forward method of class Module in `torch.nn`. Parameters ---------- x : torch.Tensor Tensor of shape (n_samples, n_features). Returns ------- p : torch.Tensor Tensor of shape (n_samples, n_features) which is the probability of failure for each element of data in the ZINB distribution. """ self.log_mu = ( self.X @ self.beta_mu + self.gamma_mu.T @ self.V.T + self.W @ self.alpha_mu ) self.log_pi = ( self.X @ self.beta_pi + self.gamma_pi.T @ self.V.T + self.W @ self.alpha_pi ) self.mu = torch.exp(self.log_mu) self.theta = torch.exp(self.log_theta) # Adaptive regulatory parameters are applied: p = self.mu / (self.mu + self.theta + 1e-4 + 1e-4 * self.mu + 1e-4 * self.theta) return p
def _loss(self, x, p): """ Returns the loss. A method to calculate the negative log-likelihood, along with the regularization penalty. The regularization is applied to avoid overfitting. Parameters ---------- x : torch.Tensor Tensor of shape (n_samples, n_features). Returns ------- loss : float Sum of the negative log-likelihood for all samples. pen : float The regularization term loss. """ J = x.shape[1] n = x.shape[0] eps_W = J / n eps_alpha_mu = 1 eps_alpha_pi = 1 eps_theta = J loss = ( ZINB( total_count=torch.exp(self.log_theta), probs=p, gate_logits=self.log_pi ) .log_prob(x) .sum() ) pen = ( eps_W * torch.linalg.norm(self.W, ord="fro").square().item() / 2 + eps_alpha_mu * torch.linalg.norm(self.alpha_mu, ord="fro").square().item() / 2 + eps_alpha_pi * torch.linalg.norm(self.alpha_pi, ord="fro").square().item() / 2 + eps_theta * torch.var(self.log_theta) / 2 ) return -loss, pen
[docs] def train_ZINB(x, optimizer, model, epochs=150, val=False): """ Trains a ZINB-Grad model. The function will train a ZINB-Grad model using an optimizer for a number of epochs, and it will return both losses and negative log-likelihood, which were obtained during the training procedure. Parameters ---------- x : torch.Tensor It is the data for training, a Tensor of shape (n_samples, n_features). optimizer: An object of torch.optim.Optimizer For more details, please refer to Pytorch documentation. model: An object of the ZINB_Grad class Please refer to the example. epochs : int (optional, default=150) Number of iteration for training. val : bool (optional, default=False) Whether it is validation or training process. Returns ------- losses : list A list consisting of the loss of each epoch. neg_log_liks : list A list consisting of the negative Log-likelihood of each epoch. Examples -------- >>> import ZINB_grad >>> import data_prep >>> import torch >>> from torch.utils.data import DataLoader >>> cortex = data_prep.CORTEX() >>> y, labels = next(iter(DataLoader(cortex, batch_size= cortex.n_cells, shuffle=True))) >>> model = ZINB_grad.ZINB_WaVE(Y = y, K = 10, device =device) >>> optimizer = torch.optim.Adam(model.parameters(), lr = 0.08) >>> losses, neg_log_liks = ZINB_grad.train_ZINB(y, optimizer, model, epochs = 300) """ losses = [] neg_log_liks = [] for i in range(epochs): i += 1 batch = x p = model(batch) neg_log_lik, pen = model._loss(batch, p) loss = neg_log_lik + pen losses.append(loss.item()) neg_log_liks.append(neg_log_lik.item()) if i % 50 == 1: if not val: print(f"epoch: {i:3} loss: {loss.item():10.2f}") optimizer.zero_grad() loss.backward() optimizer.step() if val: print(f"validation loss: {loss.item():10.2f}") else: print(f"epoch: {i:3} loss: {loss.item():10.2f}") # print the last line return losses, neg_log_liks
[docs] def train_ZINB_with_val( x, val_data, optimizer, model, device, X_val=None, epochs=150, PATH="/home/longlab/Data/Thesis/Data/", early_stop=False, ): """ Trains a ZINB-Grad model with validation. The function will train a ZINB-Grad model with validation using an optimizer for a number of epochs, and it will return losses, negative log-likelihood, and validation losses which were obtained during the training procedure. The function will save the model with the best validation loss, and it uses early stopping to avoid overfitting. In the early stopping the model with the best validation loss will be loaded. Parameters ---------- x : `torch.Tensor` It is the data for training, a Tensor of shape (n_samples, n_features). val_data : `torch.Tensor` It is the validation data, a Tensor of shape (n_samples_val, n_features). optimizer: An object of `torch.optim.Optimizer` For more details, please refer to Pytorch documentation. model: An object of the `ZINB_Grad` class Please refer to the example. device : A `torch.device` object Please refer to Pytorch documentation for more details. X_val : `torch.Tensor` (optional, default=None) It is the X parameter of the ZINB-Grad model for the validation samples, a Tensor of shape (n_samples_val, M). epochs : int (optional, default=150) Number of iteration for training. early_stop : bool (optional, default=False) If True the function will use early stopping. PATH : str The path to save the best model. Returns ------- losses : list A list consisting of the loss of each epoch. neg_log_liks : list A list consisting of the negative Log-likelihood of each epoch. val_losses : list A list consisting of the validation losses of each validation step. """ losses = [] neg_log_liks = [] val_losses = [] val_loss, _ = val_ZINB(val_data, model, device, X_val=X_val) val_losses.append(val_loss) # to avoid error when training is not making the model any better torch.save(model.state_dict(), PATH + "best_trained_model.pt") for i in range(epochs): i += 1 batch = x p = model(batch) neg_log_lik, pen = model._loss(batch, p) loss = neg_log_lik + pen losses.append(loss.item()) neg_log_liks.append(neg_log_lik.item()) optimizer.zero_grad() loss.backward() optimizer.step() if i % 50 == 1: print(f"epoch: {i:3} loss: {loss.item():10.2f}") val_loss_last, _ = val_ZINB(val_data, model, device, X_val=X_val) val_losses.append(val_loss) if val_loss_last <= val_loss: val_loss = val_loss_last # save model checkpoint torch.save(model.state_dict(), PATH + "best_trained_model.pt") elif early_stop: model.load_state_dict(torch.load(PATH + "best_trained_model.pt")) break val_loss_last, _ = val_ZINB(val_data, model, device, X_val=X_val) val_losses.append(val_loss) if val_loss_last <= val_loss: val_loss = val_loss_last # save model checkpoint torch.save(model.state_dict(), PATH + "best_trained_model.pt") print(f"epoch: {i:3} loss: {loss.item():10.2f}") # print the last line return losses, neg_log_liks, val_losses
[docs] def val_ZINB(val_data, model, device, epochs=15, X_val=None): """ Returns the validation loss and negative log-likelihood. The function will perform the validation on a ZINB-Grad model. The following parameters would be the same during the validation process: `log_theta`, `beta_mu, `beta_pi`, `alpha_mu`, and `alpha_pi`, and they will not be updated. However, the `W`, `gamma_mu`, and `gamma_pi` would change because their dimension depend on the number of samples, i.e., are sample specific. Parameters ---------- val_data : `torch.Tensor` It is the validation data, a Tensor of shape (n_samples_val, n_features). model: An object of the `ZINB_Grad` class Please refer to the example. device : A `torch.device` object Please refer to Pytorch documentation for more details. epochs : int (optional, default=15) Number of iteration for training. X_val : `torch.Tensor` (optional, default=None) It is the X parameter of the ZINB-Grad model for the validation samples, a Tensor of shape (n_samples_val, M). Returns ------- loss : float The validation loss neg_log_lik : float The validation negative log-likelihood """ model_val = ZINB_Grad( Y=val_data, X=X_val, device=device, K=model.K, alpha_mu=model.alpha_mu.detach(), alpha_pi=model.alpha_pi.detach(), beta_mu=model.beta_mu.detach(), beta_pi=model.beta_pi.detach(), log_theta=model.log_theta.detach(), ) # Tuning the validation model parameters (W and gammas) model_val.to(device) optimizer = torch.optim.Adam(model_val.parameters(), lr=0.1) losses, neg_log_liks = train_ZINB( val_data, optimizer, model_val, epochs=epochs, val=True ) return losses[-1], neg_log_liks[-1]