Breaking down GPU VRAM consumption

26 Dec 2023

Highlight: Check out my GPU VRAM Calculator

Highlight: Check out HN discussion

I’ve always been curious about the GPU VRAM required for training and fine-tuning transformer-based language models. What factors influence VRAM consumption? How does it vary with different model settings? I dug into the topic and conducted my own measurements.

Other great resources on this topic include Stas Bekman’s section from his ML Engineering book, the core inspiration for Hugging Face’s model memory anatomy article. Also, check out Eleuther’s blog which also covers compute costs.

Quick note: This post doesn’t dig into the memory usage of quantized models and PEFT fine-tuning techniques like LoRA or QLoRA.

Prerequisites for experiments

When we talk about RAM, we often use GB (10**9 bytes) and GiB (2**30 bytes) interchangeably. But in reality, we’re dealing with GiB. Take the Nvidia 3090’s “24 GB VRAM” – it’s actually 24 GiB, or about 25.76 GB. To keep things clear, I’ll stick with MiB and GiB.

To measure VRAM usage accurately, we need to delete the variable, run garbage collection, clear the CUDA cache, and then measure the VRAM difference. Here’s an example:

x = torch.Tensor(4, 8192, 32000).cuda()
total_vram = get_vram()
del x; gc.collect(); torch.cuda.empty_cache()
x_vram = total_vram - get_vram()
# 4000 MiB

The ipyexperiments Python package automates this after each cell execution, which is pretty convenient.

Before assessing memory usage, it’s important to perform warm-up steps, essentially running the same code twice, to load CUDA kernels that weren’t loaded during the initial setup. Also, we should disable the cache in the decoder, which is used during inference to prevent re-computation of hidden states [1].

Mixed precision training

Understanding mixed precision training is key, as it’s commonly used in pretraining and finetuning. Normally, model parameters are stored in float32 format, taking up 4 bytes per parameter. Mixed precision training uses float16, halving the calculation time and reducing the size of activations.

But why “mixed”? The training isn’t entirely in half precision. Lower precision can lead to imprecise weight updates or even gradients turning to zero. So, in mixed precision training, the master copy of the weights is kept and updated in fp32, and before each forward pass, these weights are copied into fp16 format.

For a deeper dive into mixed precision, check out this fast.ai documentation, which includes a detailed illustration, and Aleksey Bilogur’s blog, which offers practical PyTorch code examples.

Handling multi-GPU scenarios

What if a model doesn’t fit on a single GPU? There are two scenarios:

  1. Inference: Use model parallelism to distribute layers across GPUs. This is done automatically in transformers with device_map="auto". Learn more in the accelerate docs.
  2. Training: Distribute layers, optimizer states, and gradients across GPUs. Depending on your setup, you might use different DeepSpeed ZeRO stages or FSDP [2] for full sharding. The more you shard, the slower training will be because of a communication overhead. For a comparison of multi-GPU training approaches, check out Hugging Face’s documentation.

Breaking down the components

Memory consumption consists of the following components:

Train Inference
CUDA Kernels
Parameters
Activations
Gradients
Optimizer States
Outputs

An interesting aspect of PyTorch is its approach to memory allocation. Essentially, PyTorch rarely releases memory once it’s been allocated. For instance, during the forward pass, activations are calculated and stored in memory. Even after these activations are no longer needed following the backward pass, the memory they occupy isn’t released. This strategy is adopted to avoid the overhead associated with frequent memory allocation calls [3].

CUDA Kernels

Upon first using the GPU, CUDA kernels will allocate between 300 MiB to 2000 MiB. This can vary based on GPU, driver, and PyTorch versions. It could be measured by initializing any small tensor and moving it to GPU:

x = torch.ones(1).cuda()

Parameters

When measuring the amount of memory that will be used by parameters, it is important to understand the difference between parameters and buffers. Parameters are the actual weights that are being trained and updated by the optimizer. They could be retrieved by calling model.parameters(). Apart from parameters there exist fixed tensors, which are needed in some computations, but which are not needed to be updated. These are called buffers and may be retrieved by calling model.buffers(). One example of buffers is precomputed positional encodings [4]. So, in this section, under ‘parameters’ I assume ‘parameters’ + ‘buffers’.

During inference, the memory needed for parameters is straightforward — it’s just the number of parameters multiplied by the number of bytes per parameter. You are specifying the number of bytes per parameter when loading a model like .from_pretrained(..., torch_dtype=torch.float16). For instance, a 7B-parameter model like Mistral, when loaded in half-precision (float16), would take 7.51 × 10**9 × 2 bytes, equating to 14324 MiB.

When training as usual, in full precision, 4 bytes per parameter are occupied. Mixed precision training is more common though, in this case, we have to maintain both half precision (for forward pass, 2 bytes per param) and full precision model weights (for applying updates to them, 4 bytes per param), so in total it takes 6 bytes per param.

Activations

‘Activations’ refer to the intermediate outputs essential for backpropagation. They are usually the memory bottleneck in transformer training, especially since their size scales quadratically with sequence length (we have to store the output of a softmax(Q×K.T) which has Batch Size × Number of Attention Heads × Sequence Length ** 2 shape). There are good estimations of activations size per layer in “Reducing Activation Recomputation in Large Transformer Models” paper in section 4.1 although for each model activations will differ. For example, in the mentioned paper they also count dropout masks whereas newer architectures like Llama don’t use dropout at all.

During training, we store all layer activations for backprop, but in inference, we only keep the current (single) layer’s activations.

We can reduce activations size on training in the cost of training speed (slowdown around 20%) by discarding the activations during the forward pass and recalculating them when needed during the backward pass, this is called gradient checkpointing.

Gradients

Gradients are always stored in full precision taking 4 bytes per parameter.

Optimizer states

Optimizers like Adam and SGD have their own memory needs. SGD with momentum and Adam both store a moving average of gradients for each parameter in full precision. Additionally, Adam keeps a moving average of squared gradients.

First Moments Second Moments Bytes per Param
SGD 0
SGD w momentum 4
ADAM 8

Outputs

Finally, the output tensors (Batch Size × Sequence Length × Vocabulary Size) are almost always in float32. This remains true even if the model was loaded in a lower precision because model itself casts outputs to float32 most of the time [5] [6].

While training, we also need to store probabilities F.softmax(logits, dim=-1) which are the same size as the output tensor.

Problems

In my experiments with measuring VRAM usage in the notebook, I am facing some persistent mismatch between what my experiments show and the calculated figures, particularly regarding the size of activations during the training’s forward pass. So there is still something to figure out!

Acknowledgements

Thanks to Stas Bekman for helping me shape my understanding and Quentin Anthony’s Python gist for VRAM calculation.


  1. What is the purpose of ‘use_cache’ in decoder? (discuss.huggingface.co) | ↩︎

  2. Introducing PyTorch Fully Sharded Data Parallel (FSDP) API (pytorch.org/blog) | ↩︎

  3. What exactly is occupying the GPU cache? (discuss.pytorch.org) | ↩︎

  4. What is the difference between register_buffer and register_parameter of nn.Module (discuss.pytorch.org) | ↩︎

  5. Llama 2 casts output tensor to float32 (github.com/facebookresearch/llama) | ↩︎

  6. Mistral casts output tensor to float32 (github.com/mistralai) | ↩︎


home