Skip to content
GitLab
Projects Groups Snippets
  • /
  • Help
    • Help
    • Support
    • Community forum
    • Submit feedback
    • Contribute to GitLab
  • Sign in / Register
  • M metaseq
  • Project information
    • Project information
    • Activity
    • Labels
    • Members
  • Repository
    • Repository
    • Files
    • Commits
    • Branches
    • Tags
    • Contributors
    • Graph
    • Compare
  • Issues 95
    • Issues 95
    • List
    • Boards
    • Service Desk
    • Milestones
  • Merge requests 41
    • Merge requests 41
  • CI/CD
    • CI/CD
    • Pipelines
    • Jobs
    • Schedules
  • Deployments
    • Deployments
    • Environments
    • Releases
  • Packages and registries
    • Packages and registries
    • Package Registry
    • Infrastructure Registry
  • Monitor
    • Monitor
    • Incidents
  • Analytics
    • Analytics
    • Value stream
    • CI/CD
    • Repository
  • Wiki
    • Wiki
  • Snippets
    • Snippets
  • Activity
  • Graph
  • Create a new issue
  • Jobs
  • Commits
  • Issue Boards
Collapse sidebar
  • Administrator
  • metaseq
  • Merge requests
  • !72

Recursively unwrap fully sharded model in `convert_to_singleton.py`

  • Review changes

  • Download
  • Email patches
  • Plain diff
Merged Administrator requested to merge github/fork/thomasw21/convert_to_singleton into convert_to_singleton May 09, 2022
  • Overview 4
  • Commits 2
  • Pipelines 0
  • Changes 1

Created by: thomasw21

Patch Description Describe your changes

  • Related to: #60

convert_to_singleton.py doesn't seem to recusively unwrap fully sharded modules, leaving with the following parameters in restored.pt (notice the flat_param substring):

{'decoder.embed_positions.weight': torch.Size([2050, 2048]), 'decoder.embed_tokens.weight': torch.Size([50272, 2048]), 'decoder.layer_norm.bias': torch.Size([2048]), 'decoder.layer_norm.weight': torch.Size([2048]), 'decoder.layers.0.flat_param_0': torch.Size([25185280]), 'decoder.layers.1.flat_param_0': torch.Size([25185280]), 'decoder.layers.10.flat_param_0': torch.Size([25185280]), 'decoder.layers.11.flat_param_0': torch.Size([25185280]), 'decoder.layers.12.flat_param_0': torch.Size([25185280]), 'decoder.layers.13.flat_param_0': torch.Size([25185280]), 'decoder.layers.14.flat_param_0': torch.Size([25185280]), 'decoder.layers.15.flat_param_0': torch.Size([25185280]), 'decoder.layers.16.flat_param_0': torch.Size([25185280]), 'decoder.layers.17.flat_param_0': torch.Size([25185280]), 'decoder.layers.18.flat_param_0': torch.Size([25185280]), 'decoder.layers.19.flat_param_0': torch.Size([25185280]), 'decoder.layers.2.flat_param_0': torch.Size([25185280]), 'decoder.layers.20.flat_param_0': torch.Size([25185280]), 'decoder.layers.21.flat_param_0': torch.Size([25185280]), 'decoder.layers.22.flat_param_0': torch.Size([25185280]), 'decoder.layers.23.flat_param_0': torch.Size([25185280]), 'decoder.layers.3.flat_param_0': torch.Size([25185280]), 'decoder.layers.4.flat_param_0': torch.Size([25185280]), 'decoder.layers.5.flat_param_0': torch.Size([25185280]), 'decoder.layers.6.flat_param_0': torch.Size([25185280]), 'decoder.layers.7.flat_param_0': torch.Size([25185280]), 'decoder.layers.8.flat_param_0': torch.Size([25185280]), 'decoder.layers.9.flat_param_0': torch.Size([25185280]), 'decoder.version': torch.Size([1])}

Testing steps Describe how you tested your changes

Tested on 1B3 checkpoint, and the keys to restored.pt correspond to their unwrapped version. Haven't tested logits/generation as model should already be loaded from checkpoint before hand.

cc @stephenroller

Assignee
Assign to
Reviewers
Request review from
Time tracking
Source branch: github/fork/thomasw21/convert_to_singleton