diff --git a/metaseq/models/base_decoder.py b/metaseq/models/base_decoder.py index 4ef9e5b0793b0f72e1487053e8e1932c7445b1eb..a882a35502580856b00f714b514cdc17c943e831 100644 --- a/metaseq/models/base_decoder.py +++ b/metaseq/models/base_decoder.py @@ -15,7 +15,6 @@ class BaseDecoder(nn.Module): def __init__(self, dictionary): super().__init__() self.dictionary = dictionary - self.onnx_trace = False def forward(self, prev_output_tokens, **kwargs): """ @@ -61,13 +60,10 @@ class BaseDecoder(nn.Module): def get_normalized_probs_scriptable(self, logits: Tensor, log_probs: bool): """Get normalized probabilities (or log probs) from a net's output.""" if log_probs: - return utils.log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace) + return utils.log_softmax(logits, dim=-1) else: - return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace) + return utils.softmax(logits, dim=-1) def max_positions(self): """Maximum input length supported by the decoder.""" return 1e6 # an arbitrary large number - - def prepare_for_onnx_export_(self): - self.onnx_trace = True diff --git a/metaseq/models/base_model.py b/metaseq/models/base_model.py index b05ac6f14fbbe100de981e274bdc880129913df7..fb2f5b84472aa09df93d8522e45194f867b1021e 100644 --- a/metaseq/models/base_model.py +++ b/metaseq/models/base_model.py @@ -152,21 +152,6 @@ class BaseModel(nn.Module): self.eval() self.train = train - def prepare_for_onnx_export_(self, **kwargs): - """Make model exportable via ONNX trace.""" - seen = set() - - def apply_prepare_for_onnx_export_(module): - if ( - module != self - and hasattr(module, "prepare_for_onnx_export_") - and module not in seen - ): - seen.add(module) - module.prepare_for_onnx_export_(**kwargs) - - self.apply(apply_prepare_for_onnx_export_) - @classmethod def from_pretrained( cls, diff --git a/metaseq/modules/learned_positional_embedding.py b/metaseq/modules/learned_positional_embedding.py index 8b4e97536a65d1e63c41279062611d60988077a5..8aa6c31092de54073da41098ceed2f58a883886b 100644 --- a/metaseq/modules/learned_positional_embedding.py +++ b/metaseq/modules/learned_positional_embedding.py @@ -22,7 +22,6 @@ class LearnedPositionalEmbedding(nn.Embedding): def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int): super().__init__(num_embeddings, embedding_dim, padding_idx) - self.onnx_trace = False if self.padding_idx is not None: self.max_positions = self.num_embeddings - self.padding_idx - 1 else: @@ -41,11 +40,9 @@ class LearnedPositionalEmbedding(nn.Embedding): # we cannot use incremental state here because we must be aware of # padding. + if positions is None and self.padding_idx is not None: + positions = utils.make_positions(input, self.padding_idx) - if positions is None: - positions = utils.make_positions( - input, self.padding_idx, onnx_trace=self.onnx_trace - ) return F.embedding( positions, self.weight, diff --git a/metaseq/modules/multihead_attention.py b/metaseq/modules/multihead_attention.py index cd66613ad78c82632367e4cbbaa08bb90d06c3ff..4075de83314bdf8ba3da0ffd605cd9c45b04554b 100644 --- a/metaseq/modules/multihead_attention.py +++ b/metaseq/modules/multihead_attention.py @@ -98,14 +98,8 @@ class MultiheadAttention(nn.Module): self.bias_k = self.bias_v = None self.add_zero_attn = add_zero_attn - self.reset_parameters() - self.onnx_trace = False - - def prepare_for_onnx_export_(self): - self.onnx_trace = True - def reset_parameters(self): def _init_method_bias(weight, bias): fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weight) @@ -179,8 +173,7 @@ class MultiheadAttention(nn.Module): assert src_len, bsz == value.shape[:2] if ( - not self.onnx_trace - and incremental_state is None + incremental_state is None and not static_kv # A workaround for quantization to work. Otherwise JIT compilation # treats bias in linear module as method. @@ -346,10 +339,7 @@ class MultiheadAttention(nn.Module): # Replace any non-finite values with finite equivalents, since otherwise # we may get NaN when adding attn_mask or computing softmax. attn_weights = torch.nan_to_num(attn_weights) - attn_mask = attn_mask.unsqueeze(0) - if self.onnx_trace: - attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1) attn_weights += attn_mask if key_padding_mask is not None: @@ -364,21 +354,14 @@ class MultiheadAttention(nn.Module): if before_softmax: return attn_weights, v - attn_weights_float = utils.softmax( - attn_weights, dim=-1, onnx_trace=self.onnx_trace - ) + attn_weights_float = utils.softmax(attn_weights, dim=-1) attn_weights = attn_weights_float.type_as(attn_weights) attn_probs = self.dropout_module(attn_weights) assert v is not None attn = torch.bmm(attn_probs, v) assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] - if self.onnx_trace and attn.size(1) == 1: - # when ONNX tracing a single decoder step (sequence length == 1) - # the transpose is a no-op copy before view, thus unnecessary - attn = attn.contiguous().view(tgt_len, bsz, embed_dim) - else: - attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) attn = self.out_proj(attn) attn_weights: Optional[Tensor] = None if need_weights: diff --git a/metaseq/modules/sinusoidal_positional_embedding.py b/metaseq/modules/sinusoidal_positional_embedding.py index dce84c3d7e0f377bb6c51d5312cf6b91ac6eada8..a33631e7295ec1ff39690d8650b9dbf736c365e1 100644 --- a/metaseq/modules/sinusoidal_positional_embedding.py +++ b/metaseq/modules/sinusoidal_positional_embedding.py @@ -26,13 +26,9 @@ class SinusoidalPositionalEmbedding(nn.Module): self.weights = SinusoidalPositionalEmbedding.get_embedding( init_size, embedding_dim, padding_idx ) - self.onnx_trace = False self.register_buffer("_float_tensor", torch.FloatTensor(1)) self.max_positions = int(1e5) - def prepare_for_onnx_export_(self): - self.onnx_trace = True - @staticmethod def get_embedding( num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None @@ -79,26 +75,9 @@ class SinusoidalPositionalEmbedding(nn.Module): if incremental_state is not None: # positions is the same for every token when decoding a single step pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len - if self.onnx_trace: - return ( - self.weights.index_select(index=self.padding_idx + pos, dim=0) - .unsqueeze(1) - .repeat(bsz, 1, 1) - ) return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1) - positions = utils.make_positions( - input, self.padding_idx, onnx_trace=self.onnx_trace - ) - if self.onnx_trace: - flat_embeddings = self.weights.detach().index_select(0, positions.view(-1)) - embedding_shape = torch.cat( - (bsz.view(1), seq_len.view(1), torch.tensor([-1], dtype=torch.long)) - ) - embeddings = torch.onnx.operators.reshape_from_tensor_shape( - flat_embeddings, embedding_shape - ) - return embeddings + positions = utils.make_positions(input, self.padding_idx) return ( self.weights.index_select(0, positions.view(-1)) .view(bsz, seq_len, -1) diff --git a/metaseq/modules/transformer_decoder_layer.py b/metaseq/modules/transformer_decoder_layer.py index 343bda0dbf4dd4c721f9923f845a0d40bbef2b2a..9b170fb80b8ac2a75e5f4b5cf55cab2a08c61185 100644 --- a/metaseq/modules/transformer_decoder_layer.py +++ b/metaseq/modules/transformer_decoder_layer.py @@ -106,8 +106,6 @@ class TransformerDecoderLayer(nn.Module): self.final_layer_norm = LayerNorm(self.embed_dim, elementwise_affine=affine_ln) self.final_layer_norm.to(device).to(dtype) - - self.onnx_trace = False self.args = args # Refer to model_parallel's transformer layer for why fc1 and fc2 are separate methods. @@ -167,9 +165,6 @@ class TransformerDecoderLayer(nn.Module): dtype=utils.get_model_init_dtype(args), ) - def prepare_for_onnx_export_(self): - self.onnx_trace = True - def residual_connection(self, x, residual): return residual + x @@ -234,18 +229,6 @@ class TransformerDecoderLayer(nn.Module): ) l_aux = None x = self.residual_connection(x, residual) - if self.onnx_trace and incremental_state is not None: - saved_state = self.self_attn._get_input_buffer(incremental_state) - assert saved_state is not None - if self_attn_padding_mask is not None: - self_attn_state = [ - saved_state["prev_key"], - saved_state["prev_value"], - saved_state["prev_key_padding_mask"], - ] - else: - self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]] - return x, attn, self_attn_state return x, attn, None, l_aux def make_generation_fast_(self, **kwargs): diff --git a/metaseq/utils.py b/metaseq/utils.py index e7d04f45bb11b5896ff661399440230c13d6d672..a94ca46289fd6f052ff83adf165419143c394034 100644 --- a/metaseq/utils.py +++ b/metaseq/utils.py @@ -194,7 +194,7 @@ def post_process_prediction( return hypo_tokens, hypo_str, alignment -def make_positions(tensor, padding_idx: int, onnx_trace: bool = False): +def make_positions(tensor, padding_idx: int): """Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols are ignored. @@ -444,18 +444,12 @@ def import_user_module(args): ) -def softmax(x, dim: int, onnx_trace: bool = False): - if onnx_trace: - return F.softmax(x.float(), dim=dim) - else: - return F.softmax(x, dim=dim, dtype=torch.float32) +def softmax(x, dim: int): + return F.softmax(x, dim=dim, dtype=torch.float32) -def log_softmax(x, dim: int, onnx_trace: bool = False): - if onnx_trace: - return F.log_softmax(x.float(), dim=dim) - else: - return F.log_softmax(x, dim=dim, dtype=torch.float32) +def log_softmax(x, dim: int): + return F.log_softmax(x, dim=dim, dtype=torch.float32) def get_perplexity(loss, round=2, base=2):