"""
the ZINB-Grad, a gradient-based ZINB GLMM with GPU acceleration, high-performance
scalability, and memory-efficient estimation.
"""
import torch
from pyro.distributions import ZeroInflatedNegativeBinomial as ZINB
from torch import nn
[docs]
class ZINB_Grad(nn.Module):
    """
    The ZINB-Grad model.
    A gradient descent-based stochastic optimization process for the ZINB-WaVE
    to overcome the scalability and efficiency challenges inherited in its optimization
    procedure. The result of this combination is ZINB-Grad.
    Parameters
    ----------
    Y : torch.Tensor
        Tensor of shape (n_samples, n_features).
    device : A `torch.device` object
        Please refer to Pytorch documentation for more details.
    K : int (optional, default=1)
         Number of latent space dimensions
    W : torch.Tensor (optional, default=None)
        Tensor of shape (n_samples, K).
    X : torch.Tensor (optional, default=None)
        Known tensor of shape (n_samples, M). M is user definedand is the number of
        covariates in X. When X = None, X is a column of ones.
    V : torch.Tensor (optional, default=None)
        Known tensor of shape (n_features, L). L is user defined and is the number of
        covariates in V. When V = None, V is a column of ones.
    alpha_mu : torch.Tensor (optional, default=None)
        Tensor of shape (K, n_features).
    alpha_pi : torch.Tensor (optional, default=None)
        Tensor of shape (K, n_features).
    beta_mu : torch.Tensor (optional, default=None)
        Tensor of shape (M, n_features).
    beta_pi : torch.Tensor (optional, default=None)
        Tensor of shape (M, n_features).
    gamma_mu : torch.Tensor (optional, default=None)
        Tensor of shape (L, n_samples).
    gamma_pi : torch.Tensor (optional, default=None)
        Tensor of shape (L, n_samples).
    log_theta : torch.Tensor (optional, default=None)
        Tensor of shape (1, n_features). The natural logarithm of the theta parameter in
        the ZINB distribution.
    O_mu : torch.Tensor (optional, default=None)
        Tensor of shape (n_samples, n_features).
    O_pi : torch.Tensor (optional, default=None)
        Tensor of shape (n_samples, n_features).
    Attributes
    ----------
    n : int
        The number of samples
    J : int
        The number of features (genes)
    M : int
        The number of covariates (columns) in X.
    Examples
    --------
    >>> import ZINB_grad
    >>> import torch
    >>> import data_prep
    >>> cortex = data_prep.CORTEX()
    >>> y, labels = next(iter(DataLoader(cortex,
                                 batch_size= cortex.n_cells,
                                 shuffle=True)))
    >>> device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    >>> model = ZINB_grad.ZINB_WaVE(Y = y, K = 10, device =device)
    """
    def __init__(
        self,
        Y,
        device,
        W=None,
        alpha_mu=None,
        alpha_pi=None,
        beta_mu=None,
        beta_pi=None,
        gamma_mu=None,
        gamma_pi=None,
        log_theta=None,
        X=None,
        V=None,
        O_mu=None,
        O_pi=None,
        K=1,
    ):
        super().__init__()
        self.n, self.J = Y.size()  # n samples and J genes
        self.X = X
        self.V = V
        self.K = K
        if log_theta:
            self.log_theta = log_theta
        else:
            self.log_theta = nn.Parameter(torch.rand((1, self.J)))
        if X:
            _, self.M = X.size()
            self.X = X.to(device)
            if beta_mu:
                self.beta_mu = beta_mu
            else:
                self.beta_mu = nn.Parameter(torch.rand(self.M, self.J))
            if beta_pi:
                self.beta_pi = beta_pi
            else:
                self.beta_pi = nn.Parameter(torch.rand(self.M, self.J))
        else:
            self.X = torch.ones((self.n, 1)).to(device)
            if beta_mu:
                self.beta_mu = beta_mu
            else:
                self.beta_mu = nn.Parameter(torch.rand((1, self.J)))
            if beta_pi:
                self.beta_pi = beta_pi
            else:
                self.beta_pi = nn.Parameter(torch.rand((1, self.J)))
        if V:
            _, self.L = V.size()
            self.V = V.to(device)
            if gamma_mu:
                self.gamma_mu = gamma_mu
            else:
                self.gamma_mu = nn.Parameter(torch.rand((self.L, self.n)))
            if gamma_pi:
                self.gamma_pi = gamma_pi
            else:
                self.gamma_pi = nn.Parameter(torch.rand((self.L, self.n)))
        else:
            self.V = torch.ones((self.J, 1)).to(device)
            if gamma_mu:
                self.gamma_mu = gamma_mu
            else:
                self.gamma_mu = nn.Parameter(torch.rand((1, self.n)))
            if gamma_pi:
                self.gamma_pi = gamma_pi
            else:
                self.gamma_pi = nn.Parameter(torch.rand((1, self.n)))
        if W:
            self.W = W
        else:
            self.W = nn.Parameter(torch.rand((self.n, self.K)))
        if alpha_mu:
            self.alpha_mu = alpha_mu
        else:
            self.alpha_mu = nn.Parameter(torch.rand((self.K, self.J)))
        if alpha_pi:
            self.alpha_pi = alpha_pi
        else:
            self.alpha_pi = nn.Parameter(torch.rand((self.K, self.J)))
[docs]
    def forward(self, x):
        """
        The forward method of class Module in `torch.nn`.
        Parameters
        ----------
         x : torch.Tensor
             Tensor of shape (n_samples, n_features).
        Returns
        -------
        p : torch.Tensor
             Tensor of shape (n_samples, n_features) which is the probability of failure
              for each element of data in the ZINB distribution.
        """
        self.log_mu = (
            self.X @ self.beta_mu + self.gamma_mu.T @ self.V.T + self.W @ self.alpha_mu
        )
        self.log_pi = (
            self.X @ self.beta_pi + self.gamma_pi.T @ self.V.T + self.W @ self.alpha_pi
        )
        self.mu = torch.exp(self.log_mu)
        self.theta = torch.exp(self.log_theta)
        # Adaptive regulatory parameters are applied:
        p = self.mu / (self.mu + self.theta + 1e-4 + 1e-4 * self.mu + 1e-4 * self.theta)
        return p 
    def _loss(self, x, p):
        """
         Returns the loss.
         A method to calculate the negative log-likelihood, along with the
         regularization penalty. The regularization is applied to avoid overfitting.
        Parameters
        ----------
         x : torch.Tensor
             Tensor of shape (n_samples, n_features).
        Returns
        -------
        loss : float
             Sum of the negative log-likelihood for all samples.
        pen : float
             The regularization term loss.
        """
        J = x.shape[1]
        n = x.shape[0]
        eps_W = J / n
        eps_alpha_mu = 1
        eps_alpha_pi = 1
        eps_theta = J
        loss = (
            ZINB(
                total_count=torch.exp(self.log_theta), probs=p, gate_logits=self.log_pi
            )
            .log_prob(x)
            .sum()
        )
        pen = (
            eps_W * torch.linalg.norm(self.W, ord="fro").square().item() / 2
            + eps_alpha_mu
            * torch.linalg.norm(self.alpha_mu, ord="fro").square().item()
            / 2
            + eps_alpha_pi
            * torch.linalg.norm(self.alpha_pi, ord="fro").square().item()
            / 2
            + eps_theta * torch.var(self.log_theta) / 2
        )
        return -loss, pen 
[docs]
def train_ZINB(x, optimizer, model, epochs=150, val=False):
    """
    Trains a ZINB-Grad model.
    The function will train a ZINB-Grad model using an optimizer for a number of epochs,
     and it will return both losses and negative log-likelihood, which were obtained
     during the training procedure.
    Parameters
    ----------
    x : torch.Tensor
        It is the data for training, a Tensor of shape (n_samples, n_features).
    optimizer: An object of torch.optim.Optimizer
        For more details, please refer to Pytorch documentation.
    model: An object of the ZINB_Grad class
        Please refer to the example.
    epochs : int (optional, default=150)
        Number of iteration for training.
    val : bool (optional, default=False)
        Whether it is validation or training process.
    Returns
    -------
    losses : list
         A list consisting of the loss of each epoch.
    neg_log_liks : list
         A list consisting of the negative Log-likelihood of each epoch.
    Examples
    --------
    >>> import ZINB_grad
    >>> import data_prep
    >>> import torch
    >>> from torch.utils.data import DataLoader
    >>> cortex = data_prep.CORTEX()
    >>> y, labels = next(iter(DataLoader(cortex,
                                 batch_size= cortex.n_cells,
                                 shuffle=True)))
    >>> model = ZINB_grad.ZINB_WaVE(Y = y, K = 10, device =device)
    >>> optimizer = torch.optim.Adam(model.parameters(), lr = 0.08)
    >>> losses, neg_log_liks = ZINB_grad.train_ZINB(y, optimizer, model, epochs = 300)
    """
    losses = []
    neg_log_liks = []
    for i in range(epochs):
        i += 1
        batch = x
        p = model(batch)
        neg_log_lik, pen = model._loss(batch, p)
        loss = neg_log_lik + pen
        losses.append(loss.item())
        neg_log_liks.append(neg_log_lik.item())
        if i % 50 == 1:
            if not val:
                print(f"epoch: {i:3}  loss: {loss.item():10.2f}")
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    if val:
        print(f"validation loss: {loss.item():10.2f}")
    else:
        print(f"epoch: {i:3}  loss: {loss.item():10.2f}")  # print the last line
    return losses, neg_log_liks 
[docs]
def train_ZINB_with_val(
    x,
    val_data,
    optimizer,
    model,
    device,
    X_val=None,
    epochs=150,
    PATH="/home/longlab/Data/Thesis/Data/",
    early_stop=False,
):
    """
    Trains a ZINB-Grad model with validation.
    The function will train a ZINB-Grad model with validation using an optimizer for a
    number of epochs, and it will return losses, negative log-likelihood, and validation
     losses which were obtained during the training procedure.
    The function will save the model with the best validation loss, and it uses early
    stopping to avoid overfitting. In the early stopping the model with the best
    validation loss will be loaded.
    Parameters
    ----------
    x : `torch.Tensor`
        It is the data for training, a Tensor of shape (n_samples, n_features).
    val_data : `torch.Tensor`
        It is the validation data, a Tensor of shape (n_samples_val, n_features).
    optimizer: An object of `torch.optim.Optimizer`
        For more details, please refer to Pytorch documentation.
    model: An object of the `ZINB_Grad` class
        Please refer to the example.
    device : A `torch.device` object
        Please refer to Pytorch documentation for more details.
    X_val : `torch.Tensor` (optional, default=None)
        It is the X parameter of the ZINB-Grad model for the validation samples, a
        Tensor of shape (n_samples_val, M).
    epochs : int (optional, default=150)
        Number of iteration for training.
    early_stop : bool (optional, default=False)
        If True the function will use early stopping.
    PATH : str
        The path to save the best model.
    Returns
    -------
    losses : list
         A list consisting of the loss of each epoch.
    neg_log_liks : list
         A list consisting of the negative Log-likelihood of each epoch.
    val_losses : list
         A list consisting of the validation losses of each validation step.
    """
    losses = []
    neg_log_liks = []
    val_losses = []
    val_loss, _ = val_ZINB(val_data, model, device, X_val=X_val)
    val_losses.append(val_loss)
    # to avoid error when training is not making the model any better
    torch.save(model.state_dict(), PATH + "best_trained_model.pt")
    for i in range(epochs):
        i += 1
        batch = x
        p = model(batch)
        neg_log_lik, pen = model._loss(batch, p)
        loss = neg_log_lik + pen
        losses.append(loss.item())
        neg_log_liks.append(neg_log_lik.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if i % 50 == 1:
            print(f"epoch: {i:3}  loss: {loss.item():10.2f}")
            val_loss_last, _ = val_ZINB(val_data, model, device, X_val=X_val)
            val_losses.append(val_loss)
            if val_loss_last <= val_loss:
                val_loss = val_loss_last
                # save model checkpoint
                torch.save(model.state_dict(), PATH + "best_trained_model.pt")
            elif early_stop:
                model.load_state_dict(torch.load(PATH + "best_trained_model.pt"))
                break
    val_loss_last, _ = val_ZINB(val_data, model, device, X_val=X_val)
    val_losses.append(val_loss)
    if val_loss_last <= val_loss:
        val_loss = val_loss_last
        # save model checkpoint
        torch.save(model.state_dict(), PATH + "best_trained_model.pt")
    print(f"epoch: {i:3}  loss: {loss.item():10.2f}")  # print the last line
    return losses, neg_log_liks, val_losses 
[docs]
def val_ZINB(val_data, model, device, epochs=15, X_val=None):
    """
    Returns the validation loss and negative log-likelihood.
    The function will perform the validation on a ZINB-Grad model.
    The following parameters would be the same during the validation process:
        `log_theta`, `beta_mu, `beta_pi`, `alpha_mu`, and `alpha_pi`, and they will
    not be updated.
    However, the `W`, `gamma_mu`, and `gamma_pi` would change because their
    dimension depend on the number of samples, i.e., are sample specific.
    Parameters
    ----------
    val_data : `torch.Tensor`
        It is the validation data, a Tensor of shape (n_samples_val, n_features).
    model: An object of the `ZINB_Grad` class
        Please refer to the example.
    device : A `torch.device` object
        Please refer to Pytorch documentation for more details.
    epochs : int (optional, default=15)
        Number of iteration for training.
    X_val : `torch.Tensor` (optional, default=None)
        It is the X parameter of the ZINB-Grad model for the validation samples, a
        Tensor of shape (n_samples_val, M).
    Returns
    -------
    loss : float
         The validation loss
    neg_log_lik : float
         The validation negative log-likelihood
    """
    model_val = ZINB_Grad(
        Y=val_data,
        X=X_val,
        device=device,
        K=model.K,
        alpha_mu=model.alpha_mu.detach(),
        alpha_pi=model.alpha_pi.detach(),
        beta_mu=model.beta_mu.detach(),
        beta_pi=model.beta_pi.detach(),
        log_theta=model.log_theta.detach(),
    )
    # Tuning the validation model parameters (W and gammas)
    model_val.to(device)
    optimizer = torch.optim.Adam(model_val.parameters(), lr=0.1)
    losses, neg_log_liks = train_ZINB(
        val_data, optimizer, model_val, epochs=epochs, val=True
    )
    return losses[-1], neg_log_liks[-1]