Continued Pretraining in Machine Learning
Continued pretraining (CPT) is the practice of further training a pretrained language model on a domain-specific corpus using the same self-supervised objective (typically next-token prediction) before any task-specific fine-tuning. It occupies a critical middle layer in the model adaptation stack: pretraining gives you general language understanding, continued pretraining injects domain knowledge, and fine-tuning teaches the model to follow instructions or perform specific tasks.
The intuition is straightforward. A model like Llama 3 has seen trillions of tokens of internet text, but its exposure to, say, Indian legal judgments, radiology reports, or semiconductor datasheets is sparse and shallow. Continued pretraining on a curated corpus of 10-100 billion domain tokens rewires the model's internal representations so that domain-specific concepts, terminology, and reasoning patterns become first-class citizens in the weight space -- not distant memories retrieved through clever prompting.
Since Gururangan et al.'s seminal 2020 paper "Don't Stop Pretraining," CPT has become the standard first step for building domain-specific LLMs. BloombergGPT mixed 363 billion financial tokens with 345 billion general tokens. Code Llama continued pretraining Llama 2 on 500 billion tokens of code. Llemma adapted Code Llama on mathematical text from the Proof-Pile-2. In India, Sarvam AI's OpenHathi and Ola's Krutrim used continued pretraining to extend Llama-family models to Hindi and other Indic languages.
What makes CPT both powerful and tricky is the balancing act: you want the model to deeply absorb domain knowledge without forgetting its general capabilities. Get the data mix wrong, the learning rate too high, or the tokenizer mismatched, and you end up with a model that speaks fluent radiology but can't handle basic arithmetic. This guide covers every aspect of that balancing act -- from corpus curation to compute cost estimation.
Concept Snapshot
- What It Is
- A training phase where a pretrained language model undergoes additional self-supervised learning on a domain-specific corpus, adapting its internal representations to encode domain knowledge before any task-specific fine-tuning.
- Category
- Model Training
- Complexity
- Advanced
- Inputs / Outputs
- Inputs: pretrained base model + curated domain corpus + optional general data for replay. Outputs: domain-adapted pretrained model ready for downstream fine-tuning or direct prompting.
- System Placement
- Sits between initial pretraining and fine-tuning (SFT/RLHF/LoRA) in the ML pipeline. Applied after base model selection and corpus curation, before instruction tuning or task-specific adaptation.
- Also Known As
- Continual Pretraining, Domain-Adaptive Pretraining (DAPT), Second-Phase Pretraining, Adaptive Pretraining, Domain Continued Pretraining
- Typical Users
- ML Engineers, NLP Engineers, Applied Scientists, Foundation Model Teams, Domain AI Specialists
- Prerequisites
- Language model pretraining fundamentals (next-token prediction, causal LM), Transformer architecture (attention, MLP layers, positional encoding), Learning rate scheduling (cosine annealing, warmup), Data preprocessing for LLM training (tokenization, deduplication), Distributed training basics (FSDP, DeepSpeed, data parallelism)
- Key Terms
- DAPT (Domain-Adaptive Pretraining)TAPT (Task-Adaptive Pretraining)data mixing ratiocatastrophic forgettingreplay buffertokenizer adaptationlearning rate re-warmingdomain corpusperplexity
Why This Concept Exists
The Domain Knowledge Gap
Large language models are trained on massive, general-purpose web corpora -- Common Crawl, Wikipedia, books, code repositories. This gives them impressive breadth: they can discuss philosophy, write Python, and summarize news articles. But breadth is not depth.
Consider a model deployed for Indian legal tech. The general pretraining corpus might contain a few hundred thousand tokens of Indian court judgments, scattered across millions of web pages. But the Indian legal system has its own vocabulary ("writ petition," "Section 498A," "bail anticipatory"), its own reasoning patterns (precedent-based argumentation citing specific IPC sections), and its own document structures (FIR format, judgment headings). A general model has seen these patterns, but they occupy a tiny fraction of its parameter space. It's like asking someone who read one chapter on Indian law in a 10,000-chapter encyclopedia to draft a bail application.
Continued pretraining solves this by dedicating substantial compute to domain-specific text, effectively "re-weighting" the model's knowledge distribution toward the target domain.
From DAPT to Modern CPT
The formal study of continued pretraining began with Gururangan et al. (2020), who introduced the terms DAPT (Domain-Adaptive Pretraining) and TAPT (Task-Adaptive Pretraining). Their key finding was elegant: taking RoBERTa and further pretraining it on domain text (biomedical papers, computer science papers, news, reviews) consistently improved downstream task performance, sometimes by 3-8 F1 points. Moreover, TAPT -- pretraining on just the unlabeled task data -- provided additional gains on top of DAPT.
This observation scaled dramatically with the LLM era. As models grew from 340M (RoBERTa) to 7B-70B parameters, the returns from domain-specific continued pretraining grew proportionally. Bloomberg trained a 50B-parameter model from scratch on mixed financial and general data (BloombergGPT, 2023). Meta took Llama 2 and continued pretraining on 500B code tokens to produce Code Llama (2023). Azerbayev et al. continued pretraining Code Llama on mathematical text to produce Llemma (2023), which outperformed the much larger Minerva model.
Why Not Just Fine-Tune or Use RAG?
A natural question: if you want domain knowledge, why not just fine-tune on domain data or use retrieval-augmented generation (RAG)?
Fine-tuning (SFT, LoRA) teaches the model to follow specific instruction formats and perform specific tasks, but it doesn't fundamentally alter the model's world knowledge. You can fine-tune a model on 10,000 radiology report Q&A pairs, but if the base model doesn't deeply understand radiological concepts, the fine-tuned model will hallucinate confidently.
RAG retrieves relevant documents at inference time and includes them in the context. It's powerful for factual recall but limited by context window size, retrieval quality, and the model's ability to reason over retrieved text. If the model's internal representations don't align with the domain, it struggles to synthesize information from retrieved passages.
Continued pretraining operates at a deeper level: it rewires the model's internal representations so that domain concepts are natively encoded in the weights. After CPT, the model doesn't need to retrieve basic domain knowledge -- it's already there, enabling better reasoning, fewer hallucinations, and more natural domain fluency. The three approaches are complementary: CPT provides the knowledge foundation, fine-tuning adds task-specific behavior, and RAG provides access to dynamic or rare information.
Key Insight: Continued pretraining is the only approach that changes what the model knows. Fine-tuning changes how it behaves. RAG changes what it can access. For deep domain adaptation, you typically need all three.
Core Intuition & Mental Model
The Analogy: Becoming a Domain Expert
Imagine hiring a brilliant generalist consultant (the pretrained LLM) to work in your hospital's radiology department. This consultant has read widely -- they know some medical terminology, can discuss imaging modalities at a cocktail party level, and can follow basic clinical reasoning. But they're not a radiologist.
You have three options:
-
Give them a reference library (RAG): They can look things up when asked, but they don't deeply understand what they're reading. They'll miss subtleties and can't synthesize across sources effectively.
-
Train them on Q&A flashcards (fine-tuning): They learn to pattern-match specific question types, but their understanding is shallow. Ask something slightly outside the flashcard distribution and they're lost.
-
Send them to radiology residency (continued pretraining): They spend months reading thousands of radiology textbooks, case reports, and imaging studies. The knowledge becomes part of their mental model. They develop intuition for what's normal and abnormal, learn the vocabulary natively, and can reason about novel cases.
Continued pretraining is option 3. It's expensive -- residency takes years (or in ML terms, significant GPU hours) -- but it produces a fundamentally more capable domain specialist.
The Representation Geometry Intuition
At a more technical level, think about what continued pretraining does to the model's internal representation space. In a general pretrained model, concepts like "myocardial infarction," "ST elevation," and "troponin levels" are scattered across the embedding space, loosely associated but not tightly clustered. The model knows these words exist but hasn't deeply encoded their relationships.
Continued pretraining on medical text forces the model to repeatedly predict medical text -- to predict that "ST elevation" is often followed by discussion of "myocardial infarction" and "troponin levels." This repetition tightens the relevant representations, creating dense clusters of domain knowledge where related concepts are geometrically close in the embedding space and the attention patterns learn domain-specific reasoning paths.
The result: domain concepts move from the periphery of the model's representation space to become well-organized, easily accessible regions. This is why continued pretraining improves not just domain-specific tasks but also the model's ability to follow complex domain reasoning chains -- the building blocks are better organized.
Mental Model: Continued pretraining is like renovating a library. The books (knowledge) were already somewhere in the building, but they were scattered across random shelves. CPT reorganizes them into a dedicated domain section with proper indexing, making retrieval and reasoning dramatically more efficient.
Technical Foundations
The Training Objective
Continued pretraining uses the same autoregressive language modeling objective as initial pretraining. Given a domain corpus where each is a document, the model parameters (initialized from pretrained weights ) are optimized to minimize the negative log-likelihood:
The key difference from initial pretraining is that starts from (pretrained weights) rather than random initialization, and the training corpus is domain-specific rather than general.
Data Mixing Formulation
In practice, CPT rarely uses pure domain data. Instead, a mixed corpus is constructed:
where is the domain data mixing ratio. Typical values range from to . The general data component acts as a replay buffer that mitigates catastrophic forgetting.
BloombergGPT used (363B financial tokens / 708B total tokens). Code Llama used (85% code, 8% code-about-code, 7% general). The optimal depends on domain distance from the pretraining distribution and corpus size.
Learning Rate Scheduling for CPT
The learning rate schedule for CPT differs critically from initial pretraining. Let be the peak learning rate used during initial pretraining. For CPT, the typical approach is:
The CPT learning rate is typically 10-100x lower than the initial pretraining LR. The schedule usually follows cosine annealing:
where is the total number of CPT steps and is the minimum learning rate (often ).
Recent work by Gupta et al. (2023) suggests that re-warming the learning rate (briefly increasing it at the start of CPT) can help when the initial pretraining used cosine decay to near-zero. The warmup phase is typically short -- 1-5% of total CPT steps.
Compute Cost Estimation
The compute required for CPT scales linearly with corpus size and model parameters:
where is the number of model parameters and is the number of training tokens. For an 8B-parameter model trained on 50B domain tokens:
On an NVIDIA A100 80GB (achieving ~300 TFLOPS with bf16), this takes approximately:
On a cluster of 32 A100 GPUs: approximately 69 hours (about 3 days). At cloud rates of 4,444 (~INR 3.7 lakh)**.
Catastrophic Forgetting Formalization
Catastrophic forgetting can be formalized as performance degradation on a general benchmark after CPT:
Strategies to minimize include:
- Data replay: Including general data in the training mix (controlled by )
- Elastic Weight Consolidation (EWC): Adding a regularization term where is the Fisher information for parameter
- Low learning rate: Using to constrain parameter drift
- Gradient projection: Projecting gradients to avoid interfering with important directions for general capabilities
Practical Rule: For most CPT setups, a mixing ratio of (70% domain, 30% general replay) with a learning rate of provides a good starting point. Monitor both domain perplexity (should decrease) and general benchmark scores (should not decrease by more than 2-3%).
Internal Architecture
The architecture of a continued pretraining pipeline involves several interconnected stages, from corpus preparation through training execution to validation. Unlike initial pretraining, CPT requires careful management of the relationship between the pretrained model's existing knowledge and the new domain data.
The following diagram illustrates the end-to-end CPT pipeline, showing how raw domain data flows through curation, tokenization, mixing with general replay data, and into the training loop. The key architectural decision points are the data mixing ratio, learning rate schedule, and checkpoint validation strategy.

The validation stage is particularly critical for CPT: you need to track both domain-specific improvement (decreasing perplexity on held-out domain text) and general capability retention (stable scores on benchmarks like MMLU, HellaSwag). The best checkpoint is typically not the final one -- it's the one that maximizes domain performance while keeping general regression below a threshold.
Key Components
Domain Corpus Curator
Responsible for sourcing, cleaning, deduplicating, and filtering the domain-specific training data. This is the most labor-intensive component of the CPT pipeline. Key operations include MinHash deduplication (removing near-duplicate documents that inflate the corpus without adding information), perplexity-based quality filtering (using a small reference model to filter out noisy, machine-generated, or off-topic text), and toxicity/PII removal (ensuring the corpus doesn't inject harmful biases into the model). For Indian language corpora, this also involves script normalization and transliteration handling.
Tokenizer Adapter
Handles tokenizer adaptation when the domain vocabulary differs significantly from the pretraining vocabulary. For example, a model pretrained on English web text will tokenize Hindi text into many small subword fragments (3-4 tokens per Hindi word vs. 1 token per English word). The tokenizer adapter can extend the vocabulary with domain-specific tokens and initialize their embeddings via averaging or interpolation from existing token embeddings. This component is optional but critical for cross-lingual CPT and specialized domains like chemistry (SMILES notation) or law (statute references).
Data Mixer
Combines domain-specific data with general replay data according to a configurable mixing ratio . The mixer must handle different data formats, ensure balanced sampling across domain sub-categories (e.g., for medical CPT: clinical notes, research papers, drug labels, patient education material), and optionally implement curriculum learning -- starting with a higher proportion of general data and gradually increasing domain data concentration. The DoReMi algorithm can be used to automatically optimize these mixing weights.
Learning Rate Scheduler
Controls the learning rate trajectory during CPT. Unlike initial pretraining (which uses a standard warmup + cosine decay from a high LR), CPT typically starts from a much lower LR (1-10% of the initial pretraining peak) with optional brief re-warming. The scheduler must handle the transition from the pretrained model's final learning rate state to the CPT regime. Research suggests that cosine annealing is preferred over WSD (Warmup-Stable-Decay) for CPT when the base model was trained with cosine annealing.
Distributed Training Engine
Manages the actual training computation across multiple GPUs/nodes. For CPT of 7B+ parameter models, this requires data parallelism (FSDP or DeepSpeed ZeRO) and optionally tensor or pipeline parallelism. Key frameworks include Megatron-LM (NVIDIA's high-performance training framework), nanotron (Hugging Face's minimalist parallelism library), and torchtitan (PyTorch-native distributed training). The engine handles gradient accumulation, mixed-precision training (bf16), and checkpoint saving.
Forgetting Monitor
Continuously evaluates the model on general-capability benchmarks during training to detect catastrophic forgetting early. Runs evaluation on a subset of benchmarks (e.g., MMLU, ARC, HellaSwag) every few thousand steps. If general performance drops below a threshold (typically 2-3%), the monitor can trigger corrective actions: increasing the general data mixing ratio, reducing the learning rate, or applying EWC regularization. This early warning system prevents wasting compute on a forgetting trajectory.
Checkpoint Selector
Evaluates saved checkpoints on both domain-specific and general benchmarks to identify the optimal stopping point. CPT training loss alone is not sufficient for checkpoint selection because a model can achieve low domain perplexity while having catastrophically forgotten general knowledge. The selector computes a composite score that balances domain improvement against general regression, enabling principled selection of the best checkpoint for downstream use.
Data Flow
Corpus Preparation Phase: Raw domain text is collected from domain-specific sources (academic papers, legal documents, clinical notes, code repositories, etc.). The text passes through deduplication (MinHash for near-duplicate removal), quality filtering (perplexity scoring against a reference model, heuristic rules for formatting), and optionally toxicity/PII filtering. The cleaned text is tokenized -- either with the original tokenizer or an adapted tokenizer with domain-specific vocabulary extensions.
Training Phase: The tokenized domain corpus is combined with a general replay corpus via the data mixer. Each training batch contains a mix of domain and general tokens at the configured ratio . The pretrained model processes these batches through the standard autoregressive forward pass, computing cross-entropy loss on next-token predictions. Gradients are computed, scaled by the learning rate from the cosine schedule, and applied to update all model parameters. Gradient accumulation is used when the effective batch size exceeds single-GPU memory. Checkpoints are saved periodically (every 1,000-5,000 steps).
Validation Phase: At each checkpoint, the forgetting monitor evaluates domain perplexity (lower is better) and general benchmark scores (should remain stable). The checkpoint selector aggregates these metrics to rank checkpoints. The best checkpoint -- typically from the middle-to-late phase of training -- is selected as the final domain-adapted model. This model then proceeds to fine-tuning (SFT, LoRA, RLHF) for task-specific adaptation.
A flowchart showing four stages: (1) Data Preparation -- raw domain corpus flows through deduplication, quality filtering, and tokenization. (2) Data Mixing -- tokenized domain data is combined with general replay data at a configurable ratio. (3) CPT Training Loop -- the pretrained base model is trained on the mixed corpus with a cosine-decaying learning rate, producing periodic checkpoints. (4) Validation -- each checkpoint is evaluated on domain perplexity and general benchmarks, with the best checkpoint selected as the final domain-adapted model.
How to Implement
Two Primary Implementation Approaches
There are two main ways to implement continued pretraining in practice:
Approach 1: Hugging Face Transformers + DeepSpeed/FSDP -- The most accessible option for teams already using the Hugging Face ecosystem. You load a pretrained model, configure a Trainer with a domain dataset and appropriate hyperparameters, and leverage DeepSpeed ZeRO or PyTorch FSDP for multi-GPU training. This approach works well for models up to 13B parameters on 2-8 GPUs.
Approach 2: Megatron-LM / nanotron -- For larger models (30B+) or when you need maximum training throughput, dedicated pretraining frameworks offer optimized kernels, tensor parallelism, and pipeline parallelism. NVIDIA's Megatron-LM is the industry standard for high-performance pretraining, while Hugging Face's nanotron provides a more Pythonic alternative with similar performance.
The critical implementation decisions for CPT are:
- Data mixing: How to combine domain and general data, and at what ratio
- Learning rate: Starting LR, warmup strategy, and decay schedule
- Tokenizer: Whether to adapt the tokenizer for domain vocabulary
- Evaluation: How to monitor for catastrophic forgetting during training
Cost Note: Continued pretraining of Llama 3 8B on 50B tokens takes approximately 2,200 GPU-hours on A100 80GB. At cloud rates, that's ~3,300 (~INR 2.8 lakh) on Indian cloud providers like E2E Networks. For comparison, the initial pretraining of Llama 3 8B on 15T tokens cost an estimated ~$2M. CPT at 50B tokens is roughly 0.2% of the original pretraining cost for meaningful domain adaptation.
import os
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling,
)
from datasets import load_dataset, interleave_datasets
# Load base model and tokenizer
model_name = "meta-llama/Llama-3.1-8B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
attn_implementation="flash_attention_2",
)
# Load and prepare domain corpus
domain_dataset = load_dataset(
"json",
data_files="/data/medical_corpus_cleaned/*.jsonl",
split="train",
)
# Load general replay data (for catastrophic forgetting mitigation)
general_dataset = load_dataset(
"HuggingFaceFW/fineweb-edu",
name="sample-10BT",
split="train",
streaming=True,
)
# Tokenization function
def tokenize_fn(examples, max_length=2048):
return tokenizer(
examples["text"],
truncation=True,
max_length=max_length,
return_special_tokens_mask=True,
)
domain_tokenized = domain_dataset.map(
tokenize_fn,
batched=True,
remove_columns=domain_dataset.column_names,
num_proc=os.cpu_count(),
)
# Interleave domain (70%) and general (30%) data
mixed_dataset = interleave_datasets(
[domain_tokenized, general_dataset.map(tokenize_fn, batched=True)],
probabilities=[0.7, 0.3], # 70% domain, 30% general replay
seed=42,
)
# Data collator for causal LM (next-token prediction)
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False, # Causal LM, not masked LM
)
# Training arguments for CPT
training_args = TrainingArguments(
output_dir="./llama3-8b-medical-cpt",
max_steps=50000, # ~50B tokens with batch size ~1M tokens/step
per_device_train_batch_size=2,
gradient_accumulation_steps=32, # Effective batch = 2 * 32 * 8 GPUs = 512
learning_rate=3e-5, # ~10x lower than Llama 3 pretraining LR (3e-4)
lr_scheduler_type="cosine",
warmup_steps=500, # Brief re-warming (1% of total steps)
weight_decay=0.1,
bf16=True,
gradient_checkpointing=True,
logging_steps=10,
save_strategy="steps",
save_steps=2500,
eval_strategy="steps",
eval_steps=2500,
dataloader_num_workers=4,
deepspeed="ds_config_zero3.json", # DeepSpeed ZeRO Stage 3
)
# Train
trainer = Trainer(
model=model,
args=training_args,
train_dataset=mixed_dataset,
data_collator=data_collator,
tokenizer=tokenizer,
)
trainer.train()
# Save the domain-adapted model
trainer.save_model("./llama3-8b-medical-cpt-final")
tokenizer.save_pretrained("./llama3-8b-medical-cpt-final")This is a production-ready recipe for continued pretraining using Hugging Face Transformers with DeepSpeed. Key decisions:
- learning_rate=3e-5: Llama 3 was pretrained with a peak LR of ~3e-4. We use 10x lower (3e-5) for CPT to avoid destructive updates to the pretrained representations.
- 70/30 domain/general mix: The
interleave_datasetswithprobabilities=[0.7, 0.3]ensures 70% of each batch is domain data while 30% is general replay data to prevent catastrophic forgetting. - warmup_steps=500: Brief re-warming (1% of total steps) helps the optimizer adapt to the new data distribution without large initial loss spikes.
- DeepSpeed ZeRO Stage 3: Shards model parameters, gradients, and optimizer states across GPUs, enabling 8B model CPT on 8x A100 80GB GPUs with reasonable batch sizes.
- gradient_checkpointing=True: Reduces activation memory at the cost of ~30% slower training -- essential for fitting long sequences.
- flash_attention_2: Uses Flash Attention for memory-efficient and faster attention computation.
import hashlib
from dataclasses import dataclass
from pathlib import Path
from typing import Iterator
import json
from datasketch import MinHash, MinHashLSH
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import numpy as np
@dataclass
class Document:
text: str
source: str
language: str
quality_score: float = 0.0
class CorpusCurator:
"""Pipeline for curating domain corpus for continued pretraining."""
def __init__(
self,
reference_model_name: str = "gpt2", # Small model for perplexity scoring
quality_threshold: float = 50.0, # Max perplexity for quality filter
min_doc_length: int = 100, # Minimum tokens per document
max_doc_length: int = 100000, # Maximum tokens per document
dedup_threshold: float = 0.8, # MinHash Jaccard similarity threshold
):
self.quality_threshold = quality_threshold
self.min_doc_length = min_doc_length
self.max_doc_length = max_doc_length
self.dedup_threshold = dedup_threshold
# Load reference model for perplexity scoring
self.tokenizer = AutoTokenizer.from_pretrained(reference_model_name)
self.model = AutoModelForCausalLM.from_pretrained(
reference_model_name,
torch_dtype=torch.float16,
device_map="auto",
)
self.model.eval()
# MinHash LSH for deduplication
self.lsh = MinHashLSH(threshold=dedup_threshold, num_perm=128)
self.seen_hashes = set()
def compute_perplexity(self, text: str, max_length: int = 512) -> float:
"""Compute perplexity of text using reference model."""
encodings = self.tokenizer(
text, return_tensors="pt", truncation=True, max_length=max_length
)
input_ids = encodings.input_ids.to(self.model.device)
with torch.no_grad():
outputs = self.model(input_ids, labels=input_ids)
loss = outputs.loss.item()
return np.exp(loss)
def compute_minhash(self, text: str) -> MinHash:
"""Compute MinHash signature for near-duplicate detection."""
mh = MinHash(num_perm=128)
# Use 5-gram shingles
words = text.lower().split()
for i in range(len(words) - 4):
shingle = " ".join(words[i : i + 5])
mh.update(shingle.encode("utf-8"))
return mh
def is_duplicate(self, doc_id: str, minhash: MinHash) -> bool:
"""Check if document is a near-duplicate of any seen document."""
result = self.lsh.query(minhash)
if result:
return True
self.lsh.insert(doc_id, minhash)
return False
def curate(self, documents: Iterator[Document]) -> Iterator[Document]:
"""Full curation pipeline: length filter -> dedup -> quality filter."""
stats = {"total": 0, "length_filtered": 0, "deduped": 0, "quality_filtered": 0, "kept": 0}
for doc in documents:
stats["total"] += 1
# Step 1: Length filter
token_count = len(self.tokenizer.encode(doc.text))
if token_count < self.min_doc_length or token_count > self.max_doc_length:
stats["length_filtered"] += 1
continue
# Step 2: Near-duplicate removal
doc_id = hashlib.md5(doc.text[:1000].encode()).hexdigest()
minhash = self.compute_minhash(doc.text)
if self.is_duplicate(doc_id, minhash):
stats["deduped"] += 1
continue
# Step 3: Quality filtering via perplexity
perplexity = self.compute_perplexity(doc.text)
if perplexity > self.quality_threshold:
stats["quality_filtered"] += 1
continue
doc.quality_score = 1.0 / perplexity # Higher is better
stats["kept"] += 1
yield doc
print(f"Curation stats: {json.dumps(stats, indent=2)}")
# Usage
curator = CorpusCurator(
reference_model_name="gpt2",
quality_threshold=80.0,
dedup_threshold=0.8,
)
# Example: curate medical documents
raw_docs = (
Document(text=line.strip(), source="pubmed", language="en")
for line in open("/data/raw_medical_text.txt")
if line.strip()
)
with open("/data/medical_corpus_curated.jsonl", "w") as f:
for doc in curator.curate(raw_docs):
f.write(json.dumps({"text": doc.text, "source": doc.source}) + "\n")This corpus curation pipeline implements the three critical stages of data preparation for CPT:
-
Length filtering: Removes documents that are too short (likely noise) or too long (likely data dumps). The thresholds (100-100K tokens) should be adjusted per domain.
-
Near-duplicate removal via MinHash LSH: Uses locality-sensitive hashing to efficiently detect near-duplicate documents. This is essential because web-scraped domain corpora often contain many copies of the same content (e.g., the same medical guideline published on 50 different websites). Deduplication can remove 20-40% of a raw corpus while improving training efficiency.
-
Perplexity-based quality filtering: Uses a small reference model (GPT-2) to score document quality. Documents with very high perplexity are likely noisy, machine-generated, or not genuine domain text. This is the same approach used in FineWeb and other high-quality pretraining datasets.
The pipeline is designed for streaming (Iterator-based) to handle corpora that don't fit in memory.
from transformers import AutoTokenizer, AutoModelForCausalLM
from collections import Counter
import torch
import json
def analyze_tokenizer_efficiency(tokenizer, domain_texts: list[str]) -> dict:
"""Measure how efficiently the tokenizer handles domain text."""
total_chars = 0
total_tokens = 0
oov_fragments = Counter() # Track frequently split words
for text in domain_texts:
total_chars += len(text)
tokens = tokenizer.encode(text)
total_tokens += len(tokens)
# Identify words that get split into many subwords
words = text.split()
for word in words:
subtokens = tokenizer.encode(word, add_special_tokens=False)
if len(subtokens) > 3: # Word split into 4+ tokens
oov_fragments[word] += 1
return {
"chars_per_token": total_chars / total_tokens,
"fertility": total_tokens / (total_chars / 5), # Approx words
"top_fragmented_words": oov_fragments.most_common(50),
}
def extend_tokenizer(
base_tokenizer_name: str,
domain_texts: list[str],
new_vocab_size: int = 1000,
output_dir: str = "./extended_tokenizer",
) -> AutoTokenizer:
"""Extend tokenizer with domain-specific vocabulary."""
from tokenizers import Tokenizer
from tokenizers.trainers import BpeTrainer
base_tokenizer = AutoTokenizer.from_pretrained(base_tokenizer_name)
original_vocab_size = len(base_tokenizer)
# Train a small BPE tokenizer on domain text to find new tokens
domain_tokenizer = Tokenizer.from_pretrained(base_tokenizer_name)
trainer = BpeTrainer(
vocab_size=original_vocab_size + new_vocab_size,
special_tokens=list(base_tokenizer.special_tokens_map.values()),
)
# Find tokens that appear frequently in domain but not in base vocab
word_freq = Counter()
for text in domain_texts:
for word in text.split():
subtokens = base_tokenizer.encode(word, add_special_tokens=False)
if len(subtokens) >= 3: # Only consider heavily fragmented words
word_freq[word] += 1
# Add top domain-specific tokens to the base tokenizer
new_tokens = [word for word, count in word_freq.most_common(new_vocab_size) if count >= 10]
num_added = base_tokenizer.add_tokens(new_tokens)
print(f"Added {num_added} domain-specific tokens to vocabulary")
base_tokenizer.save_pretrained(output_dir)
return base_tokenizer, num_added
def resize_model_embeddings(
model: AutoModelForCausalLM,
new_vocab_size: int,
init_strategy: str = "mean", # 'mean', 'random', or 'nearest'
) -> AutoModelForCausalLM:
"""Resize model embeddings and initialize new token embeddings."""
old_embeddings = model.get_input_embeddings().weight.data
old_vocab_size = old_embeddings.shape[0]
# Resize the model's embedding layers
model.resize_token_embeddings(new_vocab_size)
if init_strategy == "mean":
# Initialize new embeddings as the mean of existing embeddings
mean_embedding = old_embeddings.mean(dim=0)
with torch.no_grad():
model.get_input_embeddings().weight.data[old_vocab_size:] = mean_embedding
# Also update the output (lm_head) embeddings
if hasattr(model, "lm_head"):
model.lm_head.weight.data[old_vocab_size:] = mean_embedding
elif init_strategy == "nearest":
# Initialize each new token embedding from its most similar existing token
with torch.no_grad():
for i in range(old_vocab_size, new_vocab_size):
# Use random projection to find nearest neighbor
random_idx = torch.randint(0, old_vocab_size, (5,))
model.get_input_embeddings().weight.data[i] = (
old_embeddings[random_idx].mean(dim=0)
)
print(f"Resized embeddings: {old_vocab_size} -> {new_vocab_size}")
return model
# Example usage
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B")
# Load domain texts
with open("/data/medical_sample.txt") as f:
medical_texts = f.readlines()
# Analyze current tokenizer efficiency
stats = analyze_tokenizer_efficiency(tokenizer, medical_texts)
print(f"Current chars/token: {stats['chars_per_token']:.2f}")
print(f"Top fragmented words: {stats['top_fragmented_words'][:10]}")
# Extend tokenizer with medical vocabulary
extended_tokenizer, n_new = extend_tokenizer(
"meta-llama/Llama-3.1-8B",
medical_texts,
new_vocab_size=500,
)
# Resize model embeddings
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B")
model = resize_model_embeddings(
model,
len(extended_tokenizer),
init_strategy="mean",
)Tokenizer adaptation is often overlooked in CPT but can significantly impact both training efficiency and model quality. This code demonstrates three critical operations:
-
Efficiency analysis: Measures how well the base tokenizer handles domain text. A general English tokenizer applied to Hindi text might produce 3-4 tokens per word (fertility > 2.0), compared to 1.2-1.5 for English. High fertility means the model needs more tokens to process the same information, increasing compute cost and reducing effective context length.
-
Vocabulary extension: Identifies frequently fragmented domain terms (e.g., medical terms like "electrocardiography" or Hindi words in Latin script) and adds them as single tokens. Adding 500-1000 domain tokens can reduce token count by 10-20%, directly reducing training cost.
-
Embedding initialization: New token embeddings must be initialized carefully. Mean initialization (averaging all existing embeddings) is the safest default. The model will learn proper embeddings during CPT. Never use random initialization -- it creates large gradient spikes that can destabilize early training.
# Continued pretraining configuration (YAML format)
model:
name: meta-llama/Llama-3.1-8B
dtype: bfloat16
attention: flash_attention_2
data:
domain_corpus: /data/medical_corpus_curated/
general_replay: HuggingFaceFW/fineweb-edu
mixing_ratio: 0.7 # 70% domain, 30% general
max_seq_length: 4096
num_workers: 8
tokenizer:
extend_vocab: false # Set true if domain needs new tokens
new_tokens: 0
training:
total_tokens: 50_000_000_000 # 50B tokens
per_device_batch_size: 2
gradient_accumulation_steps: 32
effective_batch_size: 512 # per_device * grad_accum * num_gpus
learning_rate: 3e-5 # 10x lower than pretraining LR
min_learning_rate: 3e-6
lr_scheduler: cosine
warmup_steps: 500
weight_decay: 0.1
gradient_checkpointing: true
bf16: true
distributed:
backend: deepspeed
zero_stage: 3
num_gpus: 8
evaluation:
eval_steps: 2500
domain_metrics:
- domain_perplexity
- medical_qa_accuracy
general_metrics:
- mmlu_5shot
- hellaswag
- arc_challenge
forgetting_threshold: 0.03 # Alert if general drops > 3%
checkpointing:
save_steps: 2500
keep_last_n: 5Common Implementation Mistakes
- ●
Using the initial pretraining learning rate: The most common and destructive mistake. If Llama 3 was pretrained with LR=3e-4, using the same LR for CPT will catastrophically overwrite the pretrained representations. Use 10-100x lower LR (3e-6 to 3e-5) for CPT. When in doubt, start lower and increase.
- ●
Training on pure domain data without replay: Without mixing in 20-30% general data, the model will catastrophically forget its general capabilities within a few thousand steps. The model becomes a domain savant that can't perform basic tasks. Always include a general replay buffer.
- ●
Insufficient deduplication of the domain corpus: Web-scraped domain corpora are notoriously duplicated. Training on a corpus where 30% of documents are near-duplicates wastes compute and can cause the model to memorize specific passages rather than learn generalizable domain knowledge. Always run MinHash deduplication before CPT.
- ●
Extending the tokenizer without sufficient CPT: If you add 1,000 new tokens to the vocabulary, those tokens start with random or mean-initialized embeddings. You need enough continued pretraining (at minimum 1B tokens) for the model to learn proper representations for the new tokens. Adding tokens and then only fine-tuning on 50K examples will produce garbage for those tokens.
- ●
Ignoring evaluation during training: CPT without periodic evaluation on general benchmarks is flying blind. By the time you notice catastrophic forgetting post-training, you've wasted all the compute. Evaluate every 2,000-5,000 steps on both domain and general metrics.
- ●
Using too small a corpus: CPT needs substantial data to be effective. For an 8B model, a minimum of 5-10B domain tokens is recommended. With less than 1B tokens, you're better off using LoRA fine-tuning on domain instruction data. CPT with insufficient data overfits to the small corpus and provides minimal benefit.
When Should You Use This?
Use When
You need the model to deeply understand domain-specific terminology and concepts -- not just pattern-match on keywords but reason about domain relationships (e.g., understanding that drug interactions affect dosage recommendations in medical contexts)
Your domain has a large unlabeled corpus (10B+ tokens) that is substantially different from general web text -- legal documents, scientific papers, financial filings, code in specific frameworks, multilingual text
You need better performance than RAG for tasks requiring deep domain reasoning, where the context window is insufficient to include all relevant knowledge at inference time
You're building a foundation model for a specific vertical that will be fine-tuned for multiple downstream tasks within that domain (e.g., a medical foundation model for diagnosis, drug interaction, clinical note summarization)
The target domain has specialized vocabulary that the base tokenizer handles poorly -- Indian languages, chemical formulas, legal citations, code in niche programming languages
You have the compute budget for training on billions of tokens -- CPT is not cheap, but it's 100-1000x cheaper than pretraining from scratch
Your task requires the model to generate domain-fluent text rather than just answer questions -- domain fluency requires the kind of deep representation alignment that only CPT provides
Avoid When
Your domain corpus is small (<1B tokens) -- LoRA fine-tuning or RAG will be more cost-effective and less risky than CPT on insufficient data
You need factual recall of specific documents rather than general domain understanding -- RAG is better suited for retrieving and referencing specific passages from a corpus
Your budget is tightly constrained and you need results quickly -- CPT requires significant compute (thousands of GPU-hours) and careful hyperparameter tuning, while RAG or LoRA can be deployed in hours
The domain is rapidly changing (e.g., current news, stock prices) -- CPT bakes knowledge into weights as a static snapshot, while RAG can incorporate real-time information
The base model already performs well on your domain tasks -- test the base model first; if it achieves 90%+ of target performance with good prompting, CPT may offer diminishing returns
You lack infrastructure for distributed training -- CPT of 7B+ models requires multi-GPU setups with FSDP or DeepSpeed, which adds operational complexity
Key Tradeoffs
The CPT Investment Calculus
Continued pretraining is the most compute-intensive model adaptation technique short of pretraining from scratch. The core tradeoff is upfront investment vs. downstream returns.
| Approach | Compute Cost | Knowledge Depth | Forgetting Risk | Time to Deploy | Best For |
|---|---|---|---|---|---|
| Prompting | ~$0 | Surface level | None | Hours | Quick prototyping |
| RAG | ~$100-1K | Access to documents | None | Days | Factual recall, dynamic knowledge |
| LoRA Fine-tuning | ~$50-500 | Task-specific behavior | Low | Days | Instruction following, formatting |
| CPT | ~$2K-50K | Deep domain understanding | Medium | Weeks | Domain foundation models |
| Pretraining from scratch | ~$100K-10M+ | Full control | N/A | Months | Novel architectures, unique data |
Data Quality vs. Quantity
A critical tradeoff in CPT is data quality versus quantity. BloombergGPT trained on 363B financial tokens -- a massive corpus. But research shows that curated smaller corpora can outperform raw larger ones. The AWS team found that training on just 10% of a financial corpus (selected via domain-adaptive data selection) outperformed training on the full corpus.
The practical implication: invest more time in corpus curation (deduplication, quality filtering, balanced sampling) rather than simply collecting more data.
Forgetting vs. Adaptation
Every CPT run involves a tension between adapting to the new domain and retaining general capabilities. The mixing ratio is the primary control:
| Mixing Ratio (domain %) | Domain Gain | General Forgetting | Use Case |
|---|---|---|---|
| 50% | Moderate | Very low | Conservative adaptation |
| 70% | High | Low-moderate | Standard recommendation |
| 90% | Very high | Moderate-high | Deep domain specialization |
| 100% | Maximum | High | Domain-only model (rare) |
Practitioner's Note: If you're unsure about the right mixing ratio, start with 70% domain / 30% general. This has consistently produced good results across medical, legal, financial, and code domains. Only go above 80% domain if you've validated that general capability retention meets your requirements.
Alternatives & Comparisons
Domain adaptation at the feature level (e.g., adversarial domain adaptation, domain-invariant representations) operates on the representation space without modifying model weights through additional training. It's faster but provides shallower adaptation than CPT. Choose feature-level domain adaptation for quick deployment with pretrained embeddings; choose CPT when you need the model to fundamentally understand domain concepts and generate domain-fluent text.
LoRA fine-tuning adapts the model to specific tasks by training small low-rank matrices, typically on labeled instruction data. It changes model behavior but not its fundamental knowledge. LoRA is 10-100x cheaper than CPT and ideal for task-specific adaptation. Choose LoRA when the base model already has sufficient domain knowledge; choose CPT when the model needs to learn new domain concepts before any task-specific training.
Full fine-tuning updates all model parameters on task-specific data. While it can inject some domain knowledge, it's typically done on smaller labeled datasets (10K-500K examples) and uses higher learning rates than CPT. CPT uses billions of unlabeled tokens with the pretraining objective, providing broader knowledge injection. Choose full fine-tuning for task-specific performance with labeled data; choose CPT for broad domain knowledge with unlabeled corpora.
Instruction tuning teaches the model to follow instructions in a specific format, typically using supervised fine-tuning on instruction-response pairs. It's a downstream step that usually follows CPT. The standard pipeline for domain LLMs is: base model -> CPT (domain knowledge) -> instruction tuning (task behavior). Choose instruction tuning alone when the base model has sufficient knowledge; use CPT + instruction tuning when domain knowledge needs enhancement.
Knowledge distillation transfers capabilities from a larger teacher model to a smaller student model. It can be used in conjunction with CPT -- for example, distilling a 70B domain-adapted model into a 7B model. Choose distillation when you need a smaller, faster model with domain expertise; choose CPT when you're building the initial domain-adapted model that will serve as the teacher.
Pros, Cons & Tradeoffs
Advantages
Deep domain knowledge injection: CPT fundamentally rewires the model's internal representations, encoding domain concepts as first-class knowledge rather than surface-level pattern matching. This produces models that reason about domain concepts more accurately and generate more fluent domain text.
Improved downstream task performance: Models that undergo CPT before fine-tuning consistently outperform models that are directly fine-tuned on domain tasks. The DAPT paper showed 3-8 F1 point improvements, and subsequent work has shown even larger gains with LLMs.
Better tokenizer efficiency for domain text: When combined with tokenizer adaptation, CPT reduces the token count for domain text by 10-20%, effectively increasing the model's context window for domain content and reducing both training and inference costs.
Foundation for multiple downstream tasks: A single CPT run produces a domain-adapted base model that can be fine-tuned for many different tasks within the domain. This amortizes the CPT cost across all downstream applications -- a medical CPT model serves diagnosis, summarization, coding, and Q&A.
Reduced hallucination in domain contexts: Models with CPT-injected domain knowledge hallucinate less on domain topics because the correct information is encoded in the weights, reducing reliance on pattern-matching or retrieval.
Cost-effective compared to pretraining from scratch: CPT achieves comparable domain performance to training a domain-specific model from scratch (like BloombergGPT) at 0.1-1% of the compute cost, because it starts from a model that already has strong general capabilities.
Compatible with all downstream adaptation methods: The output of CPT is a standard pretrained model checkpoint that works with LoRA, QLoRA, full fine-tuning, instruction tuning, RLHF, DPO -- any adaptation technique in the standard toolkit.
Disadvantages
High compute cost: CPT requires training on billions of tokens, costing 50-500) or RAG (primarily infrastructure costs).
Risk of catastrophic forgetting: Without careful data mixing and learning rate tuning, CPT can degrade the model's general capabilities. A model that forgets basic arithmetic or loses its ability to follow instructions is worse than the original.
Requires large domain corpus: Effective CPT needs 5-100B+ tokens of high-quality domain text. Many specialized domains (rare diseases, niche engineering fields) simply don't have enough publicly available text. Corpus curation is labor-intensive.
Long iteration cycles: Each CPT experiment takes days to weeks, making hyperparameter tuning slow and expensive. If the learning rate is wrong or the data mix is bad, you discover this after spending thousands of dollars in compute.
Infrastructure complexity: CPT of 7B+ models requires multi-GPU distributed training with FSDP or DeepSpeed, which demands specialized engineering expertise and infrastructure. This is beyond the capability of many smaller teams.
Static knowledge snapshot: Knowledge injected via CPT is frozen at training time. Unlike RAG, CPT cannot incorporate new information without retraining. For rapidly evolving domains, this creates a staleness problem.
Failure Modes & Debugging
Catastrophic Forgetting
Cause
Training on pure domain data or using too high a learning rate, causing the model to overwrite pretrained representations that encode general knowledge. The model's loss on domain data decreases while its ability to perform general tasks (math, reasoning, instruction following) degrades severely.
Symptoms
The model produces fluent domain text but fails at basic tasks: it can't do simple arithmetic, loses its instruction-following ability, or generates incoherent responses to general questions. MMLU scores may drop by 10-20+ points. The model may also lose its ability to switch between domain and general contexts.
Mitigation
Always include 20-30% general data in the training mix as a replay buffer. Use a learning rate 10-100x lower than the original pretraining LR. Monitor general benchmarks (MMLU, HellaSwag, ARC) every 2,000-5,000 training steps. If general scores drop more than 3%, reduce LR or increase the general data ratio. Consider EWC regularization for additional protection.
Domain Corpus Contamination
Cause
The domain corpus contains significant amounts of noise: machine-translated text, OCR errors, boilerplate content, duplicated passages, or text from the wrong domain. The model learns these artifacts instead of genuine domain knowledge.
Symptoms
Domain perplexity decreases during training (the model is learning), but downstream task performance doesn't improve or even degrades. The model may generate text with OCR-like artifacts, unnatural repetition, or boilerplate phrases. Quality on clean domain test sets is poor despite good training metrics.
Mitigation
Invest heavily in corpus curation: run MinHash deduplication (remove 20-40% near-duplicates), perplexity-based quality filtering (remove text that's too noisy or too templated), and manual spot-checking of random samples. Use a staged approach: clean a small high-quality subset first, validate CPT works on it, then scale up.
Tokenizer Mismatch Degradation
Cause
Using the original tokenizer on text with significantly different vocabulary (e.g., Hindi text with an English tokenizer, SMILES chemical notation, or dense medical abbreviations). The tokenizer fragments domain-specific terms into many subwords, reducing effective context length and making it harder for the model to learn domain-specific patterns.
Symptoms
Training is abnormally slow (more tokens needed to represent the same content). The model struggles with domain-specific terminology even after extensive CPT. Inference is expensive because domain text requires 2-3x more tokens than English text of equivalent content. Token-level perplexity may look good, but word-level or character-level perplexity is poor.
Mitigation
Analyze tokenizer efficiency before CPT: compute chars-per-token and fertility metrics on domain text. If fertility is > 2.0 (compared to ~1.3 for English), consider tokenizer adaptation: extend the vocabulary with 500-2000 domain-specific tokens and initialize their embeddings via mean pooling. Budget extra CPT tokens (1B+) for the model to learn the new token embeddings.
Learning Rate Mismatch
Cause
Using a learning rate that's too high (causes forgetting) or too low (wastes compute without meaningful adaptation). This is particularly common when teams copy hyperparameters from fine-tuning recipes (which use higher LR) or from pretraining-from-scratch setups.
Symptoms
Too high LR: Training loss spikes initially and general benchmark scores drop rapidly within the first few thousand steps. The model may become incoherent on both domain and general text. Too low LR: Training loss decreases very slowly, domain benchmarks show minimal improvement after many epochs, and the model's domain capabilities remain similar to the base model.
Mitigation
For an 8B model, start with LR = 3e-5 (assuming pretraining peak was 3e-4). Run a short pilot (1,000 steps) and check: (1) training loss should decrease smoothly, (2) general benchmarks should not drop > 1%. If loss spikes, halve the LR. If loss barely moves after 1,000 steps, double the LR. Use cosine annealing that matches the pretraining schedule.
Insufficient Training Duration
Cause
Stopping CPT too early, before the model has processed enough domain tokens to meaningfully shift its representations. This is common when compute budget is tight and teams try to do CPT with 1-2B tokens on a 7B+ model.
Symptoms
Domain perplexity decreases marginally but downstream task improvements are negligible compared to the base model. The model shows surface-level familiarity with domain terms but lacks deeper understanding -- it uses the right vocabulary but makes conceptual errors. The CPT appears to have been wasted.
Mitigation
As a rule of thumb, plan for at least 5-10B tokens for an 8B model and 20-50B tokens for a 70B model. If budget is constrained, consider using data selection methods (like the AWS ETS-DACP approach) to select the most informative 10% of the corpus, or use LoRA fine-tuning instead of CPT. Monitor the domain perplexity curve -- if it's still decreasing significantly, training hasn't saturated.
Data Mixing Ratio Imbalance
Cause
The domain/general data mixing ratio is poorly calibrated. Too much domain data (>90%) causes forgetting. Too much general data (<50% domain) wastes compute on data the model has already learned, producing minimal domain adaptation.
Symptoms
Too much domain data: General benchmark regression similar to catastrophic forgetting, but more gradual. The model becomes a domain specialist at the expense of general capability. Too much general data: Domain perplexity barely improves, downstream domain tasks show marginal gains, and most of the compute budget is spent re-learning general knowledge the model already knew.
Mitigation
Start with a 70/30 domain/general split. Validate with a pilot run of 5,000 steps, measuring both domain perplexity and general benchmarks. If general benchmarks drop > 3%, increase general replay to 40%. If domain perplexity barely moves, increase domain ratio to 80%. The DoReMi algorithm can automatically optimize these ratios, but requires training a small proxy model first.
Placement in an ML System
Where CPT Fits in the Model Development Lifecycle
Continued pretraining occupies a specific position in the LLM development pipeline: after base model selection and before any task-specific adaptation. The standard pipeline for building a production domain LLM is:
- Base model selection: Choose a pretrained foundation model (e.g., Llama 3 8B, Mistral 7B) based on the target domain, model size, and licensing constraints.
- Corpus curation: Collect, clean, deduplicate, and filter domain-specific text into a high-quality training corpus.
- Continued pretraining: Train the base model on the domain corpus (mixed with general replay data) using the autoregressive language modeling objective.
- Instruction tuning / SFT: Fine-tune the CPT model on domain-specific instruction-response pairs to teach it task-specific behavior.
- Alignment (RLHF/DPO): Optionally align the model with human preferences for safety, helpfulness, and domain-appropriate responses.
- Evaluation and deployment: Benchmark on domain-specific and general evaluations, then deploy to production.
In organizations like Bloomberg, Google (Med-PaLM), or Indian AI companies like Sarvam AI, CPT is managed as a dedicated infrastructure workload with its own compute allocation, data pipelines, and evaluation frameworks. The CPT output -- a domain-adapted base model -- is treated as a shared asset that multiple downstream teams can fine-tune for specific applications.
Organizational Pattern: Large organizations often maintain a "model kitchen" where the CPT team produces domain-adapted base models on a regular cadence (e.g., quarterly), incorporating new domain data. Downstream teams consume these base models and apply task-specific fine-tuning. This separation of concerns between "knowledge injection" (CPT) and "behavior shaping" (fine-tuning) is a key architectural pattern for scalable domain AI development.
Pipeline Stage
Training / Domain Adaptation
Upstream
- Base Model Selection (pretrained checkpoint from model hub)
- Domain Corpus Curation (cleaned, deduplicated, filtered text)
- Tokenizer (original or domain-adapted vocabulary)
Downstream
- Instruction Tuning / SFT
- LoRA or QLoRA Fine-tuning
- RLHF / DPO Alignment
- Model Evaluation & Benchmarking
- Model Registry / Version Control
Scaling Bottlenecks
The primary bottleneck for CPT is raw training compute. For a 70B model on 100B tokens, the compute requirement is approximately 42,000 GPU-hours on A100 80GB -- about $84,000 (~INR 70 lakh) at cloud rates. This makes CPT infeasible for most Indian startups unless they have access to subsidized GPU clusters (e.g., through government programs like India AI Mission or through academic partnerships).
With large-scale CPT, the data pipeline (tokenization, shuffling, batching) can become a bottleneck if not properly parallelized. A training throughput of 50,000 tokens/second per GPU requires the data pipeline to sustain ~400,000 tokens/second for an 8-GPU setup. Pre-tokenizing the corpus and storing it in a memory-mapped format (e.g., using datasets library's memory mapping) is essential.
Running comprehensive evaluation (MMLU, domain benchmarks, perplexity) at every checkpoint adds significant overhead. For a 70B model, a full MMLU evaluation takes 2-4 hours. Evaluate on a subset (e.g., 1,000 MMLU questions instead of 14,000) during training, with full evaluation only at key milestones.
Saving checkpoints for a 70B model produces ~140 GB files (bf16). Saving every 2,500 steps over 50,000 steps generates 20 checkpoints = 2.8 TB of checkpoint data. Budget for this storage and implement a rotation policy that keeps only the last N checkpoints plus milestone checkpoints.
Production Case Studies
Bloomberg trained a 50B-parameter LLM from scratch on a mixed corpus of 363B tokens of financial data (Bloomberg's proprietary archive of financial documents, news, filings, and reports) and 345B tokens of general web data. While technically a from-scratch pretraining rather than continued pretraining, the data mixing strategy -- roughly 51% domain + 49% general -- has become the reference point for CPT mixing ratios in the financial domain.
BloombergGPT significantly outperformed GPT-3 and BLOOM on financial NLP tasks (sentiment analysis, NER, QA on financial documents) while maintaining competitive performance on general benchmarks. The 51/49 mixing ratio demonstrated that domain performance doesn't require sacrificing general capabilities.
Meta produced Code Llama by continued pretraining of Llama 2 on 500B tokens of primarily code data (85% code from GitHub, 8% natural language about code, 7% general text). This is one of the largest and most successful CPT efforts, transforming a general-purpose language model into a state-of-the-art code generation model. Code Llama was then further specialized into Code Llama-Python (additional 100B Python tokens) and Code Llama-Instruct (instruction tuning).
Code Llama 34B achieved 48.8% on HumanEval and 55.0% on MBPP, significantly outperforming Llama 2 34B's baseline code capabilities. The 7B variant was competitive with much larger general models, demonstrating that CPT can effectively specialize a model while keeping it relatively small.
Sarvam AI, an Indian AI startup, used continued pretraining to adapt Llama 2 7B for Hindi and other Indic languages, producing OpenHathi. The model was trained on Hindi, English, and Hinglish data, extending Llama's capabilities to Indian languages at a fraction of the cost of training from scratch. They later released Sarvam 1, a 2B-parameter model supporting 10 Indic languages with a custom tokenizer optimized for Indic scripts.
OpenHathi achieved GPT-3.5-level performance for Hindi language tasks, demonstrating that continued pretraining can effectively extend English-centric LLMs to low-resource languages. The approach proved far more cost-effective than training an Indic LLM from scratch, enabling a startup with limited resources to compete with much larger organizations.
Azerbayev et al. continued pretraining Code Llama on the Proof-Pile-2, a mixture of scientific papers, web data containing mathematics, and mathematical code. The resulting model, Llemma, was trained in two sizes (7B and 34B) and became the first open-source model to outperform Google's Minerva on an equi-parameter basis for mathematical reasoning.
Llemma 34B outperformed Minerva 62B on the MATH benchmark despite being nearly half the size, demonstrating the power of targeted continued pretraining on high-quality domain data. Llemma was also capable of tool use (Python, Sage) and formal theorem proving in Lean/Isabelle without additional fine-tuning.
AWS's applied science team demonstrated efficient data selection for continued pretraining in the financial domain. They showed that training on just 10% of a financial corpus -- selected using domain-adaptive data selection methods (ETS-DACP) -- outperformed training on the full corpus. This addressed the common concern that CPT requires massive domain corpora.
Efficient data selection methods achieved better downstream financial task performance than standard CPT with 10x more data, while also showing that continual pretraining does not adversely affect non-domain performance. This validated the importance of data quality over quantity in CPT.
Tooling & Ecosystem
The most accessible framework for CPT. The Trainer class supports autoregressive language modeling with DeepSpeed and FSDP integration. Combined with datasets for data loading and trl for training recipes, this is the standard starting point for CPT projects. Supports gradient checkpointing, mixed precision, and multi-GPU training out of the box.
High-performance framework for pretraining and continued pretraining of large transformer models. Supports 3D parallelism (data + tensor + pipeline), optimized attention kernels, and efficient data loading. The standard choice for CPT of 30B+ parameter models where training throughput is critical. Used by many organizations building production domain LLMs.
Hugging Face's minimalistic large language model training library with 3D parallelism support. More Pythonic and easier to customize than Megatron-LM, while achieving comparable training throughput. Good for teams that need custom data pipelines or training logic without diving into Megatron-LM's codebase.
Unified fine-tuning and continued pretraining framework with a web UI. Supports pretraining (CLM), SFT, RLHF, DPO, and all PEFT methods. Popular for its low barrier to entry -- you can configure and launch a CPT run through the browser. Supports 100+ model architectures.
Python library for probabilistic data structures, including MinHash and MinHash LSH for near-duplicate detection. Essential for corpus deduplication in the CPT data preparation pipeline. Can process millions of documents efficiently using locality-sensitive hashing.
Allen AI's data curation toolkit used to build the Dolma dataset (3 trillion tokens). Provides tools for deduplication, quality filtering, PII removal, and content classification. Can be adapted for domain corpus curation in CPT pipelines. Used in the creation of OLMo's training data.
Research & References
Gururangan, Marasovic, Swayamdipta, Lo, Beltagy, Downey & Smith (2020)ACL 2020
The foundational paper for continued pretraining. Introduced DAPT (Domain-Adaptive Pretraining) and TAPT (Task-Adaptive Pretraining), showing consistent improvements across four domains and eight tasks when further pretraining RoBERTa on domain-specific text. Established that domain-adaptive pretraining is a simple, effective first step for domain specialization.
Wu, Irsoy, Lu, Daber, Li, Winchester, Treviso, et al. (2023)arXiv 2023
Trained a 50B-parameter LLM on a mixed corpus of 363B financial tokens and 345B general tokens. Demonstrated that a balanced domain/general data mix achieves strong domain performance without sacrificing general capabilities. The data mixing strategy (51% domain) has become a reference point for CPT practitioners.
Roziere, Gehring, Gloeckle, Sootla, Gat, Tan, et al. (2023)arXiv 2023
Produced state-of-the-art code generation models by continued pretraining of Llama 2 on 500B tokens (85% code, 8% code-adjacent, 7% general). Demonstrated multi-stage CPT: base Code Llama -> Code Llama Python (additional Python CPT) -> Code Llama Instruct (instruction tuning). A masterclass in staged domain adaptation.
Azerbayev, Schoelkopf, Paster, Dos Santos, McAleer, Jiang, Deng, Biderman & Welleck (2023)ICLR 2024
Continued pretraining of Code Llama on the Proof-Pile-2 (a mixture of scientific papers, math-heavy web data, and mathematical code). Llemma 34B outperformed the larger Minerva 62B on MATH, demonstrating that targeted CPT on high-quality domain data can be more effective than scale alone.
Xie, Pham, Dong, Du, Liu, Lu, Liang, Le, Ma & Yu (2023)NeurIPS 2023
Proposed an automatic method for optimizing data mixing ratios using a small proxy model trained with group distributionally robust optimization (Group DRO). The optimized mixing weights improved downstream performance by 6.5% and reached baseline accuracy with 2.6x fewer training steps. Directly applicable to optimizing the domain/general data mix in CPT.
Chen, Wu, Ling, Liang & Shen (2024)arXiv 2024
Presented a comprehensive recipe for continued pretraining of Llama 3 8B, achieving +8.81 on C-Eval and +12.00 on MATH through data mixture and curriculum strategies. Showed that synthetic data (scientific QA pairs) combined with web data can significantly enhance specific reasoning abilities during CPT.
Interview & Evaluation Perspective
Common Interview Questions
- ●
What is continued pretraining and how does it differ from fine-tuning?
- ●
How do you choose the data mixing ratio between domain and general data for CPT?
- ●
What strategies would you use to mitigate catastrophic forgetting during continued pretraining?
- ●
Walk me through the end-to-end pipeline for building a domain-specific LLM using continued pretraining.
- ●
How would you estimate the compute cost for continued pretraining of a 70B model on 100B tokens?
- ●
Compare CPT vs RAG vs fine-tuning for injecting domain knowledge into an LLM.
- ●
How should the learning rate schedule differ between initial pretraining and continued pretraining?
- ●
When would you choose to adapt the tokenizer for continued pretraining, and how?
Key Points to Mention
- ●
CPT uses the same self-supervised objective as pretraining (next-token prediction) but starts from pretrained weights and trains on domain-specific data. It changes what the model knows, while fine-tuning changes how it behaves.
- ●
The data mixing ratio is critical: 70% domain / 30% general is a strong default. Too much domain data causes forgetting; too little wastes compute. The DoReMi algorithm can optimize this automatically.
- ●
Learning rate for CPT should be 10-100x lower than initial pretraining LR. For Llama 3 (pretrained at 3e-4), use 3e-6 to 3e-5. Use cosine annealing matching the pretraining schedule.
- ●
Catastrophic forgetting mitigation has four levers: data replay (general data in training mix), lower learning rate, EWC regularization, and continuous evaluation on general benchmarks.
- ●
Compute cost formula: FLOPs = 6 * N * D. For 8B model on 50B tokens: ~2,200 GPU-hours on A100 = ~$4,400 (~INR 3.7 lakh).
- ●
Real-world examples: Code Llama (500B code tokens on Llama 2), BloombergGPT (363B financial tokens), Llemma (math CPT outperforming larger Minerva), Sarvam OpenHathi (Indic language CPT).
- ●
Corpus curation is 50% of the work: deduplication (MinHash), quality filtering (perplexity scoring), balanced sampling across sub-domains, and PII removal.
Pitfalls to Avoid
- ●
Conflating continued pretraining with fine-tuning. CPT uses unsupervised pretraining objective on large unlabeled corpora; fine-tuning uses supervised objectives on smaller labeled datasets. They operate at different layers of the adaptation stack.
- ●
Ignoring catastrophic forgetting. Any answer about CPT that doesn't discuss forgetting mitigation strategies shows a lack of practical experience.
- ●
Claiming that more data is always better. Research (AWS ETS-DACP) shows that curated 10% of the corpus can outperform the full corpus. Data quality trumps quantity.
- ●
Not discussing cost and compute requirements. CPT is expensive, and a senior engineer should be able to estimate GPU-hours and cost for a given configuration.
- ●
Forgetting to mention the tokenizer. For cross-lingual CPT (e.g., English model to Hindi), tokenizer adaptation can be as important as the training itself.
Senior-Level Expectation
A senior/staff engineer should discuss CPT at three levels: (1) Technical: articulate the training objective, learning rate scheduling (why 10-100x lower LR), data mixing strategies (with concrete ratios), and forgetting mitigation (replay, EWC, monitoring). (2) Engineering: cover the end-to-end pipeline from corpus curation (deduplication, quality filtering) through distributed training (FSDP/DeepSpeed configuration, checkpoint management) to evaluation (dual tracking of domain and general metrics). (3) Strategic: reason about when CPT is worth the investment vs. alternatives (RAG, LoRA), estimate compute costs in dollars/INR and GPU-hours for specific configurations, and design an organizational workflow where a CPT team produces domain base models consumed by multiple downstream fine-tuning teams. The ability to discuss real examples (Code Llama's 85/8/7 data split, BloombergGPT's 51/49 mix, Sarvam's Indic language approach) and connect them to specific design decisions demonstrates genuine depth.
Summary
What We Covered
Continued pretraining (CPT) is the practice of further training a pretrained language model on domain-specific data using the self-supervised language modeling objective, bridging the gap between general pretraining and task-specific fine-tuning. It is the primary method for injecting deep domain knowledge into LLMs -- fundamentally rewiring the model's internal representations so that domain concepts, terminology, and reasoning patterns become natively encoded in the weights.
The key technical decisions in CPT are: data mixing ratio (70% domain / 30% general replay is the standard default), learning rate (10-100x lower than initial pretraining, with cosine annealing), corpus curation (deduplication, quality filtering, balanced sampling are more important than raw corpus size), and catastrophic forgetting mitigation (data replay, low learning rate, continuous evaluation on general benchmarks, optionally EWC regularization). Tokenizer adaptation is critical for cross-lingual CPT (e.g., English to Indian languages) and specialized domains where the base tokenizer fragments domain text excessively.
Real-world CPT success stories span multiple domains: Code Llama (500B code tokens, state-of-the-art code generation), BloombergGPT (363B financial tokens, domain-leading financial NLP), Llemma (mathematical reasoning from Proof-Pile-2), and Indian language models like Sarvam AI's OpenHathi and Ola's Krutrim (Indic language adaptation of Llama-family models). The compute cost for CPT is substantial -- approximately $4,400 (~INR 3.7 lakh) for an 8B model on 50B tokens -- but it's 100-1000x cheaper than pretraining from scratch and produces a domain-adapted base model that serves as a foundation for multiple downstream applications. For teams building serious domain AI, CPT is not optional -- it's the foundation that makes everything downstream work better.