diff --git a/metaseq/modules/multihead_attention.py b/metaseq/modules/multihead_attention.py index cd66613ad78c82632367e4cbbaa08bb90d06c3ff..65bc1737518c9e758214407771fa9fae1e46ec4b 100644 --- a/metaseq/modules/multihead_attention.py +++ b/metaseq/modules/multihead_attention.py @@ -338,8 +338,6 @@ class MultiheadAttention(nn.Module): ) attn_weights = torch.bmm(q, k.transpose(1, 2)) - attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) - assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] if attn_mask is not None: @@ -427,6 +425,3 @@ class MultiheadAttention(nn.Module): buffer: Dict[str, Optional[Tensor]], ): return self.set_incremental_state(incremental_state, "attn_state", buffer) - - def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): - return attn_weights