diff --git a/metaseq/models/base_decoder.py b/metaseq/models/base_decoder.py
index 4ef9e5b0793b0f72e1487053e8e1932c7445b1eb..a882a35502580856b00f714b514cdc17c943e831 100644
--- a/metaseq/models/base_decoder.py
+++ b/metaseq/models/base_decoder.py
@@ -15,7 +15,6 @@ class BaseDecoder(nn.Module):
     def __init__(self, dictionary):
         super().__init__()
         self.dictionary = dictionary
-        self.onnx_trace = False
 
     def forward(self, prev_output_tokens, **kwargs):
         """
@@ -61,13 +60,10 @@ class BaseDecoder(nn.Module):
     def get_normalized_probs_scriptable(self, logits: Tensor, log_probs: bool):
         """Get normalized probabilities (or log probs) from a net's output."""
         if log_probs:
-            return utils.log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
+            return utils.log_softmax(logits, dim=-1)
         else:
-            return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
+            return utils.softmax(logits, dim=-1)
 
     def max_positions(self):
         """Maximum input length supported by the decoder."""
         return 1e6  # an arbitrary large number
-
-    def prepare_for_onnx_export_(self):
-        self.onnx_trace = True
diff --git a/metaseq/models/base_model.py b/metaseq/models/base_model.py
index b05ac6f14fbbe100de981e274bdc880129913df7..fb2f5b84472aa09df93d8522e45194f867b1021e 100644
--- a/metaseq/models/base_model.py
+++ b/metaseq/models/base_model.py
@@ -152,21 +152,6 @@ class BaseModel(nn.Module):
         self.eval()
         self.train = train
 
-    def prepare_for_onnx_export_(self, **kwargs):
-        """Make model exportable via ONNX trace."""
-        seen = set()
-
-        def apply_prepare_for_onnx_export_(module):
-            if (
-                module != self
-                and hasattr(module, "prepare_for_onnx_export_")
-                and module not in seen
-            ):
-                seen.add(module)
-                module.prepare_for_onnx_export_(**kwargs)
-
-        self.apply(apply_prepare_for_onnx_export_)
-
     @classmethod
     def from_pretrained(
         cls,
diff --git a/metaseq/modules/learned_positional_embedding.py b/metaseq/modules/learned_positional_embedding.py
index 8b4e97536a65d1e63c41279062611d60988077a5..8aa6c31092de54073da41098ceed2f58a883886b 100644
--- a/metaseq/modules/learned_positional_embedding.py
+++ b/metaseq/modules/learned_positional_embedding.py
@@ -22,7 +22,6 @@ class LearnedPositionalEmbedding(nn.Embedding):
 
     def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
         super().__init__(num_embeddings, embedding_dim, padding_idx)
-        self.onnx_trace = False
         if self.padding_idx is not None:
             self.max_positions = self.num_embeddings - self.padding_idx - 1
         else:
@@ -41,11 +40,9 @@ class LearnedPositionalEmbedding(nn.Embedding):
 
         # we cannot use incremental state here because we must be aware of
         # padding.
+        if positions is None and self.padding_idx is not None:
+            positions = utils.make_positions(input, self.padding_idx)
 
-        if positions is None:
-            positions = utils.make_positions(
-                input, self.padding_idx, onnx_trace=self.onnx_trace
-            )
         return F.embedding(
             positions,
             self.weight,
diff --git a/metaseq/modules/multihead_attention.py b/metaseq/modules/multihead_attention.py
index cd66613ad78c82632367e4cbbaa08bb90d06c3ff..4075de83314bdf8ba3da0ffd605cd9c45b04554b 100644
--- a/metaseq/modules/multihead_attention.py
+++ b/metaseq/modules/multihead_attention.py
@@ -98,14 +98,8 @@ class MultiheadAttention(nn.Module):
             self.bias_k = self.bias_v = None
 
         self.add_zero_attn = add_zero_attn
-
         self.reset_parameters()
 
-        self.onnx_trace = False
-
-    def prepare_for_onnx_export_(self):
-        self.onnx_trace = True
-
     def reset_parameters(self):
         def _init_method_bias(weight, bias):
             fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weight)
@@ -179,8 +173,7 @@ class MultiheadAttention(nn.Module):
                 assert src_len, bsz == value.shape[:2]
 
         if (
-            not self.onnx_trace
-            and incremental_state is None
+            incremental_state is None
             and not static_kv
             # A workaround for quantization to work. Otherwise JIT compilation
             # treats bias in linear module as method.
@@ -346,10 +339,7 @@ class MultiheadAttention(nn.Module):
             # Replace any non-finite values with finite equivalents, since otherwise
             # we may get NaN when adding attn_mask or computing softmax.
             attn_weights = torch.nan_to_num(attn_weights)
-
             attn_mask = attn_mask.unsqueeze(0)
-            if self.onnx_trace:
-                attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
             attn_weights += attn_mask
 
         if key_padding_mask is not None:
@@ -364,21 +354,14 @@ class MultiheadAttention(nn.Module):
         if before_softmax:
             return attn_weights, v
 
-        attn_weights_float = utils.softmax(
-            attn_weights, dim=-1, onnx_trace=self.onnx_trace
-        )
+        attn_weights_float = utils.softmax(attn_weights, dim=-1)
         attn_weights = attn_weights_float.type_as(attn_weights)
         attn_probs = self.dropout_module(attn_weights)
 
         assert v is not None
         attn = torch.bmm(attn_probs, v)
         assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
-        if self.onnx_trace and attn.size(1) == 1:
-            # when ONNX tracing a single decoder step (sequence length == 1)
-            # the transpose is a no-op copy before view, thus unnecessary
-            attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
-        else:
-            attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+        attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
         attn = self.out_proj(attn)
         attn_weights: Optional[Tensor] = None
         if need_weights:
diff --git a/metaseq/modules/sinusoidal_positional_embedding.py b/metaseq/modules/sinusoidal_positional_embedding.py
index dce84c3d7e0f377bb6c51d5312cf6b91ac6eada8..a33631e7295ec1ff39690d8650b9dbf736c365e1 100644
--- a/metaseq/modules/sinusoidal_positional_embedding.py
+++ b/metaseq/modules/sinusoidal_positional_embedding.py
@@ -26,13 +26,9 @@ class SinusoidalPositionalEmbedding(nn.Module):
         self.weights = SinusoidalPositionalEmbedding.get_embedding(
             init_size, embedding_dim, padding_idx
         )
-        self.onnx_trace = False
         self.register_buffer("_float_tensor", torch.FloatTensor(1))
         self.max_positions = int(1e5)
 
-    def prepare_for_onnx_export_(self):
-        self.onnx_trace = True
-
     @staticmethod
     def get_embedding(
         num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None
@@ -79,26 +75,9 @@ class SinusoidalPositionalEmbedding(nn.Module):
         if incremental_state is not None:
             # positions is the same for every token when decoding a single step
             pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
-            if self.onnx_trace:
-                return (
-                    self.weights.index_select(index=self.padding_idx + pos, dim=0)
-                    .unsqueeze(1)
-                    .repeat(bsz, 1, 1)
-                )
             return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
 
-        positions = utils.make_positions(
-            input, self.padding_idx, onnx_trace=self.onnx_trace
-        )
-        if self.onnx_trace:
-            flat_embeddings = self.weights.detach().index_select(0, positions.view(-1))
-            embedding_shape = torch.cat(
-                (bsz.view(1), seq_len.view(1), torch.tensor([-1], dtype=torch.long))
-            )
-            embeddings = torch.onnx.operators.reshape_from_tensor_shape(
-                flat_embeddings, embedding_shape
-            )
-            return embeddings
+        positions = utils.make_positions(input, self.padding_idx)
         return (
             self.weights.index_select(0, positions.view(-1))
             .view(bsz, seq_len, -1)
diff --git a/metaseq/modules/transformer_decoder_layer.py b/metaseq/modules/transformer_decoder_layer.py
index 343bda0dbf4dd4c721f9923f845a0d40bbef2b2a..9b170fb80b8ac2a75e5f4b5cf55cab2a08c61185 100644
--- a/metaseq/modules/transformer_decoder_layer.py
+++ b/metaseq/modules/transformer_decoder_layer.py
@@ -106,8 +106,6 @@ class TransformerDecoderLayer(nn.Module):
 
         self.final_layer_norm = LayerNorm(self.embed_dim, elementwise_affine=affine_ln)
         self.final_layer_norm.to(device).to(dtype)
-
-        self.onnx_trace = False
         self.args = args
 
     # Refer to model_parallel's transformer layer for why fc1 and fc2 are separate methods.
@@ -167,9 +165,6 @@ class TransformerDecoderLayer(nn.Module):
             dtype=utils.get_model_init_dtype(args),
         )
 
-    def prepare_for_onnx_export_(self):
-        self.onnx_trace = True
-
     def residual_connection(self, x, residual):
         return residual + x
 
@@ -234,18 +229,6 @@ class TransformerDecoderLayer(nn.Module):
         )
         l_aux = None
         x = self.residual_connection(x, residual)
-        if self.onnx_trace and incremental_state is not None:
-            saved_state = self.self_attn._get_input_buffer(incremental_state)
-            assert saved_state is not None
-            if self_attn_padding_mask is not None:
-                self_attn_state = [
-                    saved_state["prev_key"],
-                    saved_state["prev_value"],
-                    saved_state["prev_key_padding_mask"],
-                ]
-            else:
-                self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]]
-            return x, attn, self_attn_state
         return x, attn, None, l_aux
 
     def make_generation_fast_(self, **kwargs):
diff --git a/metaseq/utils.py b/metaseq/utils.py
index e7d04f45bb11b5896ff661399440230c13d6d672..a94ca46289fd6f052ff83adf165419143c394034 100644
--- a/metaseq/utils.py
+++ b/metaseq/utils.py
@@ -194,7 +194,7 @@ def post_process_prediction(
     return hypo_tokens, hypo_str, alignment
 
 
-def make_positions(tensor, padding_idx: int, onnx_trace: bool = False):
+def make_positions(tensor, padding_idx: int):
     """Replace non-padding symbols with their position numbers.
 
     Position numbers begin at padding_idx+1. Padding symbols are ignored.
@@ -444,18 +444,12 @@ def import_user_module(args):
                 )
 
 
-def softmax(x, dim: int, onnx_trace: bool = False):
-    if onnx_trace:
-        return F.softmax(x.float(), dim=dim)
-    else:
-        return F.softmax(x, dim=dim, dtype=torch.float32)
+def softmax(x, dim: int):
+    return F.softmax(x, dim=dim, dtype=torch.float32)
 
 
-def log_softmax(x, dim: int, onnx_trace: bool = False):
-    if onnx_trace:
-        return F.log_softmax(x.float(), dim=dim)
-    else:
-        return F.log_softmax(x, dim=dim, dtype=torch.float32)
+def log_softmax(x, dim: int):
+    return F.log_softmax(x, dim=dim, dtype=torch.float32)
 
 
 def get_perplexity(loss, round=2, base=2):