Source code for diagnnose.probe.logreg

from typing import Optional

import torch.nn as nn
import torch.nn.functional as F
from skorch import NeuralNetClassifier
from torch import Tensor


[docs]class LogRegModule(nn.Module): def __init__(self, ninp: int, nout: int, rank: Optional[int] = None): super().__init__() if rank is None: self.classifier = nn.Linear(ninp, nout) else: self.classifier = nn.Sequential( nn.Linear(ninp, rank), nn.Linear(rank, nout) )
[docs] def forward(self, inp: Tensor, create_softmax=True): if create_softmax: return F.softmax(self.classifier(inp), dim=-1) return self.classifier(inp)
# https://github.com/skorch-dev/skorch/blob/master/docs/user/neuralnet.rst#subclassing-neuralnet
[docs]class L1NeuralNetClassifier(NeuralNetClassifier): def __init__(self, *args, lambda1=0.01, **kwargs): super().__init__(*args, **kwargs) self.lambda1 = lambda1
[docs] def get_loss(self, y_pred, y_true, X=None, training=False): loss = super().get_loss(y_pred, y_true, X=X, training=training) loss += self.lambda1 * sum([w.abs().sum() for w in self.module_.parameters()]) return loss