Continued Pretraining in Machine Learning

Continued pretraining (CPT) is the practice of further training a pretrained language model on a domain-specific corpus using the same self-supervised objective (typically next-token prediction) before any task-specific fine-tuning. It occupies a critical middle layer in the model adaptation stack: pretraining gives you general language understanding, continued pretraining injects domain knowledge, and fine-tuning teaches the model to follow instructions or perform specific tasks.

The intuition is straightforward. A model like Llama 3 has seen trillions of tokens of internet text, but its exposure to, say, Indian legal judgments, radiology reports, or semiconductor datasheets is sparse and shallow. Continued pretraining on a curated corpus of 10-100 billion domain tokens rewires the model's internal representations so that domain-specific concepts, terminology, and reasoning patterns become first-class citizens in the weight space -- not distant memories retrieved through clever prompting.

Since Gururangan et al.'s seminal 2020 paper "Don't Stop Pretraining," CPT has become the standard first step for building domain-specific LLMs. BloombergGPT mixed 363 billion financial tokens with 345 billion general tokens. Code Llama continued pretraining Llama 2 on 500 billion tokens of code. Llemma adapted Code Llama on mathematical text from the Proof-Pile-2. In India, Sarvam AI's OpenHathi and Ola's Krutrim used continued pretraining to extend Llama-family models to Hindi and other Indic languages.

What makes CPT both powerful and tricky is the balancing act: you want the model to deeply absorb domain knowledge without forgetting its general capabilities. Get the data mix wrong, the learning rate too high, or the tokenizer mismatched, and you end up with a model that speaks fluent radiology but can't handle basic arithmetic. This guide covers every aspect of that balancing act -- from corpus curation to compute cost estimation.

Concept Snapshot

What It Is
A training phase where a pretrained language model undergoes additional self-supervised learning on a domain-specific corpus, adapting its internal representations to encode domain knowledge before any task-specific fine-tuning.
Category
Model Training
Complexity
Advanced
Inputs / Outputs
Inputs: pretrained base model + curated domain corpus + optional general data for replay. Outputs: domain-adapted pretrained model ready for downstream fine-tuning or direct prompting.
System Placement
Sits between initial pretraining and fine-tuning (SFT/RLHF/LoRA) in the ML pipeline. Applied after base model selection and corpus curation, before instruction tuning or task-specific adaptation.
Also Known As
Continual Pretraining, Domain-Adaptive Pretraining (DAPT), Second-Phase Pretraining, Adaptive Pretraining, Domain Continued Pretraining
Typical Users
ML Engineers, NLP Engineers, Applied Scientists, Foundation Model Teams, Domain AI Specialists
Prerequisites
Language model pretraining fundamentals (next-token prediction, causal LM), Transformer architecture (attention, MLP layers, positional encoding), Learning rate scheduling (cosine annealing, warmup), Data preprocessing for LLM training (tokenization, deduplication), Distributed training basics (FSDP, DeepSpeed, data parallelism)
Key Terms
DAPT (Domain-Adaptive Pretraining)TAPT (Task-Adaptive Pretraining)data mixing ratiocatastrophic forgettingreplay buffertokenizer adaptationlearning rate re-warmingdomain corpusperplexity

Why This Concept Exists

The Domain Knowledge Gap

Large language models are trained on massive, general-purpose web corpora -- Common Crawl, Wikipedia, books, code repositories. This gives them impressive breadth: they can discuss philosophy, write Python, and summarize news articles. But breadth is not depth.

Consider a model deployed for Indian legal tech. The general pretraining corpus might contain a few hundred thousand tokens of Indian court judgments, scattered across millions of web pages. But the Indian legal system has its own vocabulary ("writ petition," "Section 498A," "bail anticipatory"), its own reasoning patterns (precedent-based argumentation citing specific IPC sections), and its own document structures (FIR format, judgment headings). A general model has seen these patterns, but they occupy a tiny fraction of its parameter space. It's like asking someone who read one chapter on Indian law in a 10,000-chapter encyclopedia to draft a bail application.

Continued pretraining solves this by dedicating substantial compute to domain-specific text, effectively "re-weighting" the model's knowledge distribution toward the target domain.

From DAPT to Modern CPT

The formal study of continued pretraining began with Gururangan et al. (2020), who introduced the terms DAPT (Domain-Adaptive Pretraining) and TAPT (Task-Adaptive Pretraining). Their key finding was elegant: taking RoBERTa and further pretraining it on domain text (biomedical papers, computer science papers, news, reviews) consistently improved downstream task performance, sometimes by 3-8 F1 points. Moreover, TAPT -- pretraining on just the unlabeled task data -- provided additional gains on top of DAPT.

This observation scaled dramatically with the LLM era. As models grew from 340M (RoBERTa) to 7B-70B parameters, the returns from domain-specific continued pretraining grew proportionally. Bloomberg trained a 50B-parameter model from scratch on mixed financial and general data (BloombergGPT, 2023). Meta took Llama 2 and continued pretraining on 500B code tokens to produce Code Llama (2023). Azerbayev et al. continued pretraining Code Llama on mathematical text to produce Llemma (2023), which outperformed the much larger Minerva model.

Why Not Just Fine-Tune or Use RAG?

A natural question: if you want domain knowledge, why not just fine-tune on domain data or use retrieval-augmented generation (RAG)?

Fine-tuning (SFT, LoRA) teaches the model to follow specific instruction formats and perform specific tasks, but it doesn't fundamentally alter the model's world knowledge. You can fine-tune a model on 10,000 radiology report Q&A pairs, but if the base model doesn't deeply understand radiological concepts, the fine-tuned model will hallucinate confidently.

RAG retrieves relevant documents at inference time and includes them in the context. It's powerful for factual recall but limited by context window size, retrieval quality, and the model's ability to reason over retrieved text. If the model's internal representations don't align with the domain, it struggles to synthesize information from retrieved passages.

Continued pretraining operates at a deeper level: it rewires the model's internal representations so that domain concepts are natively encoded in the weights. After CPT, the model doesn't need to retrieve basic domain knowledge -- it's already there, enabling better reasoning, fewer hallucinations, and more natural domain fluency. The three approaches are complementary: CPT provides the knowledge foundation, fine-tuning adds task-specific behavior, and RAG provides access to dynamic or rare information.

Key Insight: Continued pretraining is the only approach that changes what the model knows. Fine-tuning changes how it behaves. RAG changes what it can access. For deep domain adaptation, you typically need all three.

Core Intuition & Mental Model

The Analogy: Becoming a Domain Expert

Imagine hiring a brilliant generalist consultant (the pretrained LLM) to work in your hospital's radiology department. This consultant has read widely -- they know some medical terminology, can discuss imaging modalities at a cocktail party level, and can follow basic clinical reasoning. But they're not a radiologist.

You have three options:

  1. Give them a reference library (RAG): They can look things up when asked, but they don't deeply understand what they're reading. They'll miss subtleties and can't synthesize across sources effectively.

  2. Train them on Q&A flashcards (fine-tuning): They learn to pattern-match specific question types, but their understanding is shallow. Ask something slightly outside the flashcard distribution and they're lost.

  3. Send them to radiology residency (continued pretraining): They spend months reading thousands of radiology textbooks, case reports, and imaging studies. The knowledge becomes part of their mental model. They develop intuition for what's normal and abnormal, learn the vocabulary natively, and can reason about novel cases.

Continued pretraining is option 3. It's expensive -- residency takes years (or in ML terms, significant GPU hours) -- but it produces a fundamentally more capable domain specialist.

The Representation Geometry Intuition

At a more technical level, think about what continued pretraining does to the model's internal representation space. In a general pretrained model, concepts like "myocardial infarction," "ST elevation," and "troponin levels" are scattered across the embedding space, loosely associated but not tightly clustered. The model knows these words exist but hasn't deeply encoded their relationships.

Continued pretraining on medical text forces the model to repeatedly predict medical text -- to predict that "ST elevation" is often followed by discussion of "myocardial infarction" and "troponin levels." This repetition tightens the relevant representations, creating dense clusters of domain knowledge where related concepts are geometrically close in the embedding space and the attention patterns learn domain-specific reasoning paths.

The result: domain concepts move from the periphery of the model's representation space to become well-organized, easily accessible regions. This is why continued pretraining improves not just domain-specific tasks but also the model's ability to follow complex domain reasoning chains -- the building blocks are better organized.

Mental Model: Continued pretraining is like renovating a library. The books (knowledge) were already somewhere in the building, but they were scattered across random shelves. CPT reorganizes them into a dedicated domain section with proper indexing, making retrieval and reasoning dramatically more efficient.

Technical Foundations

The Training Objective

Continued pretraining uses the same autoregressive language modeling objective as initial pretraining. Given a domain corpus D={x1,x2,,xN}\mathcal{D} = \{x_1, x_2, \ldots, x_N\} where each xix_i is a document, the model parameters θ\theta (initialized from pretrained weights θ0\theta_0) are optimized to minimize the negative log-likelihood:

L(θ)=1DxDt=1xlogPθ(xtx<t)\mathcal{L}(\theta) = -\frac{1}{|\mathcal{D}|} \sum_{x \in \mathcal{D}} \sum_{t=1}^{|x|} \log P_\theta(x_t \mid x_{<t})

The key difference from initial pretraining is that θ\theta starts from θ0\theta_0 (pretrained weights) rather than random initialization, and the training corpus D\mathcal{D} is domain-specific rather than general.

Data Mixing Formulation

In practice, CPT rarely uses pure domain data. Instead, a mixed corpus Dmix\mathcal{D}_{\text{mix}} is constructed:

Dmix=αDdomain+(1α)Dgeneral\mathcal{D}_{\text{mix}} = \alpha \cdot \mathcal{D}_{\text{domain}} + (1 - \alpha) \cdot \mathcal{D}_{\text{general}}

where α[0,1]\alpha \in [0, 1] is the domain data mixing ratio. Typical values range from α=0.5\alpha = 0.5 to α=0.9\alpha = 0.9. The general data component Dgeneral\mathcal{D}_{\text{general}} acts as a replay buffer that mitigates catastrophic forgetting.

BloombergGPT used α0.51\alpha \approx 0.51 (363B financial tokens / 708B total tokens). Code Llama used α0.85\alpha \approx 0.85 (85% code, 8% code-about-code, 7% general). The optimal α\alpha depends on domain distance from the pretraining distribution and corpus size.

Learning Rate Scheduling for CPT

The learning rate schedule for CPT differs critically from initial pretraining. Let η0\eta_0 be the peak learning rate used during initial pretraining. For CPT, the typical approach is:

ηCPT=βη0where β[0.01,0.1]\eta_{\text{CPT}} = \beta \cdot \eta_0 \quad \text{where } \beta \in [0.01, 0.1]

The CPT learning rate is typically 10-100x lower than the initial pretraining LR. The schedule usually follows cosine annealing:

η(t)=ηmin+12(ηCPTηmin)(1+cos(πtT))\eta(t) = \eta_{\text{min}} + \frac{1}{2}(\eta_{\text{CPT}} - \eta_{\text{min}})\left(1 + \cos\left(\frac{\pi t}{T}\right)\right)

where TT is the total number of CPT steps and ηmin\eta_{\text{min}} is the minimum learning rate (often 0.1×ηCPT0.1 \times \eta_{\text{CPT}}).

Recent work by Gupta et al. (2023) suggests that re-warming the learning rate (briefly increasing it at the start of CPT) can help when the initial pretraining used cosine decay to near-zero. The warmup phase is typically short -- 1-5% of total CPT steps.

Compute Cost Estimation

The compute required for CPT scales linearly with corpus size and model parameters:

FLOPs6×N×D\text{FLOPs} \approx 6 \times N \times D

where NN is the number of model parameters and DD is the number of training tokens. For an 8B-parameter model trained on 50B domain tokens:

FLOPs6×8×109×50×109=2.4×1021\text{FLOPs} \approx 6 \times 8 \times 10^9 \times 50 \times 10^9 = 2.4 \times 10^{21}

On an NVIDIA A100 80GB (achieving ~300 TFLOPS with bf16), this takes approximately:

GPU-hours=2.4×1021300×1012×36002,222 GPU-hours\text{GPU-hours} = \frac{2.4 \times 10^{21}}{300 \times 10^{12} \times 3600} \approx 2{,}222 \text{ GPU-hours}

On a cluster of 32 A100 GPUs: approximately 69 hours (about 3 days). At cloud rates of 2/GPUhour( INR168/GPUhour),thatsroughly2/GPU-hour (~INR 168/GPU-hour), that's roughly **4,444 (~INR 3.7 lakh)**.

Catastrophic Forgetting Formalization

Catastrophic forgetting can be formalized as performance degradation on a general benchmark Bgen\mathcal{B}_{\text{gen}} after CPT:

Δforget=Score(θ0,Bgen)Score(θCPT,Bgen)\Delta_{\text{forget}} = \text{Score}(\theta_0, \mathcal{B}_{\text{gen}}) - \text{Score}(\theta_{\text{CPT}}, \mathcal{B}_{\text{gen}})

Strategies to minimize Δforget\Delta_{\text{forget}} include:

  1. Data replay: Including general data in the training mix (controlled by α\alpha)
  2. Elastic Weight Consolidation (EWC): Adding a regularization term λ2iFi(θiθ0,i)2\frac{\lambda}{2} \sum_i F_i (\theta_i - \theta_{0,i})^2 where FiF_i is the Fisher information for parameter ii
  3. Low learning rate: Using β1\beta \ll 1 to constrain parameter drift
  4. Gradient projection: Projecting gradients to avoid interfering with important directions for general capabilities

Practical Rule: For most CPT setups, a mixing ratio of α=0.7\alpha = 0.7 (70% domain, 30% general replay) with a learning rate of 0.05×η00.05 \times \eta_0 provides a good starting point. Monitor both domain perplexity (should decrease) and general benchmark scores (should not decrease by more than 2-3%).

Internal Architecture

The architecture of a continued pretraining pipeline involves several interconnected stages, from corpus preparation through training execution to validation. Unlike initial pretraining, CPT requires careful management of the relationship between the pretrained model's existing knowledge and the new domain data.

The following diagram illustrates the end-to-end CPT pipeline, showing how raw domain data flows through curation, tokenization, mixing with general replay data, and into the training loop. The key architectural decision points are the data mixing ratio, learning rate schedule, and checkpoint validation strategy.

The validation stage is particularly critical for CPT: you need to track both domain-specific improvement (decreasing perplexity on held-out domain text) and general capability retention (stable scores on benchmarks like MMLU, HellaSwag). The best checkpoint is typically not the final one -- it's the one that maximizes domain performance while keeping general regression below a threshold.

Key Components

Domain Corpus Curator

Responsible for sourcing, cleaning, deduplicating, and filtering the domain-specific training data. This is the most labor-intensive component of the CPT pipeline. Key operations include MinHash deduplication (removing near-duplicate documents that inflate the corpus without adding information), perplexity-based quality filtering (using a small reference model to filter out noisy, machine-generated, or off-topic text), and toxicity/PII removal (ensuring the corpus doesn't inject harmful biases into the model). For Indian language corpora, this also involves script normalization and transliteration handling.

Tokenizer Adapter

Handles tokenizer adaptation when the domain vocabulary differs significantly from the pretraining vocabulary. For example, a model pretrained on English web text will tokenize Hindi text into many small subword fragments (3-4 tokens per Hindi word vs. 1 token per English word). The tokenizer adapter can extend the vocabulary with domain-specific tokens and initialize their embeddings via averaging or interpolation from existing token embeddings. This component is optional but critical for cross-lingual CPT and specialized domains like chemistry (SMILES notation) or law (statute references).

Data Mixer

Combines domain-specific data with general replay data according to a configurable mixing ratio α\alpha. The mixer must handle different data formats, ensure balanced sampling across domain sub-categories (e.g., for medical CPT: clinical notes, research papers, drug labels, patient education material), and optionally implement curriculum learning -- starting with a higher proportion of general data and gradually increasing domain data concentration. The DoReMi algorithm can be used to automatically optimize these mixing weights.

Learning Rate Scheduler

Controls the learning rate trajectory during CPT. Unlike initial pretraining (which uses a standard warmup + cosine decay from a high LR), CPT typically starts from a much lower LR (1-10% of the initial pretraining peak) with optional brief re-warming. The scheduler must handle the transition from the pretrained model's final learning rate state to the CPT regime. Research suggests that cosine annealing is preferred over WSD (Warmup-Stable-Decay) for CPT when the base model was trained with cosine annealing.

Distributed Training Engine

Manages the actual training computation across multiple GPUs/nodes. For CPT of 7B+ parameter models, this requires data parallelism (FSDP or DeepSpeed ZeRO) and optionally tensor or pipeline parallelism. Key frameworks include Megatron-LM (NVIDIA's high-performance training framework), nanotron (Hugging Face's minimalist parallelism library), and torchtitan (PyTorch-native distributed training). The engine handles gradient accumulation, mixed-precision training (bf16), and checkpoint saving.

Forgetting Monitor

Continuously evaluates the model on general-capability benchmarks during training to detect catastrophic forgetting early. Runs evaluation on a subset of benchmarks (e.g., MMLU, ARC, HellaSwag) every few thousand steps. If general performance drops below a threshold (typically 2-3%), the monitor can trigger corrective actions: increasing the general data mixing ratio, reducing the learning rate, or applying EWC regularization. This early warning system prevents wasting compute on a forgetting trajectory.

Checkpoint Selector

Evaluates saved checkpoints on both domain-specific and general benchmarks to identify the optimal stopping point. CPT training loss alone is not sufficient for checkpoint selection because a model can achieve low domain perplexity while having catastrophically forgotten general knowledge. The selector computes a composite score that balances domain improvement against general regression, enabling principled selection of the best checkpoint for downstream use.

Data Flow

Corpus Preparation Phase: Raw domain text is collected from domain-specific sources (academic papers, legal documents, clinical notes, code repositories, etc.). The text passes through deduplication (MinHash for near-duplicate removal), quality filtering (perplexity scoring against a reference model, heuristic rules for formatting), and optionally toxicity/PII filtering. The cleaned text is tokenized -- either with the original tokenizer or an adapted tokenizer with domain-specific vocabulary extensions.

Training Phase: The tokenized domain corpus is combined with a general replay corpus via the data mixer. Each training batch contains a mix of domain and general tokens at the configured ratio α\alpha. The pretrained model processes these batches through the standard autoregressive forward pass, computing cross-entropy loss on next-token predictions. Gradients are computed, scaled by the learning rate from the cosine schedule, and applied to update all model parameters. Gradient accumulation is used when the effective batch size exceeds single-GPU memory. Checkpoints are saved periodically (every 1,000-5,000 steps).

Validation Phase: At each checkpoint, the forgetting monitor evaluates domain perplexity (lower is better) and general benchmark scores (should remain stable). The checkpoint selector aggregates these metrics to rank checkpoints. The best checkpoint -- typically from the middle-to-late phase of training -- is selected as the final domain-adapted model. This model then proceeds to fine-tuning (SFT, LoRA, RLHF) for task-specific adaptation.

A flowchart showing four stages: (1) Data Preparation -- raw domain corpus flows through deduplication, quality filtering, and tokenization. (2) Data Mixing -- tokenized domain data is combined with general replay data at a configurable ratio. (3) CPT Training Loop -- the pretrained base model is trained on the mixed corpus with a cosine-decaying learning rate, producing periodic checkpoints. (4) Validation -- each checkpoint is evaluated on domain perplexity and general benchmarks, with the best checkpoint selected as the final domain-adapted model.

How to Implement

Two Primary Implementation Approaches

There are two main ways to implement continued pretraining in practice:

Approach 1: Hugging Face Transformers + DeepSpeed/FSDP -- The most accessible option for teams already using the Hugging Face ecosystem. You load a pretrained model, configure a Trainer with a domain dataset and appropriate hyperparameters, and leverage DeepSpeed ZeRO or PyTorch FSDP for multi-GPU training. This approach works well for models up to 13B parameters on 2-8 GPUs.

Approach 2: Megatron-LM / nanotron -- For larger models (30B+) or when you need maximum training throughput, dedicated pretraining frameworks offer optimized kernels, tensor parallelism, and pipeline parallelism. NVIDIA's Megatron-LM is the industry standard for high-performance pretraining, while Hugging Face's nanotron provides a more Pythonic alternative with similar performance.

The critical implementation decisions for CPT are:

  1. Data mixing: How to combine domain and general data, and at what ratio
  2. Learning rate: Starting LR, warmup strategy, and decay schedule
  3. Tokenizer: Whether to adapt the tokenizer for domain vocabulary
  4. Evaluation: How to monitor for catastrophic forgetting during training

Cost Note: Continued pretraining of Llama 3 8B on 50B tokens takes approximately 2,200 GPU-hours on A100 80GB. At cloud rates, that's ~4,400( INR3.7lakh)onAWSor 4,400 (~INR 3.7 lakh) on AWS or ~3,300 (~INR 2.8 lakh) on Indian cloud providers like E2E Networks. For comparison, the initial pretraining of Llama 3 8B on 15T tokens cost an estimated ~$2M. CPT at 50B tokens is roughly 0.2% of the original pretraining cost for meaningful domain adaptation.

Continued Pretraining with Hugging Face Transformers + DeepSpeed
import os
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
)
from datasets import load_dataset, interleave_datasets

# Load base model and tokenizer
model_name = "meta-llama/Llama-3.1-8B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    attn_implementation="flash_attention_2",
)

# Load and prepare domain corpus
domain_dataset = load_dataset(
    "json",
    data_files="/data/medical_corpus_cleaned/*.jsonl",
    split="train",
)

# Load general replay data (for catastrophic forgetting mitigation)
general_dataset = load_dataset(
    "HuggingFaceFW/fineweb-edu",
    name="sample-10BT",
    split="train",
    streaming=True,
)

# Tokenization function
def tokenize_fn(examples, max_length=2048):
    return tokenizer(
        examples["text"],
        truncation=True,
        max_length=max_length,
        return_special_tokens_mask=True,
    )

domain_tokenized = domain_dataset.map(
    tokenize_fn,
    batched=True,
    remove_columns=domain_dataset.column_names,
    num_proc=os.cpu_count(),
)

# Interleave domain (70%) and general (30%) data
mixed_dataset = interleave_datasets(
    [domain_tokenized, general_dataset.map(tokenize_fn, batched=True)],
    probabilities=[0.7, 0.3],  # 70% domain, 30% general replay
    seed=42,
)

# Data collator for causal LM (next-token prediction)
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,  # Causal LM, not masked LM
)

# Training arguments for CPT
training_args = TrainingArguments(
    output_dir="./llama3-8b-medical-cpt",
    max_steps=50000,               # ~50B tokens with batch size ~1M tokens/step
    per_device_train_batch_size=2,
    gradient_accumulation_steps=32, # Effective batch = 2 * 32 * 8 GPUs = 512
    learning_rate=3e-5,             # ~10x lower than Llama 3 pretraining LR (3e-4)
    lr_scheduler_type="cosine",
    warmup_steps=500,               # Brief re-warming (1% of total steps)
    weight_decay=0.1,
    bf16=True,
    gradient_checkpointing=True,
    logging_steps=10,
    save_strategy="steps",
    save_steps=2500,
    eval_strategy="steps",
    eval_steps=2500,
    dataloader_num_workers=4,
    deepspeed="ds_config_zero3.json",  # DeepSpeed ZeRO Stage 3
)

# Train
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=mixed_dataset,
    data_collator=data_collator,
    tokenizer=tokenizer,
)
trainer.train()

# Save the domain-adapted model
trainer.save_model("./llama3-8b-medical-cpt-final")
tokenizer.save_pretrained("./llama3-8b-medical-cpt-final")

This is a production-ready recipe for continued pretraining using Hugging Face Transformers with DeepSpeed. Key decisions:

  • learning_rate=3e-5: Llama 3 was pretrained with a peak LR of ~3e-4. We use 10x lower (3e-5) for CPT to avoid destructive updates to the pretrained representations.
  • 70/30 domain/general mix: The interleave_datasets with probabilities=[0.7, 0.3] ensures 70% of each batch is domain data while 30% is general replay data to prevent catastrophic forgetting.
  • warmup_steps=500: Brief re-warming (1% of total steps) helps the optimizer adapt to the new data distribution without large initial loss spikes.
  • DeepSpeed ZeRO Stage 3: Shards model parameters, gradients, and optimizer states across GPUs, enabling 8B model CPT on 8x A100 80GB GPUs with reasonable batch sizes.
  • gradient_checkpointing=True: Reduces activation memory at the cost of ~30% slower training -- essential for fitting long sequences.
  • flash_attention_2: Uses Flash Attention for memory-efficient and faster attention computation.
Domain Corpus Curation Pipeline
import hashlib
from dataclasses import dataclass
from pathlib import Path
from typing import Iterator
import json

from datasketch import MinHash, MinHashLSH
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import numpy as np


@dataclass
class Document:
    text: str
    source: str
    language: str
    quality_score: float = 0.0


class CorpusCurator:
    """Pipeline for curating domain corpus for continued pretraining."""

    def __init__(
        self,
        reference_model_name: str = "gpt2",  # Small model for perplexity scoring
        quality_threshold: float = 50.0,       # Max perplexity for quality filter
        min_doc_length: int = 100,             # Minimum tokens per document
        max_doc_length: int = 100000,          # Maximum tokens per document
        dedup_threshold: float = 0.8,          # MinHash Jaccard similarity threshold
    ):
        self.quality_threshold = quality_threshold
        self.min_doc_length = min_doc_length
        self.max_doc_length = max_doc_length
        self.dedup_threshold = dedup_threshold

        # Load reference model for perplexity scoring
        self.tokenizer = AutoTokenizer.from_pretrained(reference_model_name)
        self.model = AutoModelForCausalLM.from_pretrained(
            reference_model_name,
            torch_dtype=torch.float16,
            device_map="auto",
        )
        self.model.eval()

        # MinHash LSH for deduplication
        self.lsh = MinHashLSH(threshold=dedup_threshold, num_perm=128)
        self.seen_hashes = set()

    def compute_perplexity(self, text: str, max_length: int = 512) -> float:
        """Compute perplexity of text using reference model."""
        encodings = self.tokenizer(
            text, return_tensors="pt", truncation=True, max_length=max_length
        )
        input_ids = encodings.input_ids.to(self.model.device)

        with torch.no_grad():
            outputs = self.model(input_ids, labels=input_ids)
            loss = outputs.loss.item()

        return np.exp(loss)

    def compute_minhash(self, text: str) -> MinHash:
        """Compute MinHash signature for near-duplicate detection."""
        mh = MinHash(num_perm=128)
        # Use 5-gram shingles
        words = text.lower().split()
        for i in range(len(words) - 4):
            shingle = " ".join(words[i : i + 5])
            mh.update(shingle.encode("utf-8"))
        return mh

    def is_duplicate(self, doc_id: str, minhash: MinHash) -> bool:
        """Check if document is a near-duplicate of any seen document."""
        result = self.lsh.query(minhash)
        if result:
            return True
        self.lsh.insert(doc_id, minhash)
        return False

    def curate(self, documents: Iterator[Document]) -> Iterator[Document]:
        """Full curation pipeline: length filter -> dedup -> quality filter."""
        stats = {"total": 0, "length_filtered": 0, "deduped": 0, "quality_filtered": 0, "kept": 0}

        for doc in documents:
            stats["total"] += 1

            # Step 1: Length filter
            token_count = len(self.tokenizer.encode(doc.text))
            if token_count < self.min_doc_length or token_count > self.max_doc_length:
                stats["length_filtered"] += 1
                continue

            # Step 2: Near-duplicate removal
            doc_id = hashlib.md5(doc.text[:1000].encode()).hexdigest()
            minhash = self.compute_minhash(doc.text)
            if self.is_duplicate(doc_id, minhash):
                stats["deduped"] += 1
                continue

            # Step 3: Quality filtering via perplexity
            perplexity = self.compute_perplexity(doc.text)
            if perplexity > self.quality_threshold:
                stats["quality_filtered"] += 1
                continue

            doc.quality_score = 1.0 / perplexity  # Higher is better
            stats["kept"] += 1
            yield doc

        print(f"Curation stats: {json.dumps(stats, indent=2)}")


# Usage
curator = CorpusCurator(
    reference_model_name="gpt2",
    quality_threshold=80.0,
    dedup_threshold=0.8,
)

# Example: curate medical documents
raw_docs = (
    Document(text=line.strip(), source="pubmed", language="en")
    for line in open("/data/raw_medical_text.txt")
    if line.strip()
)

with open("/data/medical_corpus_curated.jsonl", "w") as f:
    for doc in curator.curate(raw_docs):
        f.write(json.dumps({"text": doc.text, "source": doc.source}) + "\n")

This corpus curation pipeline implements the three critical stages of data preparation for CPT:

  1. Length filtering: Removes documents that are too short (likely noise) or too long (likely data dumps). The thresholds (100-100K tokens) should be adjusted per domain.

  2. Near-duplicate removal via MinHash LSH: Uses locality-sensitive hashing to efficiently detect near-duplicate documents. This is essential because web-scraped domain corpora often contain many copies of the same content (e.g., the same medical guideline published on 50 different websites). Deduplication can remove 20-40% of a raw corpus while improving training efficiency.

  3. Perplexity-based quality filtering: Uses a small reference model (GPT-2) to score document quality. Documents with very high perplexity are likely noisy, machine-generated, or not genuine domain text. This is the same approach used in FineWeb and other high-quality pretraining datasets.

The pipeline is designed for streaming (Iterator-based) to handle corpora that don't fit in memory.

Tokenizer Vocabulary Extension for Domain CPT
from transformers import AutoTokenizer, AutoModelForCausalLM
from collections import Counter
import torch
import json


def analyze_tokenizer_efficiency(tokenizer, domain_texts: list[str]) -> dict:
    """Measure how efficiently the tokenizer handles domain text."""
    total_chars = 0
    total_tokens = 0
    oov_fragments = Counter()  # Track frequently split words

    for text in domain_texts:
        total_chars += len(text)
        tokens = tokenizer.encode(text)
        total_tokens += len(tokens)

        # Identify words that get split into many subwords
        words = text.split()
        for word in words:
            subtokens = tokenizer.encode(word, add_special_tokens=False)
            if len(subtokens) > 3:  # Word split into 4+ tokens
                oov_fragments[word] += 1

    return {
        "chars_per_token": total_chars / total_tokens,
        "fertility": total_tokens / (total_chars / 5),  # Approx words
        "top_fragmented_words": oov_fragments.most_common(50),
    }


def extend_tokenizer(
    base_tokenizer_name: str,
    domain_texts: list[str],
    new_vocab_size: int = 1000,
    output_dir: str = "./extended_tokenizer",
) -> AutoTokenizer:
    """Extend tokenizer with domain-specific vocabulary."""
    from tokenizers import Tokenizer
    from tokenizers.trainers import BpeTrainer

    base_tokenizer = AutoTokenizer.from_pretrained(base_tokenizer_name)
    original_vocab_size = len(base_tokenizer)

    # Train a small BPE tokenizer on domain text to find new tokens
    domain_tokenizer = Tokenizer.from_pretrained(base_tokenizer_name)
    trainer = BpeTrainer(
        vocab_size=original_vocab_size + new_vocab_size,
        special_tokens=list(base_tokenizer.special_tokens_map.values()),
    )

    # Find tokens that appear frequently in domain but not in base vocab
    word_freq = Counter()
    for text in domain_texts:
        for word in text.split():
            subtokens = base_tokenizer.encode(word, add_special_tokens=False)
            if len(subtokens) >= 3:  # Only consider heavily fragmented words
                word_freq[word] += 1

    # Add top domain-specific tokens to the base tokenizer
    new_tokens = [word for word, count in word_freq.most_common(new_vocab_size) if count >= 10]
    num_added = base_tokenizer.add_tokens(new_tokens)
    print(f"Added {num_added} domain-specific tokens to vocabulary")

    base_tokenizer.save_pretrained(output_dir)
    return base_tokenizer, num_added


def resize_model_embeddings(
    model: AutoModelForCausalLM,
    new_vocab_size: int,
    init_strategy: str = "mean",  # 'mean', 'random', or 'nearest'
) -> AutoModelForCausalLM:
    """Resize model embeddings and initialize new token embeddings."""
    old_embeddings = model.get_input_embeddings().weight.data
    old_vocab_size = old_embeddings.shape[0]

    # Resize the model's embedding layers
    model.resize_token_embeddings(new_vocab_size)

    if init_strategy == "mean":
        # Initialize new embeddings as the mean of existing embeddings
        mean_embedding = old_embeddings.mean(dim=0)
        with torch.no_grad():
            model.get_input_embeddings().weight.data[old_vocab_size:] = mean_embedding
            # Also update the output (lm_head) embeddings
            if hasattr(model, "lm_head"):
                model.lm_head.weight.data[old_vocab_size:] = mean_embedding

    elif init_strategy == "nearest":
        # Initialize each new token embedding from its most similar existing token
        with torch.no_grad():
            for i in range(old_vocab_size, new_vocab_size):
                # Use random projection to find nearest neighbor
                random_idx = torch.randint(0, old_vocab_size, (5,))
                model.get_input_embeddings().weight.data[i] = (
                    old_embeddings[random_idx].mean(dim=0)
                )

    print(f"Resized embeddings: {old_vocab_size} -> {new_vocab_size}")
    return model


# Example usage
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B")

# Load domain texts
with open("/data/medical_sample.txt") as f:
    medical_texts = f.readlines()

# Analyze current tokenizer efficiency
stats = analyze_tokenizer_efficiency(tokenizer, medical_texts)
print(f"Current chars/token: {stats['chars_per_token']:.2f}")
print(f"Top fragmented words: {stats['top_fragmented_words'][:10]}")

# Extend tokenizer with medical vocabulary
extended_tokenizer, n_new = extend_tokenizer(
    "meta-llama/Llama-3.1-8B",
    medical_texts,
    new_vocab_size=500,
)

# Resize model embeddings
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B")
model = resize_model_embeddings(
    model,
    len(extended_tokenizer),
    init_strategy="mean",
)

Tokenizer adaptation is often overlooked in CPT but can significantly impact both training efficiency and model quality. This code demonstrates three critical operations:

  1. Efficiency analysis: Measures how well the base tokenizer handles domain text. A general English tokenizer applied to Hindi text might produce 3-4 tokens per word (fertility > 2.0), compared to 1.2-1.5 for English. High fertility means the model needs more tokens to process the same information, increasing compute cost and reducing effective context length.

  2. Vocabulary extension: Identifies frequently fragmented domain terms (e.g., medical terms like "electrocardiography" or Hindi words in Latin script) and adds them as single tokens. Adding 500-1000 domain tokens can reduce token count by 10-20%, directly reducing training cost.

  3. Embedding initialization: New token embeddings must be initialized carefully. Mean initialization (averaging all existing embeddings) is the safest default. The model will learn proper embeddings during CPT. Never use random initialization -- it creates large gradient spikes that can destabilize early training.

Configuration Example
# Continued pretraining configuration (YAML format)
model:
  name: meta-llama/Llama-3.1-8B
  dtype: bfloat16
  attention: flash_attention_2

data:
  domain_corpus: /data/medical_corpus_curated/
  general_replay: HuggingFaceFW/fineweb-edu
  mixing_ratio: 0.7  # 70% domain, 30% general
  max_seq_length: 4096
  num_workers: 8

tokenizer:
  extend_vocab: false  # Set true if domain needs new tokens
  new_tokens: 0

training:
  total_tokens: 50_000_000_000  # 50B tokens
  per_device_batch_size: 2
  gradient_accumulation_steps: 32
  effective_batch_size: 512  # per_device * grad_accum * num_gpus
  learning_rate: 3e-5  # 10x lower than pretraining LR
  min_learning_rate: 3e-6
  lr_scheduler: cosine
  warmup_steps: 500
  weight_decay: 0.1
  gradient_checkpointing: true
  bf16: true

distributed:
  backend: deepspeed
  zero_stage: 3
  num_gpus: 8

evaluation:
  eval_steps: 2500
  domain_metrics:
    - domain_perplexity
    - medical_qa_accuracy
  general_metrics:
    - mmlu_5shot
    - hellaswag
    - arc_challenge
  forgetting_threshold: 0.03  # Alert if general drops > 3%

checkpointing:
  save_steps: 2500
  keep_last_n: 5

Common Implementation Mistakes

  • Using the initial pretraining learning rate: The most common and destructive mistake. If Llama 3 was pretrained with LR=3e-4, using the same LR for CPT will catastrophically overwrite the pretrained representations. Use 10-100x lower LR (3e-6 to 3e-5) for CPT. When in doubt, start lower and increase.

  • Training on pure domain data without replay: Without mixing in 20-30% general data, the model will catastrophically forget its general capabilities within a few thousand steps. The model becomes a domain savant that can't perform basic tasks. Always include a general replay buffer.

  • Insufficient deduplication of the domain corpus: Web-scraped domain corpora are notoriously duplicated. Training on a corpus where 30% of documents are near-duplicates wastes compute and can cause the model to memorize specific passages rather than learn generalizable domain knowledge. Always run MinHash deduplication before CPT.

  • Extending the tokenizer without sufficient CPT: If you add 1,000 new tokens to the vocabulary, those tokens start with random or mean-initialized embeddings. You need enough continued pretraining (at minimum 1B tokens) for the model to learn proper representations for the new tokens. Adding tokens and then only fine-tuning on 50K examples will produce garbage for those tokens.

  • Ignoring evaluation during training: CPT without periodic evaluation on general benchmarks is flying blind. By the time you notice catastrophic forgetting post-training, you've wasted all the compute. Evaluate every 2,000-5,000 steps on both domain and general metrics.

  • Using too small a corpus: CPT needs substantial data to be effective. For an 8B model, a minimum of 5-10B domain tokens is recommended. With less than 1B tokens, you're better off using LoRA fine-tuning on domain instruction data. CPT with insufficient data overfits to the small corpus and provides minimal benefit.

When Should You Use This?

Use When

  • You need the model to deeply understand domain-specific terminology and concepts -- not just pattern-match on keywords but reason about domain relationships (e.g., understanding that drug interactions affect dosage recommendations in medical contexts)

  • Your domain has a large unlabeled corpus (10B+ tokens) that is substantially different from general web text -- legal documents, scientific papers, financial filings, code in specific frameworks, multilingual text

  • You need better performance than RAG for tasks requiring deep domain reasoning, where the context window is insufficient to include all relevant knowledge at inference time

  • You're building a foundation model for a specific vertical that will be fine-tuned for multiple downstream tasks within that domain (e.g., a medical foundation model for diagnosis, drug interaction, clinical note summarization)

  • The target domain has specialized vocabulary that the base tokenizer handles poorly -- Indian languages, chemical formulas, legal citations, code in niche programming languages

  • You have the compute budget for training on billions of tokens -- CPT is not cheap, but it's 100-1000x cheaper than pretraining from scratch

  • Your task requires the model to generate domain-fluent text rather than just answer questions -- domain fluency requires the kind of deep representation alignment that only CPT provides

Avoid When

  • Your domain corpus is small (<1B tokens) -- LoRA fine-tuning or RAG will be more cost-effective and less risky than CPT on insufficient data

  • You need factual recall of specific documents rather than general domain understanding -- RAG is better suited for retrieving and referencing specific passages from a corpus

  • Your budget is tightly constrained and you need results quickly -- CPT requires significant compute (thousands of GPU-hours) and careful hyperparameter tuning, while RAG or LoRA can be deployed in hours

  • The domain is rapidly changing (e.g., current news, stock prices) -- CPT bakes knowledge into weights as a static snapshot, while RAG can incorporate real-time information

  • The base model already performs well on your domain tasks -- test the base model first; if it achieves 90%+ of target performance with good prompting, CPT may offer diminishing returns

  • You lack infrastructure for distributed training -- CPT of 7B+ models requires multi-GPU setups with FSDP or DeepSpeed, which adds operational complexity

Key Tradeoffs

The CPT Investment Calculus

Continued pretraining is the most compute-intensive model adaptation technique short of pretraining from scratch. The core tradeoff is upfront investment vs. downstream returns.

ApproachCompute CostKnowledge DepthForgetting RiskTime to DeployBest For
Prompting~$0Surface levelNoneHoursQuick prototyping
RAG~$100-1KAccess to documentsNoneDaysFactual recall, dynamic knowledge
LoRA Fine-tuning~$50-500Task-specific behaviorLowDaysInstruction following, formatting
CPT~$2K-50KDeep domain understandingMediumWeeksDomain foundation models
Pretraining from scratch~$100K-10M+Full controlN/AMonthsNovel architectures, unique data

Data Quality vs. Quantity

A critical tradeoff in CPT is data quality versus quantity. BloombergGPT trained on 363B financial tokens -- a massive corpus. But research shows that curated smaller corpora can outperform raw larger ones. The AWS team found that training on just 10% of a financial corpus (selected via domain-adaptive data selection) outperformed training on the full corpus.

The practical implication: invest more time in corpus curation (deduplication, quality filtering, balanced sampling) rather than simply collecting more data.

Forgetting vs. Adaptation

Every CPT run involves a tension between adapting to the new domain and retaining general capabilities. The mixing ratio α\alpha is the primary control:

Mixing Ratio (domain %)Domain GainGeneral ForgettingUse Case
50%ModerateVery lowConservative adaptation
70%HighLow-moderateStandard recommendation
90%Very highModerate-highDeep domain specialization
100%MaximumHighDomain-only model (rare)

Practitioner's Note: If you're unsure about the right mixing ratio, start with 70% domain / 30% general. This has consistently produced good results across medical, legal, financial, and code domains. Only go above 80% domain if you've validated that general capability retention meets your requirements.

Alternatives & Comparisons

Domain adaptation at the feature level (e.g., adversarial domain adaptation, domain-invariant representations) operates on the representation space without modifying model weights through additional training. It's faster but provides shallower adaptation than CPT. Choose feature-level domain adaptation for quick deployment with pretrained embeddings; choose CPT when you need the model to fundamentally understand domain concepts and generate domain-fluent text.

LoRA fine-tuning adapts the model to specific tasks by training small low-rank matrices, typically on labeled instruction data. It changes model behavior but not its fundamental knowledge. LoRA is 10-100x cheaper than CPT and ideal for task-specific adaptation. Choose LoRA when the base model already has sufficient domain knowledge; choose CPT when the model needs to learn new domain concepts before any task-specific training.

Full fine-tuning updates all model parameters on task-specific data. While it can inject some domain knowledge, it's typically done on smaller labeled datasets (10K-500K examples) and uses higher learning rates than CPT. CPT uses billions of unlabeled tokens with the pretraining objective, providing broader knowledge injection. Choose full fine-tuning for task-specific performance with labeled data; choose CPT for broad domain knowledge with unlabeled corpora.

Instruction tuning teaches the model to follow instructions in a specific format, typically using supervised fine-tuning on instruction-response pairs. It's a downstream step that usually follows CPT. The standard pipeline for domain LLMs is: base model -> CPT (domain knowledge) -> instruction tuning (task behavior). Choose instruction tuning alone when the base model has sufficient knowledge; use CPT + instruction tuning when domain knowledge needs enhancement.

Knowledge distillation transfers capabilities from a larger teacher model to a smaller student model. It can be used in conjunction with CPT -- for example, distilling a 70B domain-adapted model into a 7B model. Choose distillation when you need a smaller, faster model with domain expertise; choose CPT when you're building the initial domain-adapted model that will serve as the teacher.

Pros, Cons & Tradeoffs

Advantages

  • Deep domain knowledge injection: CPT fundamentally rewires the model's internal representations, encoding domain concepts as first-class knowledge rather than surface-level pattern matching. This produces models that reason about domain concepts more accurately and generate more fluent domain text.

  • Improved downstream task performance: Models that undergo CPT before fine-tuning consistently outperform models that are directly fine-tuned on domain tasks. The DAPT paper showed 3-8 F1 point improvements, and subsequent work has shown even larger gains with LLMs.

  • Better tokenizer efficiency for domain text: When combined with tokenizer adaptation, CPT reduces the token count for domain text by 10-20%, effectively increasing the model's context window for domain content and reducing both training and inference costs.

  • Foundation for multiple downstream tasks: A single CPT run produces a domain-adapted base model that can be fine-tuned for many different tasks within the domain. This amortizes the CPT cost across all downstream applications -- a medical CPT model serves diagnosis, summarization, coding, and Q&A.

  • Reduced hallucination in domain contexts: Models with CPT-injected domain knowledge hallucinate less on domain topics because the correct information is encoded in the weights, reducing reliance on pattern-matching or retrieval.

  • Cost-effective compared to pretraining from scratch: CPT achieves comparable domain performance to training a domain-specific model from scratch (like BloombergGPT) at 0.1-1% of the compute cost, because it starts from a model that already has strong general capabilities.

  • Compatible with all downstream adaptation methods: The output of CPT is a standard pretrained model checkpoint that works with LoRA, QLoRA, full fine-tuning, instruction tuning, RLHF, DPO -- any adaptation technique in the standard toolkit.

Disadvantages

  • High compute cost: CPT requires training on billions of tokens, costing 2,00050,000+( INR1.7lakh42lakh)dependingonmodelsizeandcorpussize.ThisissignificantlymoreexpensivethanLoRA(2,000-50,000+ (~INR 1.7 lakh - 42 lakh) depending on model size and corpus size. This is significantly more expensive than LoRA (50-500) or RAG (primarily infrastructure costs).

  • Risk of catastrophic forgetting: Without careful data mixing and learning rate tuning, CPT can degrade the model's general capabilities. A model that forgets basic arithmetic or loses its ability to follow instructions is worse than the original.

  • Requires large domain corpus: Effective CPT needs 5-100B+ tokens of high-quality domain text. Many specialized domains (rare diseases, niche engineering fields) simply don't have enough publicly available text. Corpus curation is labor-intensive.

  • Long iteration cycles: Each CPT experiment takes days to weeks, making hyperparameter tuning slow and expensive. If the learning rate is wrong or the data mix is bad, you discover this after spending thousands of dollars in compute.

  • Infrastructure complexity: CPT of 7B+ models requires multi-GPU distributed training with FSDP or DeepSpeed, which demands specialized engineering expertise and infrastructure. This is beyond the capability of many smaller teams.

  • Static knowledge snapshot: Knowledge injected via CPT is frozen at training time. Unlike RAG, CPT cannot incorporate new information without retraining. For rapidly evolving domains, this creates a staleness problem.

Failure Modes & Debugging

Catastrophic Forgetting

Cause

Training on pure domain data or using too high a learning rate, causing the model to overwrite pretrained representations that encode general knowledge. The model's loss on domain data decreases while its ability to perform general tasks (math, reasoning, instruction following) degrades severely.

Symptoms

The model produces fluent domain text but fails at basic tasks: it can't do simple arithmetic, loses its instruction-following ability, or generates incoherent responses to general questions. MMLU scores may drop by 10-20+ points. The model may also lose its ability to switch between domain and general contexts.

Mitigation

Always include 20-30% general data in the training mix as a replay buffer. Use a learning rate 10-100x lower than the original pretraining LR. Monitor general benchmarks (MMLU, HellaSwag, ARC) every 2,000-5,000 training steps. If general scores drop more than 3%, reduce LR or increase the general data ratio. Consider EWC regularization for additional protection.

Domain Corpus Contamination

Cause

The domain corpus contains significant amounts of noise: machine-translated text, OCR errors, boilerplate content, duplicated passages, or text from the wrong domain. The model learns these artifacts instead of genuine domain knowledge.

Symptoms

Domain perplexity decreases during training (the model is learning), but downstream task performance doesn't improve or even degrades. The model may generate text with OCR-like artifacts, unnatural repetition, or boilerplate phrases. Quality on clean domain test sets is poor despite good training metrics.

Mitigation

Invest heavily in corpus curation: run MinHash deduplication (remove 20-40% near-duplicates), perplexity-based quality filtering (remove text that's too noisy or too templated), and manual spot-checking of random samples. Use a staged approach: clean a small high-quality subset first, validate CPT works on it, then scale up.

Tokenizer Mismatch Degradation

Cause

Using the original tokenizer on text with significantly different vocabulary (e.g., Hindi text with an English tokenizer, SMILES chemical notation, or dense medical abbreviations). The tokenizer fragments domain-specific terms into many subwords, reducing effective context length and making it harder for the model to learn domain-specific patterns.

Symptoms

Training is abnormally slow (more tokens needed to represent the same content). The model struggles with domain-specific terminology even after extensive CPT. Inference is expensive because domain text requires 2-3x more tokens than English text of equivalent content. Token-level perplexity may look good, but word-level or character-level perplexity is poor.

Mitigation

Analyze tokenizer efficiency before CPT: compute chars-per-token and fertility metrics on domain text. If fertility is > 2.0 (compared to ~1.3 for English), consider tokenizer adaptation: extend the vocabulary with 500-2000 domain-specific tokens and initialize their embeddings via mean pooling. Budget extra CPT tokens (1B+) for the model to learn the new token embeddings.

Learning Rate Mismatch

Cause

Using a learning rate that's too high (causes forgetting) or too low (wastes compute without meaningful adaptation). This is particularly common when teams copy hyperparameters from fine-tuning recipes (which use higher LR) or from pretraining-from-scratch setups.

Symptoms

Too high LR: Training loss spikes initially and general benchmark scores drop rapidly within the first few thousand steps. The model may become incoherent on both domain and general text. Too low LR: Training loss decreases very slowly, domain benchmarks show minimal improvement after many epochs, and the model's domain capabilities remain similar to the base model.

Mitigation

For an 8B model, start with LR = 3e-5 (assuming pretraining peak was 3e-4). Run a short pilot (1,000 steps) and check: (1) training loss should decrease smoothly, (2) general benchmarks should not drop > 1%. If loss spikes, halve the LR. If loss barely moves after 1,000 steps, double the LR. Use cosine annealing that matches the pretraining schedule.

Insufficient Training Duration

Cause

Stopping CPT too early, before the model has processed enough domain tokens to meaningfully shift its representations. This is common when compute budget is tight and teams try to do CPT with 1-2B tokens on a 7B+ model.

Symptoms

Domain perplexity decreases marginally but downstream task improvements are negligible compared to the base model. The model shows surface-level familiarity with domain terms but lacks deeper understanding -- it uses the right vocabulary but makes conceptual errors. The CPT appears to have been wasted.

Mitigation

As a rule of thumb, plan for at least 5-10B tokens for an 8B model and 20-50B tokens for a 70B model. If budget is constrained, consider using data selection methods (like the AWS ETS-DACP approach) to select the most informative 10% of the corpus, or use LoRA fine-tuning instead of CPT. Monitor the domain perplexity curve -- if it's still decreasing significantly, training hasn't saturated.

Data Mixing Ratio Imbalance

Cause

The domain/general data mixing ratio is poorly calibrated. Too much domain data (>90%) causes forgetting. Too much general data (<50% domain) wastes compute on data the model has already learned, producing minimal domain adaptation.

Symptoms

Too much domain data: General benchmark regression similar to catastrophic forgetting, but more gradual. The model becomes a domain specialist at the expense of general capability. Too much general data: Domain perplexity barely improves, downstream domain tasks show marginal gains, and most of the compute budget is spent re-learning general knowledge the model already knew.

Mitigation

Start with a 70/30 domain/general split. Validate with a pilot run of 5,000 steps, measuring both domain perplexity and general benchmarks. If general benchmarks drop > 3%, increase general replay to 40%. If domain perplexity barely moves, increase domain ratio to 80%. The DoReMi algorithm can automatically optimize these ratios, but requires training a small proxy model first.

Placement in an ML System

Where CPT Fits in the Model Development Lifecycle

Continued pretraining occupies a specific position in the LLM development pipeline: after base model selection and before any task-specific adaptation. The standard pipeline for building a production domain LLM is:

  1. Base model selection: Choose a pretrained foundation model (e.g., Llama 3 8B, Mistral 7B) based on the target domain, model size, and licensing constraints.
  2. Corpus curation: Collect, clean, deduplicate, and filter domain-specific text into a high-quality training corpus.
  3. Continued pretraining: Train the base model on the domain corpus (mixed with general replay data) using the autoregressive language modeling objective.
  4. Instruction tuning / SFT: Fine-tune the CPT model on domain-specific instruction-response pairs to teach it task-specific behavior.
  5. Alignment (RLHF/DPO): Optionally align the model with human preferences for safety, helpfulness, and domain-appropriate responses.
  6. Evaluation and deployment: Benchmark on domain-specific and general evaluations, then deploy to production.

In organizations like Bloomberg, Google (Med-PaLM), or Indian AI companies like Sarvam AI, CPT is managed as a dedicated infrastructure workload with its own compute allocation, data pipelines, and evaluation frameworks. The CPT output -- a domain-adapted base model -- is treated as a shared asset that multiple downstream teams can fine-tune for specific applications.

Organizational Pattern: Large organizations often maintain a "model kitchen" where the CPT team produces domain-adapted base models on a regular cadence (e.g., quarterly), incorporating new domain data. Downstream teams consume these base models and apply task-specific fine-tuning. This separation of concerns between "knowledge injection" (CPT) and "behavior shaping" (fine-tuning) is a key architectural pattern for scalable domain AI development.

Pipeline Stage

Training / Domain Adaptation

Upstream

  • Base Model Selection (pretrained checkpoint from model hub)
  • Domain Corpus Curation (cleaned, deduplicated, filtered text)
  • Tokenizer (original or domain-adapted vocabulary)

Downstream

  • Instruction Tuning / SFT
  • LoRA or QLoRA Fine-tuning
  • RLHF / DPO Alignment
  • Model Evaluation & Benchmarking
  • Model Registry / Version Control

Scaling Bottlenecks

Training Compute

The primary bottleneck for CPT is raw training compute. For a 70B model on 100B tokens, the compute requirement is approximately 42,000 GPU-hours on A100 80GB -- about $84,000 (~INR 70 lakh) at cloud rates. This makes CPT infeasible for most Indian startups unless they have access to subsidized GPU clusters (e.g., through government programs like India AI Mission or through academic partnerships).

Data Pipeline Throughput

With large-scale CPT, the data pipeline (tokenization, shuffling, batching) can become a bottleneck if not properly parallelized. A training throughput of 50,000 tokens/second per GPU requires the data pipeline to sustain ~400,000 tokens/second for an 8-GPU setup. Pre-tokenizing the corpus and storing it in a memory-mapped format (e.g., using datasets library's memory mapping) is essential.

Evaluation Overhead

Running comprehensive evaluation (MMLU, domain benchmarks, perplexity) at every checkpoint adds significant overhead. For a 70B model, a full MMLU evaluation takes 2-4 hours. Evaluate on a subset (e.g., 1,000 MMLU questions instead of 14,000) during training, with full evaluation only at key milestones.

Checkpoint Storage

Saving checkpoints for a 70B model produces ~140 GB files (bf16). Saving every 2,500 steps over 50,000 steps generates 20 checkpoints = 2.8 TB of checkpoint data. Budget for this storage and implement a rotation policy that keeps only the last N checkpoints plus milestone checkpoints.

Production Case Studies

Bloomberg (BloombergGPT)Financial Technology

Bloomberg trained a 50B-parameter LLM from scratch on a mixed corpus of 363B tokens of financial data (Bloomberg's proprietary archive of financial documents, news, filings, and reports) and 345B tokens of general web data. While technically a from-scratch pretraining rather than continued pretraining, the data mixing strategy -- roughly 51% domain + 49% general -- has become the reference point for CPT mixing ratios in the financial domain.

Outcome:

BloombergGPT significantly outperformed GPT-3 and BLOOM on financial NLP tasks (sentiment analysis, NER, QA on financial documents) while maintaining competitive performance on general benchmarks. The 51/49 mixing ratio demonstrated that domain performance doesn't require sacrificing general capabilities.

Meta (Code Llama)Technology / Open-Source AI

Meta produced Code Llama by continued pretraining of Llama 2 on 500B tokens of primarily code data (85% code from GitHub, 8% natural language about code, 7% general text). This is one of the largest and most successful CPT efforts, transforming a general-purpose language model into a state-of-the-art code generation model. Code Llama was then further specialized into Code Llama-Python (additional 100B Python tokens) and Code Llama-Instruct (instruction tuning).

Outcome:

Code Llama 34B achieved 48.8% on HumanEval and 55.0% on MBPP, significantly outperforming Llama 2 34B's baseline code capabilities. The 7B variant was competitive with much larger general models, demonstrating that CPT can effectively specialize a model while keeping it relatively small.

Sarvam AI (OpenHathi / Sarvam 1)Indian AI / Multilingual NLP

Sarvam AI, an Indian AI startup, used continued pretraining to adapt Llama 2 7B for Hindi and other Indic languages, producing OpenHathi. The model was trained on Hindi, English, and Hinglish data, extending Llama's capabilities to Indian languages at a fraction of the cost of training from scratch. They later released Sarvam 1, a 2B-parameter model supporting 10 Indic languages with a custom tokenizer optimized for Indic scripts.

Outcome:

OpenHathi achieved GPT-3.5-level performance for Hindi language tasks, demonstrating that continued pretraining can effectively extend English-centric LLMs to low-resource languages. The approach proved far more cost-effective than training an Indic LLM from scratch, enabling a startup with limited resources to compete with much larger organizations.

Princeton / Eleuther AI (Llemma)AI Research / Mathematics

Azerbayev et al. continued pretraining Code Llama on the Proof-Pile-2, a mixture of scientific papers, web data containing mathematics, and mathematical code. The resulting model, Llemma, was trained in two sizes (7B and 34B) and became the first open-source model to outperform Google's Minerva on an equi-parameter basis for mathematical reasoning.

Outcome:

Llemma 34B outperformed Minerva 62B on the MATH benchmark despite being nearly half the size, demonstrating the power of targeted continued pretraining on high-quality domain data. Llemma was also capable of tool use (Python, Sage) and formal theorem proving in Lean/Isabelle without additional fine-tuning.

Amazon (Financial Domain LLMs)Cloud / Financial Services

AWS's applied science team demonstrated efficient data selection for continued pretraining in the financial domain. They showed that training on just 10% of a financial corpus -- selected using domain-adaptive data selection methods (ETS-DACP) -- outperformed training on the full corpus. This addressed the common concern that CPT requires massive domain corpora.

Outcome:

Efficient data selection methods achieved better downstream financial task performance than standard CPT with 10x more data, while also showing that continual pretraining does not adversely affect non-domain performance. This validated the importance of data quality over quantity in CPT.

Tooling & Ecosystem

The most accessible framework for CPT. The Trainer class supports autoregressive language modeling with DeepSpeed and FSDP integration. Combined with datasets for data loading and trl for training recipes, this is the standard starting point for CPT projects. Supports gradient checkpointing, mixed precision, and multi-GPU training out of the box.

NVIDIA Megatron-LM
Python / CUDAOpen Source

High-performance framework for pretraining and continued pretraining of large transformer models. Supports 3D parallelism (data + tensor + pipeline), optimized attention kernels, and efficient data loading. The standard choice for CPT of 30B+ parameter models where training throughput is critical. Used by many organizations building production domain LLMs.

nanotron
PythonOpen Source

Hugging Face's minimalistic large language model training library with 3D parallelism support. More Pythonic and easier to customize than Megatron-LM, while achieving comparable training throughput. Good for teams that need custom data pipelines or training logic without diving into Megatron-LM's codebase.

LLaMA-Factory
PythonOpen Source

Unified fine-tuning and continued pretraining framework with a web UI. Supports pretraining (CLM), SFT, RLHF, DPO, and all PEFT methods. Popular for its low barrier to entry -- you can configure and launch a CPT run through the browser. Supports 100+ model architectures.

datasketch (MinHash LSH)
PythonOpen Source

Python library for probabilistic data structures, including MinHash and MinHash LSH for near-duplicate detection. Essential for corpus deduplication in the CPT data preparation pipeline. Can process millions of documents efficiently using locality-sensitive hashing.

Dolma Toolkit
Python / RustOpen Source

Allen AI's data curation toolkit used to build the Dolma dataset (3 trillion tokens). Provides tools for deduplication, quality filtering, PII removal, and content classification. Can be adapted for domain corpus curation in CPT pipelines. Used in the creation of OLMo's training data.

Research & References

Don't Stop Pretraining: Adapt Language Models to Domains and Tasks

Gururangan, Marasovic, Swayamdipta, Lo, Beltagy, Downey & Smith (2020)ACL 2020

The foundational paper for continued pretraining. Introduced DAPT (Domain-Adaptive Pretraining) and TAPT (Task-Adaptive Pretraining), showing consistent improvements across four domains and eight tasks when further pretraining RoBERTa on domain-specific text. Established that domain-adaptive pretraining is a simple, effective first step for domain specialization.

BloombergGPT: A Large Language Model for Finance

Wu, Irsoy, Lu, Daber, Li, Winchester, Treviso, et al. (2023)arXiv 2023

Trained a 50B-parameter LLM on a mixed corpus of 363B financial tokens and 345B general tokens. Demonstrated that a balanced domain/general data mix achieves strong domain performance without sacrificing general capabilities. The data mixing strategy (51% domain) has become a reference point for CPT practitioners.

Code Llama: Open Foundation Models for Code

Roziere, Gehring, Gloeckle, Sootla, Gat, Tan, et al. (2023)arXiv 2023

Produced state-of-the-art code generation models by continued pretraining of Llama 2 on 500B tokens (85% code, 8% code-adjacent, 7% general). Demonstrated multi-stage CPT: base Code Llama -> Code Llama Python (additional Python CPT) -> Code Llama Instruct (instruction tuning). A masterclass in staged domain adaptation.

Llemma: An Open Language Model for Mathematics

Azerbayev, Schoelkopf, Paster, Dos Santos, McAleer, Jiang, Deng, Biderman & Welleck (2023)ICLR 2024

Continued pretraining of Code Llama on the Proof-Pile-2 (a mixture of scientific papers, math-heavy web data, and mathematical code). Llemma 34B outperformed the larger Minerva 62B on MATH, demonstrating that targeted CPT on high-quality domain data can be more effective than scale alone.

DoReMi: Optimizing Data Mixtures Speeds Up Language Model Pretraining

Xie, Pham, Dong, Du, Liu, Lu, Liang, Le, Ma & Yu (2023)NeurIPS 2023

Proposed an automatic method for optimizing data mixing ratios using a small proxy model trained with group distributionally robust optimization (Group DRO). The optimized mixing weights improved downstream performance by 6.5% and reached baseline accuracy with 2.6x fewer training steps. Directly applicable to optimizing the domain/general data mix in CPT.

Towards Effective and Efficient Continual Pre-training of Large Language Models

Chen, Wu, Ling, Liang & Shen (2024)arXiv 2024

Presented a comprehensive recipe for continued pretraining of Llama 3 8B, achieving +8.81 on C-Eval and +12.00 on MATH through data mixture and curriculum strategies. Showed that synthetic data (scientific QA pairs) combined with web data can significantly enhance specific reasoning abilities during CPT.

Interview & Evaluation Perspective

Common Interview Questions

  • What is continued pretraining and how does it differ from fine-tuning?

  • How do you choose the data mixing ratio between domain and general data for CPT?

  • What strategies would you use to mitigate catastrophic forgetting during continued pretraining?

  • Walk me through the end-to-end pipeline for building a domain-specific LLM using continued pretraining.

  • How would you estimate the compute cost for continued pretraining of a 70B model on 100B tokens?

  • Compare CPT vs RAG vs fine-tuning for injecting domain knowledge into an LLM.

  • How should the learning rate schedule differ between initial pretraining and continued pretraining?

  • When would you choose to adapt the tokenizer for continued pretraining, and how?

Key Points to Mention

  • CPT uses the same self-supervised objective as pretraining (next-token prediction) but starts from pretrained weights and trains on domain-specific data. It changes what the model knows, while fine-tuning changes how it behaves.

  • The data mixing ratio is critical: 70% domain / 30% general is a strong default. Too much domain data causes forgetting; too little wastes compute. The DoReMi algorithm can optimize this automatically.

  • Learning rate for CPT should be 10-100x lower than initial pretraining LR. For Llama 3 (pretrained at 3e-4), use 3e-6 to 3e-5. Use cosine annealing matching the pretraining schedule.

  • Catastrophic forgetting mitigation has four levers: data replay (general data in training mix), lower learning rate, EWC regularization, and continuous evaluation on general benchmarks.

  • Compute cost formula: FLOPs = 6 * N * D. For 8B model on 50B tokens: ~2,200 GPU-hours on A100 = ~$4,400 (~INR 3.7 lakh).

  • Real-world examples: Code Llama (500B code tokens on Llama 2), BloombergGPT (363B financial tokens), Llemma (math CPT outperforming larger Minerva), Sarvam OpenHathi (Indic language CPT).

  • Corpus curation is 50% of the work: deduplication (MinHash), quality filtering (perplexity scoring), balanced sampling across sub-domains, and PII removal.

Pitfalls to Avoid

  • Conflating continued pretraining with fine-tuning. CPT uses unsupervised pretraining objective on large unlabeled corpora; fine-tuning uses supervised objectives on smaller labeled datasets. They operate at different layers of the adaptation stack.

  • Ignoring catastrophic forgetting. Any answer about CPT that doesn't discuss forgetting mitigation strategies shows a lack of practical experience.

  • Claiming that more data is always better. Research (AWS ETS-DACP) shows that curated 10% of the corpus can outperform the full corpus. Data quality trumps quantity.

  • Not discussing cost and compute requirements. CPT is expensive, and a senior engineer should be able to estimate GPU-hours and cost for a given configuration.

  • Forgetting to mention the tokenizer. For cross-lingual CPT (e.g., English model to Hindi), tokenizer adaptation can be as important as the training itself.

Senior-Level Expectation

A senior/staff engineer should discuss CPT at three levels: (1) Technical: articulate the training objective, learning rate scheduling (why 10-100x lower LR), data mixing strategies (with concrete ratios), and forgetting mitigation (replay, EWC, monitoring). (2) Engineering: cover the end-to-end pipeline from corpus curation (deduplication, quality filtering) through distributed training (FSDP/DeepSpeed configuration, checkpoint management) to evaluation (dual tracking of domain and general metrics). (3) Strategic: reason about when CPT is worth the investment vs. alternatives (RAG, LoRA), estimate compute costs in dollars/INR and GPU-hours for specific configurations, and design an organizational workflow where a CPT team produces domain base models consumed by multiple downstream fine-tuning teams. The ability to discuss real examples (Code Llama's 85/8/7 data split, BloombergGPT's 51/49 mix, Sarvam's Indic language approach) and connect them to specific design decisions demonstrates genuine depth.

Summary

What We Covered

Continued pretraining (CPT) is the practice of further training a pretrained language model on domain-specific data using the self-supervised language modeling objective, bridging the gap between general pretraining and task-specific fine-tuning. It is the primary method for injecting deep domain knowledge into LLMs -- fundamentally rewiring the model's internal representations so that domain concepts, terminology, and reasoning patterns become natively encoded in the weights.

The key technical decisions in CPT are: data mixing ratio (70% domain / 30% general replay is the standard default), learning rate (10-100x lower than initial pretraining, with cosine annealing), corpus curation (deduplication, quality filtering, balanced sampling are more important than raw corpus size), and catastrophic forgetting mitigation (data replay, low learning rate, continuous evaluation on general benchmarks, optionally EWC regularization). Tokenizer adaptation is critical for cross-lingual CPT (e.g., English to Indian languages) and specialized domains where the base tokenizer fragments domain text excessively.

Real-world CPT success stories span multiple domains: Code Llama (500B code tokens, state-of-the-art code generation), BloombergGPT (363B financial tokens, domain-leading financial NLP), Llemma (mathematical reasoning from Proof-Pile-2), and Indian language models like Sarvam AI's OpenHathi and Ola's Krutrim (Indic language adaptation of Llama-family models). The compute cost for CPT is substantial -- approximately $4,400 (~INR 3.7 lakh) for an 8B model on 50B tokens -- but it's 100-1000x cheaper than pretraining from scratch and produces a domain-adapted base model that serves as a foundation for multiple downstream applications. For teams building serious domain AI, CPT is not optional -- it's the foundation that makes everything downstream work better.

ML System Design Reference · Built by QnA Lab