Training AI models at a large scale isn’t easy. Aside from the need for large amounts of computing power and resources, there is also considerable engineering complexity behind training very large models. At Facebook AI Research (FAIR) Engineering, we have been working on building tools and infrastructure to make training large AI models easier. Our […]
Training AI models at a large scale isn’t easy. Aside from the need for large amounts of computing power and resources, there is also considerable engineering complexity behind training very large models. At Facebook AI Research (FAIR) Engineering, we have been working on building tools and infrastructure to make training large AI models easier. Our recent work in areas such as intra-layer model parallelism, pipeline model parallelism, optimizer state+gradient sharding, and mixture of experts is just part of our work to make training advanced AI models for any number of tasks more efficient.
Fully Sharded Data Parallel (FSDP) is the newest tool we’re introducing. It shards an AI model’s parameters across data parallel workers and can optionally offload part of the training computation to the CPUs. As its name suggests, FSDP is a type of data-parallel training algorithm. Although the parameters are sharded to different GPUs, the computation for each microbatch of data is still local to each GPU worker. This conceptual simplicity makes FSDP easier to understand and more applicable to a wide range of usage scenarios (compared with intra-layer parallelism and pipeline parallelism). Compared with optimizer state+gradient sharding data parallel methods, FSDP shards parameters more uniformly and is capable of better performance via communication and computation overlapping during training.
With FSDP, it is now possible to more efficiently train models that are orders of magnitude larger using fewer GPUs. FSDP has been implemented in the FairScale library and allows engineers and developers to scale and optimize the training of their models with simple APIs. At Facebook, FSDP has already been integrated and tested for training some of our NLP and Vision models.
NLP research is one particular area where we can see the importance of efficiently leveraging compute for training AI. Last year, OpenAI announced that they had trained GPT-3, the largest-ever neural language model, with 175 billion parameters. It is estimated to have taken roughly 355 GPU years to train GPT-3, or the equivalent of 1,000 GPUs working continuously for more than four months.
Besides requiring a lot of compute and engineering resources, most approaches to scaling like this introduce additional communication costs and require engineers to carefully evaluate trade-offs between memory use and computational efficiency. For example, typical data parallel training requires maintaining redundant copies of the model on each GPU, and model parallel training introduces additional communication costs to move activations between workers (GPUs).
FSDP is relatively free of trade-offs in comparison. It improves memory efficiency by sharding model parameters, gradients, and optimizer states across GPUs, and improves computational efficiency by decomposing the communication and overlapping it with both the forward and backward passes. FSDP produces identical results as standard distributed data parallel (DDP) training and is available in an easy-to-use interface that’s a drop-in replacement for PyTorch’s DistributedDataParallel module. Our early testing has shown that FSDP can enable scaling to trillions of parameters.
In standard DDP training, every worker processes a separate batch and the gradients are summed across workers using an all-reduce operation. While DDP has become very popular, it takes more GPU memory than it needs because the model weights and optimizer states are replicated across all DDP workers.
One method to reduce replications is to apply a process called full parameter sharding, where only a subset of the model parameters, gradients, and optimizers needed for a local computation is made available. An implementation of this method, ZeRO-3, has already been popularized by Microsoft.
We can then rearrange the reduce-scatter and all-gather so that each DDP worker needs to store only a single shard of parameters and optimizer states. The figure below illustrates standard DDP training (top) and FSDP training (bottom):
To maximize memory efficiency, we can discard the full weights after each layer’s forward pass, saving memory for subsequent layers. This can be implemented by applying the FSDP wrapper to every layer in the network (with reshard_after_forward=True).
FSDP forward pass: for layer_i in layers: all-gather full weights for layer_i forward pass for layer_i discard full weights for layer_i FSDP backward pass: for layer_i in layers: all-gather full weights for layer_i backward pass for layer_i discard full weights for layer_i reduce-scatter gradients for layer_i
There are several ways to use FSDP in large-scale AI research. At this time, we offer four solutions to adapt to different needs.
For language models, FSDP is supported in the fairseq framework via the following new arguments:
See the fairseq tutorial for instructions on using FSDP to train a 13B-parameter model on eight GPUs or on a single GPU with FSDP + CPU offloading.
For computer vision models, FSDP is supported in VISSL and tested on RegNets architectures. Layers like BatchNorm and ReLU are seamlessly handled and tested for convergence.
Use the following options to enable FSDP:
See this section of the yaml config for additional options to config FSDP within VISSL.
For easier integration with more general use cases, FSDP is supported as a beta feature by PyTorch Lightning. This tutorial contains a detailed example on how to use the FSDP plugin with PyTorch Lightning. At a high level, adding plugins=’fsdp’ below can activate it.
model = MyModel() trainer = Trainer(gpus=4, plugins='fsdp', precision=16) trainer.fit(model) trainer.test() trainer.predict()
The main library where FSDP has been developed, and where you can find the latest updates, is FairScale. You can directly use FSDP from FairScale with the below example by simply replacing the DDP(my_module):
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP ... sharded_module =
DDP(my_module)FSDP(my_module) optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001) for sample, label in dataload.next_batch: out = sharded_module(x=sample, y=3, z=torch.Tensor()) loss = criterion(out, label) loss.backward() optim.step()
The FSDP library in FairScale exposes the low-level options for many important aspects of large-scale training. Here are some few important areas to consider when you apply FSDP with its full power.
FSDP is open source, and early users have tried it and contributed to it. We think it can benefit the entire research community, and we look forward to working with everyone in making it better. In particular, these are some of the important areas.
FSDP is currently available directly from the FairScale library.
Thanks for sticking with us thus far. Please try FSDP in your research or production work. We would love to hear your feedback, and, as always, pull requests are welcome!
The post Fully Sharded Data Parallel: faster AI training with fewer GPUs appeared first on Facebook Engineering.