Highlight: Check out my GPU VRAM Calculator
Highlight: Check out HN discussion
I’ve always been curious about the GPU VRAM required for training and fine-tuning transformer-based language models. What factors influence VRAM consumption? How does it vary with different model settings? I dug into the topic and conducted my own measurements.
Other great resources on this topic include Stas Bekman’s section from his ML Engineering book, the core inspiration for Hugging Face’s model memory anatomy article. Also, check out Eleuther’s blog which also covers compute costs.
Quick note: This post doesn’t dig into the memory usage of quantized models and PEFT fine-tuning techniques like LoRA or QLoRA.
When we talk about RAM, we often use GB (10**9 bytes) and GiB (2**30 bytes) interchangeably. But in reality, we’re dealing with GiB. Take the Nvidia 3090’s “24 GB VRAM” – it’s actually 24 GiB, or about 25.76 GB. To keep things clear, I’ll stick with MiB and GiB.
To measure VRAM usage accurately, we need to delete the variable, run garbage collection, clear the CUDA cache, and then measure the VRAM difference. Here’s an example:
= torch.Tensor(4, 8192, 32000).cuda()
x = get_vram()
total_vram del x; gc.collect(); torch.cuda.empty_cache()
= total_vram - get_vram()
x_vram # 4000 MiB
The ipyexperiments Python package automates this after each cell execution, which is pretty convenient.
Before assessing memory usage, it’s important to perform warm-up steps, essentially running the same code twice, to load CUDA kernels that weren’t loaded during the initial setup. Also, we should disable the cache in the decoder, which is used during inference to prevent re-computation of hidden states [1].
Understanding mixed precision training is key, as it’s commonly used in pretraining and finetuning. Normally, model parameters are stored in float32 format, taking up 4 bytes per parameter. Mixed precision training uses float16, halving the calculation time and reducing the size of activations.
But why “mixed”? The training isn’t entirely in half precision. Lower precision can lead to imprecise weight updates or even gradients turning to zero. So, in mixed precision training, the master copy of the weights is kept and updated in fp32, and before each forward pass, these weights are copied into fp16 format.
For a deeper dive into mixed precision, check out this fast.ai documentation, which includes a detailed illustration, and Aleksey Bilogur’s blog, which offers practical PyTorch code examples.
What if a model doesn’t fit on a single GPU? There are two scenarios:
device_map="auto"
. Learn more in the accelerate docs.Memory consumption consists of the following components:
Train | Inference | |
---|---|---|
CUDA Kernels | ✅ | ✅ |
Parameters | ✅ | ✅ |
Activations | ✅ | ✅ |
Gradients | ✅ | ❌ |
Optimizer States | ✅ | ❌ |
Outputs | ✅ | ✅ |
An interesting aspect of PyTorch is its approach to memory allocation. Essentially, PyTorch rarely releases memory once it’s been allocated. For instance, during the forward pass, activations are calculated and stored in memory. Even after these activations are no longer needed following the backward pass, the memory they occupy isn’t released. This strategy is adopted to avoid the overhead associated with frequent memory allocation calls [3].
Upon first using the GPU, CUDA kernels will allocate between 300 MiB to 2000 MiB. This can vary based on GPU, driver, and PyTorch versions. It could be measured by initializing any small tensor and moving it to GPU:
= torch.ones(1).cuda() x
When measuring the amount of memory that will be used by
parameters, it is important to understand the difference between
parameters and buffers. Parameters are the actual weights that are
being trained and updated by the optimizer. They could be retrieved
by calling model.parameters()
. Apart from parameters
there exist fixed tensors, which are needed in some computations,
but which are not needed to be updated. These are called buffers and
may be retrieved by calling model.buffers()
. One
example of buffers is precomputed positional encodings [4]. So, in this section, under
‘parameters’ I assume ‘parameters’ + ‘buffers’.
During inference, the memory needed for parameters is
straightforward — it’s just the number of parameters multiplied by
the number of bytes per parameter. You are specifying the number of
bytes per parameter when loading a model like
.from_pretrained(..., torch_dtype=torch.float16)
. For
instance, a 7B-parameter model like Mistral, when loaded in
half-precision (float16), would take 7.51 × 10**9 × 2 bytes,
equating to 14324 MiB.
When training as usual, in full precision, 4 bytes per parameter are occupied. Mixed precision training is more common though, in this case, we have to maintain both half precision (for forward pass, 2 bytes per param) and full precision model weights (for applying updates to them, 4 bytes per param), so in total it takes 6 bytes per param.
‘Activations’ refer to the intermediate outputs essential for
backpropagation. They are usually the memory bottleneck in
transformer training, especially since their size scales
quadratically with sequence length (we have to store the output of a
softmax(Q×K.T)
which has Batch Size × Number of
Attention Heads × Sequence Length ** 2 shape). There are good
estimations of activations size per layer in “Reducing
Activation Recomputation in Large Transformer Models” paper in
section 4.1 although for each model activations will differ. For
example, in the mentioned paper they also count dropout masks
whereas newer architectures like Llama don’t use dropout at all.
During training, we store all layer activations for backprop, but in inference, we only keep the current (single) layer’s activations.
We can reduce activations size on training in the cost of training speed (slowdown around 20%) by discarding the activations during the forward pass and recalculating them when needed during the backward pass, this is called gradient checkpointing.
Gradients are always stored in full precision taking 4 bytes per parameter.
Optimizers like Adam and SGD have their own memory needs. SGD with momentum and Adam both store a moving average of gradients for each parameter in full precision. Additionally, Adam keeps a moving average of squared gradients.
First Moments | Second Moments | Bytes per Param | |
---|---|---|---|
SGD | ❌ | ❌ | 0 |
SGD w momentum | ✅ | ❌ | 4 |
ADAM | ✅ | ✅ | 8 |
Finally, the output tensors (Batch Size × Sequence Length × Vocabulary Size) are almost always in float32. This remains true even if the model was loaded in a lower precision because model itself casts outputs to float32 most of the time [5] [6].
While training, we also need to store probabilities
F.softmax(logits, dim=-1)
which are the same size as
the output tensor.
In my experiments with measuring VRAM usage in the notebook, I am facing some persistent mismatch between what my experiments show and the calculated figures, particularly regarding the size of activations during the training’s forward pass. So there is still something to figure out!
Thanks to Stas Bekman for helping me shape my understanding and Quentin Anthony’s Python gist for VRAM calculation.
What is the purpose of ‘use_cache’ in decoder? (discuss.huggingface.co) | ↩︎
Introducing PyTorch Fully Sharded Data Parallel (FSDP) API (pytorch.org/blog) | ↩︎
What exactly is occupying the GPU cache? (discuss.pytorch.org) | ↩︎
What is the difference between
register_buffer
and register_parameter
of
nn.Module
(discuss.pytorch.org) | ↩︎
Llama 2 casts output tensor to float32 (github.com/facebookresearch/llama) | ↩︎
Mistral casts output tensor to float32 (github.com/mistralai) | ↩︎