Source code for diagnnose.attribute.gcd_tensor

import torch

from .shapley_tensor import ShapleyTensor


[docs]class GCDTensor(ShapleyTensor):
[docs] def mul_contributions(self, *args, **kwargs): gate, source = args if not isinstance(gate, ShapleyTensor): contributions = [ torch.mul(gate, contribution, **kwargs) for contribution in source.contributions ] elif not isinstance(source, ShapleyTensor): contributions = [ torch.mul(contribution, source, **kwargs) for contribution in gate.contributions ] else: gate_contributions = sum(gate.contributions) contributions = [ torch.mul(gate_contributions, source_contribution, **kwargs) for source_contribution in source.contributions ] return contributions