Reading Recommendations

This is a long paper, but it’s full of gems. Here’s a reading recommendation guide:

  • Strapped on time: sections 1 (Introduction), 2 (General Overview). It’s just a couple of pages and provides a good overview.
  • Love ML systems: 3.3 (Infrastructure, Scaling, Efficiency). Talks about on hardware, architecture, training challenges, parallelism optimizations
  • How to train a coding model: 4.3.1 (Code). Covers how they targeted specific coding abilities and generated synthetic datasets to bootstrap the model
  • Training model to perform tool use: 4.3.5 (Tool Use)
  • Post-training framework: 4.1 (Modeling). Covers their pipeline for reward modeling, supervised finetuning, and direct preference optimization
  • Extending to 128K context: 3.4.2 (Long Context Pre-Training) and 4.3.4 (Long Context)
  • Why a 405B model: 3.2.1 (Scaling Laws)
  • Optimizations for inference: 6 (Inference) on both pipeline parallelism and FP8 quantization, this is a short section
  • Results and benchmarks: 5.1, 5.2 (Pre and Post-trained Language Model), 5.3 (Human Evaluations)
  • Red teaming: 5.4.6 (Red Teaming)
  • Multi-modality: 7 (Vision Experiments), 8 (Speech Experiments), 9.2 (Multimodality)

Introduction

Introduces new set of models (8/70/405 B) that supports:

  • multilinguality
  • coding
  • reasoning
  • tool usage

Largest model:

  • 405B parameters
  • 128k context window
  • Has instruction fine-tuned version
  • Pre-trained on 3.8 x \(10^{25}\) FLOPS

Also introduced Llama Guard 3 model for input/output safety.

Pre-training

Pre-Training Data

Data Cleaning

Knowledge cutoff end of 2023. To ensure high-quality tokens, performed: de-duplication, data cleaning, removed domains known to contain large amounts of PII, adult content.

Data cleaning:

  • extracts HTML content from web documents
  • done carefully for pages with math & code content to preserve structure
  • Markdown markers also removed

De-duplication:

  • on the URL, duplication across documents, line-level de-duplication (common in boilerplate)

Used heuristics to filter other low-quality documents: logs/error messages, other adult websites, websites with excessive numbers of outlier tokens

Built a model-based classifier to sub-select high-quality tokens.

Built domain-specific pipelines to extract code & math-relevant web pages, including pages containing math deduction, pages containing code interleaved with natural language.

Used similar approaches as the above for other languages.

Data Mix

This ensures they have the right proportion of different data sources. They ended up with:

  • 50% general knowledge
  • 25% math & reasoning
  • 17% code
  • 8% multilingual

Knowledge classification: categorizes data to determine the data mix. Used this to downsample data over-represented on the web like arts & entertainment.

Scaling laws for data mix: trained several small models on data mix & use that to predict the performance of large models on mix

Overview

  • 15.6T multilingual tokens (compare 1.8T for Llama 2)
  • Use 8K token context window initially, followed by continued pre-training stage which increases supported context window to 128K tokens

Multi-modality

Encoders

Separate encoders trained for images and speech.

Image encoder:

  • Trained on image-text pairs

Speech encoder:

  • Self-supervised learning via masking
  • Masked part reconstructed by discrete-token representation

Adapters

TBD

Annealing Data

Performed annealing on small amounts of high-quality code and mathematical data. Annealing here means increasingly upsampling these high-quality data over time.

Found improvements for Llama 3 8B on GSM8k and MATH, but not 405B.

Model Architecture

Architecture

  • Uses dense Transformer architecture instead of MoE for training stability
  • Similar to Llama and Llama 2, performance gains mostly from improvements in data quality & diversity, and training scale
  • Grouped query attention with 8 KV heads: improves inference speed, reduce size of KV cache during decoding
  • Attention mask to prevent self-attention between different documents (why not just put them in different sequences? maybe to take advantage of parallelism?). Limited impact during pre-training, helpful for continued pre-training on long sequences
  • RoPE for positional embeddings (500,000 base frequency hyperparameter instead of 10k in original paper, this is helpful for longer context), SwiGLU activation
  • 128K token vocabulary, based off tiktoken tokenizer and extra 28K non-English tokens. Tokenizer improves compression rate from 3.17 to 3.94 characters per token compared to Llama 2 tokenizer.
  • Llama 3 405B: 126 layers (!!), model dimension 16,382, 128 attention heads

Scaling Laws

Scaling laws are nice for predicting loss, but not helpful for understanding impact on downstream task performance.

To find relationship with downstream task performance they did:

  1. Find correlation between compute-optimal model’s loss on downstream tasks and training FLOPs
  2. Find correlation between loss and downstream task accuracy, using scaling law models

The scaling laws suggest that given their compute budget of \(3.8 \times 10^{25}\) FLOPs, a 402B model with 16.55T tokens is optimal, which led to their 405B model.

They also found their predictions to be quite accurate for the final downstream performance of their models.

Infrastructure, Scaling, and Efficiency

Compute:

  • 16K H100 GPUs, 700W TDP (thermal design power) with 80GB HBM3 (high bandwidth memory that allows for faster data transfer between CPU and GPU)
  • Trained on Meta’s Grand Teton AI server platform, scheduling using MAST (Meta’s global-scale training scheduler)
  • Each server: 8 GPUs connected by NVLink, 2 CPUs

Storage:

  • Tectonic, Meta’s distributed file system
  • 240 PB storage over 7500 servers, 2TB/s sustainable throughput, 7TB/s peak throughput

Network:

  • Llama 3 405B used RDMA over Converged Ethernet (RoCE) fabric
  • Smaller models uses Nvidia Quantum2 Infiniband
  • Both 400 Gbps interconnect

Parallelism for Model Scaling

Scaled parallelism as much as possible, so all of GPU’s model parameters, optimizer states, gradients, and activations fit in HBM.

4D parallelism:

  • tensor parallelism
  • pipeline parallelism
  • context parallelism
  • data parallelism

Parallelism achieved BF16 Model FLOPs Utilization (MFU) of 38-43%

Reliability and Operational Challenges

  • 90% effective training time, even while supporting automated cluster maintenance (i.e Linux kernel upgrades)

  • At least one training interruption daily

466 job interruptions

  • 47 planned interruptions (i.e maintenance)
  • 419 unexpected: mostly GPU/host component failures, suspected data corruption, unplanned maintenance
  • Significant manual intervention only required 3 times, rest handled by automation

Debugging

  • PyTorch’s built-in NCCL flight recorder helped diagnose issues quickly at scale
  • Mixed use of NVLink and RoCE complicated things

Others

  • Higher mid-day temperatures impacted GPU dynamic voltage and frequency scaling, causing diurnal 1-2% throughput variation throughout the day
  • ~10ks of GPUs with correlated increase/decrease in power consumption (i.e waiting for checkpointing) causes fluctuation of power consumption on the order of ~10s megawatts, stretching limits of power grid

Training Recipe

Initial pre-training:

  • AdamW optimizer
  • Linear warm up, cosine LR schedule
  • Start with lower batch size for training stability, increase subsequently for efficiency
  • Few loss spikes, no interventions needed to correct for training divergence
  • Upsampled non-English and math data, downsampled low-quality data
  • Added recent web data in final stages of pre-training to advance model knowledge cut-off

Long context pre-training:

  • To support 128K context window
  • Don’t do long-context training earlier because of quadratic self-attention, too expensive
  • Increased context length by successive adaptation over 6 stages from 8K to 128K, 800B training tokens

Annealing:

  • On final 40M tokens, linearly annealed LR to 0, kept 128K context window
  • Upsampled data source of very high quality
  • Averaged model checkpoints during annealing to get final pre-trained model

Post-Training

  • Several rounds of post-training, each starts with SFT followed by DPO
  • Examples collected by human annotations or generated synthetically
  • Custom

Modeling

  • Uses reward model (RM) and language model (LM)
  • RM trained by human-annotated preference data
  • Checkpoints aligned with DPO

  • Model supports tool use, which required designing multi-message chat protocol with special header and termination tokens

Reward Modeling

  • RM trained on top of pre-trained checkpoint
  • Preference pairs of either (chosen, rejected) or (chosen, rejected, edited), where edited > chosen > rejected
  • Filtered out preference data with similar responses

Supervised Finetuning

  • RM performs rejection sampling on human annotation prompts
  • Fine-tune pre-trained LM on the model-generated samples that are accepted

Direct Preference Optimization

  • Why not on-policy algorithms like PPO? DPO required less compute, performed better on instruction-following benchmarks
  • Used most recent batches of preference data from best-performing models during previous alignment rounds, ensures training data conforms better to distribution of policy model being optimized

Modified DPO:

  • Masked out formatting tokens (including header and termination tokens) in DPO loss, helps with stability. These tokens caused tail repetition or random termination tokens. Hypothesized due to these tokens being common in both chosen and rejected responses causes conflicting optimization objectives
  • Added regularization with negative log-likelihood (NLL) loss

Post-training Data

Preference Data

  • Sample two responses from two different models for each user prompt, labelled by human annotators
  • Annotators state strength of preference by 4 levels: significantly better, better, slightly better, marginally better
  • Allow editing step after annotation to further improve response
  • Only used responses significantly better or better for training

SFT Data

Finetuning data contains:

  • Prompts from human annotation collection with rejection-sampled (RS) responses
  • Synthetic data targeting specific capabilities
  • Small amounts of human-curated data

Datasets:

  • General English
  • Code
  • Multilingual
  • Exam-like
  • Reasoning and tools
  • Long context

Rejection sampling:

  • Choose prompt from human annotation collection
  • Sample 10-30 outputs from latest chat model policy
  • Use RM to choose best candidate
  • For later rounds of post-training, use system prompt to steer RS responses to conform with tone/style/formatting
  • Uses PagedAttention to make RS efficient

Data Processing and Quality Control

Most of training data is model-generated, requires careful cleaning and quality control

Data cleaning:

  • Rule-based data removal or modification strategies
  • Balance proportion of such samples in dataset

Data pruning:

  • Topic classification: Fine-tuned Llama 3 8B to a topic classifier
  • Quality scoring: Use RM and Llama 3 checkpoint to rate content, keep examples marked as high quality by either RM or Llama. Both signals have high disagreement rates, and combining signals gives best recall on test set.
  • Difficulty scoring: used Llama 3 70B to perform intention-tagging, where more intentions implies more complexity. Also used it to measure difficulty of dialogs
  • Semantic deduplication: clustering using RoBERTa, sort by quality score \(\times\) difficulty score, go through sorted examples by best and take only ones with maximum cosine similarity less than threshold

Capabilities

Code

Capabilities:

  • Code generation
  • Documentation
  • Debugging
  • Review

Targeted languages: Python, Java, Javascript, C/C++, Typescript, Rust, PHP, HTML/CSS, SQL, bash/shell

Improved capabilities via:

  • Training a code expert
  • Generate synthetic data for SFT
  • Improve formatting with system prompt steering
  • Create quality filters to remove bad examples

Expert training:

  • Train code expert to obtain high quality human annotations for code
  • Approach similar to CodeLlama (scant on details, should probably check this paper)

Synthetic data generation:

  • Faced issues in code generation: following instructions, code syntax errors, incorrect code generation, difficulty in fixing bugs
  • Use Llama 3 and code expert to generate synthetic 2.7M dialogs for SFT

During RS, used code specific system prompts to improve:

  • code readability
  • documentation
  • thoroughness
  • specificity

Synthetic data generation: execution feedback

  • Distillation to smaller models helped, but not for 405B on its own inputs
  • Use execution feedback as source of truth, allow model to learn from own mistakes
    1. Problem description generation: generate programming problem descriptions, use random code snippets as inspiration
    2. Solution generation: Prompt Llama 3 to solve problem, use CoT in comments, add programming guidelines in system prompt
    3. Correctness analysis: use static analysis (parser and linters), and unit test generation (also by the model) and execution
    4. Error feedback and iterative self-correction: prompt model to revise solutions that fail, includes feedback from parser/linter/tester. Can modify code and unit test to accomodate new code. 20% of solutions that were incorrect could be self-corrected this way.
    5. Fine-tuning and iterative improvement: process iterated over multiple rounds, higher-quality synthetic data generated in each subsequent rounds

Synthetic data generation: programming language translation

  • Noted performance gap between popular vs less common programming languages, due to difference in dataset size
  • Translate data from more common to less common languages
  • Ensure quality via syntax parsing, compilation, execution

Synthetic data generation: backtranslation

  • Some coding capabilities don’t benefit as much from execution feedback, i.e documentation & explanation
  • Generated 1.2M synthetic dialogs for code explanation, generation, documentation, debugging
  • Done as follows:
    1. Generate: Prompt Llama 3 to generate data corresponding to desired capability (i.e add comments to code)
    2. Backtranslate: Ask model to backtranslate synthetically generated data to original code (i.e generate code based on only comments)
    3. Filter: ask Llama 3 to determine quality of generated code with original code as reference. This self-verification step acts as a filter for good examples, only those with high scores are used for SFT

Tool Use

Trained Llama 3 to use search engine (Brave), Python interpreter, Wolfram Alpha API.

To train on tool use:

  • Start with training single-turn tool use, then tool use in dialog, and then multi-step tool use and data analysis
  • All synthetically generated: first synthetic user prompts which require calling out to tools, then the corresponding tool calls which are then executed, and then the final answer to user prompt
  • Multi-step tool use trained in a similar way synthetically
  • User prompts are based on a provided file, and ask to summarize the contents of the file, find and fix bugs, optimize a piece of code, perform data analysis or visualization
  • Augmented synthetic data with different system prompts to teach model to use tools only when activated
  • To avoid model using tools for simple queries, added dataset containing queries of simple math/reasoning questions with tool use activated but without using tools in response

Factuality

To train the model to guard against hallucinations, they used a knowledge probe to find out what the model knows, and to generate training data of refusals for the things it doesn’t:

  1. Extract a data snippet from the pre-training data.
  2. Generate a factual question about these snippets (context) by prompting Llama 3.
  3. Sample responses from Llama 3 to the question.
  4. Score the correctness of the generations using the original context as a reference and Llama 3 as a judge.
  5. Score the informativeness of the generations using Llama 3 as a judge.
  6. Generate a refusal for responses which are consistently informative and incorrect across the generations, using Llama 3.

But because pre-training data is not always factually correct, they also did this for sensitive topics where contradictory/incorrect statements are prevalent

Steerability

Remainder to be continued…