TorchTune: Fine-Tuning Made Easy
Co-Author: Dikshachandra
In today’s rapidly advancing world of Generative AI, fine-tuning Large Language Models (LLMs) has become a pivotal task for researchers and developers. However, it can often be challenging, especially when aiming for both performance and simplicity. Enter TorchTune, a PyTorch-based library designed to simplify and democratize the fine-tuning of LLMs, making it accessible across various hardware configurations without compromising on functionality or precision.
In this article, we will explore fine-tuning of Meta Llama3–8B–Instruct model with torchtune. We’ll walk through downloading the model weights and tokenizer from Hugging Face, followed by setting up torchtune’s recipe for fine-tuning with a QLoRA configuration and logging our training runs to Weights and Biases (WandB) in order to track and visualize various aspects of the model training process in real-time.
What is TorchTune?
Torchtune is a dedicated library built on PyTorch, designed specifically for fine-tuning large language models (LLMs). It streamlines the process, allowing developers and researchers to easily configure, run, and experiment with advanced LLMs. The core philosophy behind Torchtune revolves around four essential pillars:
- Simplicity and Extensibility — The library is designed to be native to PyTorch, ensuring ease of use and reusability.
- Correctness — A high standard of accuracy and reliability is maintained through rigorous testing of every component.
- Stability — Just as PyTorch is known for its robustness, Torchtune ensures consistent and reliable functionality.
- Democratization of Fine-Tuning — TorchTune enables effective fine-tuning across various hardware configurations, making advanced model optimization accessible to users with different levels of computational resources.
PyTorch: The Foundation of Torchtune
At its core, Torchtune is built on PyTorch, one of the most widely-used deep learning libraries. PyTorch is beloved by the ML community for its flexibility, dynamic computation graphs, and ease of debugging. This makes PyTorch the perfect foundation for a library like Torchtune, which relies on fine-grained control over model training while ensuring compatibility with existing PyTorch workflows.
PyTorch’s native functionality ensures that Torchtune integrates seamlessly into existing workflows, making it accessible to developers who are already familiar with PyTorch. Torchtune adds an additional layer of abstraction, making the fine-tuning of LLMs much simpler while maintaining full control and transparency over the process.
Key Features of Torchtune
Torchtune is packed with features that make it a powerful tool for LLM fine-tuning. Below are some of its standout features:
1. Modular Implementations of Popular LLMs
Torchtune provides out-of-the-box support for many popular LLM architectures. This modular approach means you can easily select and experiment with various models, all while benefiting from PyTorch’s native support.
2. Checkpoint Conversion for Model Interoperability
The library comes equipped with utilities for converting model checkpoints, enabling seamless interoperability with various model zoos. Whether you’re using Hugging Face models or others, Torchtune has you covered.
3. Comprehensive Training Recipes
One of Torchtune’s most powerful features is its training recipes, which are end-to-end pipelines for fine-tuning and evaluating LLMs. These recipes include best practices for specific tasks, such as full fine-tuning or parameter-efficient methods like LoRA (Low-Rank Adaptation).
4. Seamless Integration with Hugging Face Datasets
Data is the foundation of any fine-tuning task. Torchtune integrates effortlessly with Hugging Face Datasets, allowing users to quickly access and preprocess datasets for training. This integration reduces the hassle of data management, making it easier to focus on model performance.
5. Distributed Training with FSDP
For users with large models and data, Torchtune supports Fully Sharded Data Parallel (FSDP) training, enabling distributed training across multiple GPUs. This makes it easier to fine-tune massive models without running into memory constraints.
6. YAML Configuration for Easy Experiment Setup
Torchtune uses YAML files for configuring your training runs. These configurations can be used to set hyperparameters, model choices, dataset paths, and more. The use of YAML ensures that your experiment setups remain consistent and easy to share.
Key Concepts to Master in Torchtune
As you dive deeper into using Torchtune, it’s important to understand the following key concepts:
- Configs: The YAML-based configuration files enable you to control various aspects of the training process, including model selection, dataset paths, checkpoints, and hyperparameters like learning rate and batch size. With configs, there’s no need to manually modify code for each experiment.
- Recipes: Recipes are pre-built pipelines that define the training and evaluation workflows. Each recipe includes various techniques like FSDP, gradient accumulation, activation checkpointing, and reduced precision training. Recipes are tailored to specific model families, ensuring optimal performance.
Fine-Tuning Recipes
TorchTune offers a range of fine-tuning recipes designed to cater to various hardware setups and training scenarios:
For Distributed Training across 1 to 8 GPUs, users can opt for:
- Full Fine-Tuning: Training the entire model end-to-end.
- LoRA (Low-Rank Adaptation): A parameter-efficient fine-tuning technique that adds pairs of rank-decomposition matrices to the original model layers. This method significantly reduces the number of trainable parameters while maintaining performance.
On a Single Device with limited memory (1 GPU), TorchTune provides:
- Full Fine-Tuning: Tailored for lower memory availability.
- LoRA + QLoRA: Combines LoRA with QLoRA (Quantization-aware Low-Rank Adaptation). QLoRA applies quantization techniques to further reduce memory usage while preserving model quality.
For Single Device Training (1 GPU), additional options include:
- DPO (Direct Preference Optimization): A method for aligning language models with human preferences without using reinforcement learning.
- RLHF (Reinforcement Learning from Human Feedback) using PPO (Proximal Policy Optimization): A technique that fine-tunes models using reinforcement learning based on human feedback.
These recipes ensure flexibility and performance across diverse hardware environments.
Ecosystem Integration
TorchTune seamlessly integrates with popular tools and libraries, ensuring flexibility and ease of use throughout the fine-tuning process. These integrations allow users to efficiently access resources, optimize their workflows, and track progress without hassle:
- Hugging Face Hub: Provides access to a vast repository of pre-trained model weights for easy fine-tuning.
- Hugging Face Datasets: Offers a collection of training and evaluation datasets, simplifying data handling.
- PyTorch FSDP: Supports distributed training across multiple GPUs, making it scalable and efficient.
- bitsandbytes: Optimizes memory usage for single-device fine-tuning, enabling smoother performance on limited hardware.
- Weights & Biases: Allows detailed logging of metrics, checkpoints, and tracking of training runs.
- Comet: An alternative platform for logging and experiment tracking.
- ExecuTorch: Enables on-device inference, making it suitable for deploying models on mobile and edge devices.
- torchao: Facilitates post-training quantization, reducing memory consumption and speeding up inference.
Supported Models
TorchTune currently supports a wide range of leading large language models (LLMs), allowing users to fine-tune models tailored to their specific tasks. As the library evolves, more models will be added to further expand its capabilities.
- Llama3: Available in 8B and 70B variants for versatile fine-tuning.
- Llama2: Comes in 7B, 13B, and 70B models, providing different sizes for various applications.
- Code-Llama2: Similar to Llama2, but tailored for code-related tasks.
- Mistral: A powerful 7B model designed for performance.
- Gemma: Offers smaller models with 2B and 7B versions, suitable for a wider range of hardware setups.
- Microsoft Phi3: Mini-sized for efficient performance.
- Qwen2: Features models ranging from 0.5B to 7B, providing flexibility in scaling.
Practical TorchTune Walkthrough
Now that we’ve covered the theoretical aspects, let’s dive into the practical side. In this section, we’ll walk through the steps of setting up the environment, downloading the instruction-tuned version of Llama3–8B model, applying QLoRA for fine-tuning, and tracking the training process in Weights and Biases (WandB) dashboard.
Install required packages
Ensure your system is up-to-date with the latest versions of PyTorch and TorchTune. Install W&B for tracking your training progress.
pip install torch --upgrade
pip install torchtune --upgrade
pip install wandb --upgrade
Explore torchtune commands
Check available commands in TorchTune.
tune --help
usage: tune [-h] {download,ls,cp,run,validate} ...
Welcome to the torchtune CLI!
options:
-h, --help show this help message and exit
subcommands:
{download,ls,cp,run,validate}
download Download a model from the Hugging Face Hub.
ls List all built-in recipes and configs
cp Copy a built-in recipe or config to a local path.
run Run a recipe. For distributed recipes, this supports all torchrun
arguments.
validate Validate a config and ensure that it is well-formed.
List all built-in recipes and configs
To view different finetuning recipes and the associated configs.
Each recipe consists of three components:
Configurable parameters, specified through yaml configs and command-line overrides
Recipe Script, entry-point which puts everything together including parsing and validating configs, setting up the environment, and correctly using the recipe class
Recipe Class, core logic needed for training, exposed to users through a set of APIs
tune ls
RECIPE CONFIG
full_finetune_single_device llama2/7B_full_low_memory
code_llama2/7B_full_low_memory
llama3/8B_full_single_device
llama3_1/8B_full_single_device
mistral/7B_full_low_memory
phi3/mini_full_low_memory
full_finetune_distributed llama2/7B_full
llama2/13B_full
llama3/8B_full
llama3_1/8B_full
llama3/70B_full
llama3_1/70B_full
mistral/7B_full
gemma/2B_full
gemma/7B_full
phi3/mini_full
lora_finetune_single_device llama2/7B_lora_single_device
llama2/7B_qlora_single_device
code_llama2/7B_lora_single_device
code_llama2/7B_qlora_single_device
llama3/8B_lora_single_device
llama3_1/8B_lora_single_device
llama3/8B_qlora_single_device
llama3_1/8B_qlora_single_device
llama2/13B_qlora_single_device
mistral/7B_lora_single_device
mistral/7B_qlora_single_device
gemma/2B_lora_single_device
gemma/2B_qlora_single_device
gemma/7B_lora_single_device
gemma/7B_qlora_single_device
phi3/mini_lora_single_device
phi3/mini_qlora_single_device
lora_dpo_single_device llama2/7B_lora_dpo_single_device
lora_dpo_distributed llama2/7B_lora_dpo
lora_finetune_distributed llama2/7B_lora
llama2/13B_lora
llama2/70B_lora
llama3/70B_lora
llama3_1/70B_lora
llama3/8B_lora
llama3_1/8B_lora
mistral/7B_lora
gemma/2B_lora
gemma/7B_lora
phi3/mini_lora
lora_finetune_fsdp2 llama2/7B_lora
llama2/13B_lora
llama2/70B_lora
llama2/7B_qlora
llama2/70B_qlora
generate generation
eleuther_eval eleuther_evaluation
quantize quantization
qat_distributed llama2/7B_qat_full
llama3/8B_qat_full
Download Llama3 model
- Create a directory for the model
Prepare a folder to store the downloaded model:
mkdir /tmp/Meta-Llama-3-8B-Instruct
- Getting access
Follow the instructions on the official Meta page hosted on Hugging Face to gain access to the Llama3–8B-Instruct model. Next, make sure you obtain your Hugging Face token from here. Once the access is granted, we can run the following command to download the model weights and tokenizer to our local system.
tune download meta-llama/Meta-Llama-3-8B-Instruct \
--output-dir /tmp/Meta-Llama-3-8B-Instruct \
--ignore-patterns "original/consolidated*" \
--hf-token <HF_TOKEN>
Configure Fine-Tuning Setup
We will fine-tune model with QLoRA configuration on a single device.
- Modify config
Copy the existing configuration file template for your fine-tuning setup.
tune cp llama3/8B_qlora_single_device ./custom_config.yaml
- Add metric logger
Once we have configured the logger in above yaml file, Weights & Biases will log the metrics and model checkpoints for you.
# Logging
output_dir: /tmp/qlora_finetune_output
metric_logger:
_component_: torchtune.utils.metric_logging.WandBLogger
project: <PROJECT_NAME>
log_every_n_steps: 1
log_peak_memory_stats: True
# Environment
enable_activation_checkpointing: True
- Validate Configuration
Ensure that the configuration file is properly set up.
tune validate ./custom_config.yaml
Set Up Weights & Biases (W&B) for Logging
Log into your W&B account to enable training tracking.
wandb login <API_KEY>
Run the Fine-Tuning Job
Run the fine-tuning process on a single device using QLoRA for memory efficiency.
tune run lora_finetune_single_device --config ./custom_config.yaml
Now once we run your custom recipe by directing the tune run command to the local files, this will automatically grab the config from the recipe you are running and log it to W&B. We can monitor the logged metrics in the project that we created on W&B.
We can also monitor and analyze the GPU-specific performance metrics. such as power consumption, memory allocation, temperature, and time spent accessing memory, as shown in image below.
Conclusion
Torchtune is a powerful tool for anyone looking to fine-tune large language models, offering both simplicity and robustness. Whether you’re just starting with LLMs or you’re an experienced practitioner, Torchtune’s modular design and seamless PyTorch integration make it an essential addition to your ML toolkit. From fine-tuning large models on distributed systems to tweaking configurations for maximum efficiency, Torchtune has the flexibility and power you need to take your projects to the next level.
References
- https://github.com/pytorch/torchtune
- https://pytorch.org/torchtune/stable/overview.html
- https://wandb.ai/byyoung3/mlnews2/reports/Fine-Tuning-Llama-3-with-LoRA-TorchTune-vs-HuggingFace--Vmlldzo3NjE3NzAz
- https://ai.plainenglish.io/torchtune-simplifying-llm-fine-tuning-8811d2bb25a5
- https://www.analyticsvidhya.com/blog/2024/04/pytorchs-torchtune-revolutionizing-llm-fine-tuning/