Source code for diagnnose.activations.activation_index

from typing import Iterable, Optional, Sized

from torch import Tensor, long

from diagnnose.typedefs.activations import ActivationIndex


[docs]def activation_index_to_iterable( activation_index: ActivationIndex, stop_index: Optional[int] = None ) -> Iterable: """ Transforms an activation index into an iterable object. """ if isinstance(activation_index, Tensor): activation_index = activation_index.to(long) if isinstance(activation_index, Iterable): return activation_index if isinstance(activation_index, int): return [activation_index] if isinstance(activation_index, slice): stop_index = activation_index.stop or stop_index assert stop_index is not None, "Stop index of slice should be provided" return range( activation_index.start or 0, stop_index, activation_index.step or 1 ) raise ValueError( f"Activation index of incorrect type: {type(activation_index)}, " f"should be one of {{int, List[int], np.ndarray or torch.Tensor}}" )
[docs]def activation_index_len(activation_index: ActivationIndex) -> int: """ Returns the number of items in an activation index. """ activation_iterable = activation_index_to_iterable(activation_index) if isinstance(activation_iterable, Sized): return len(activation_iterable) raise ValueError( f"Activation index of incorrect type: {type(activation_index)}, " f"should be one of {{int, List[int], np.ndarray or torch.Tensor}}" )