RA-LLM Training:檢索增強語言模型的訓練策略
深入理解 Sequential Training 的兩種方法:Retriever First vs LLMs First
目錄
從 RAG 到 RA-LLM:為什麼需要訓練?
我們已經知道 RAG(Retrieval-Augmented Generation)的基本原理:
Query → Retrieve Documents → Generate Answer但在實際應用中,你可能會遇到這些問題:
❌ 問題 1:檢索器找不到真正相關的文檔
- 通用的檢索模型對你的領域術語理解不夠
❌ 問題 2:LLM 不會利用檢索到的資訊
- 預訓練的 LLM 可能忽略提供的上下文
❌ 問題 3:檢索器和 LLM 不協調
- 檢索器找到的資訊,LLM 不知道如何使用
解決方案:RA-LLM Training(檢索增強 LLM 訓練)
本文將深入解析:
- Sequential Training(順序訓練)的概念
- Retriever First(先訓練檢索器)策略
- LLMs First(先訓練語言模型)策略
- 兩種方法的對比與選擇
🎯 什麼是 RA-LLM Training?
核心概念
RA-LLM = Retrieval-Augmented Large Language Model
不只是在推理時使用檢索(RAG),而是在訓練階段就整合檢索能力。
傳統 RAG(推理時整合):
預訓練的 Retriever + 預訓練的 LLM
↓
在推理時組合使用
RA-LLM(訓練時整合):
訓練 Retriever 和/或 LLM
↓
讓它們學會協作
↓
在推理時更好的配合為什麼需要訓練?
場景:醫療領域 RAG 系統
問題:「EGFR 突變的靶向治療方案有哪些?」
使用通用模型(未訓練):
├─ 檢索器:不理解醫療術語
│ └─ 找到:一般性的癌症治療文章 ❌
├─ LLM:不知道如何利用檢索結果
│ └─ 生成:基於預訓練知識的通用答案 ❌
└─ 結果:答案不準確
使用訓練過的 RA-LLM:
├─ 檢索器:理解醫療術語
│ └─ 找到:EGFR 靶向藥物的具體文獻 ✅
├─ LLM:知道如何整合檢索資訊
│ └─ 生成:基於文獻的專業答案 ✅
└─ 結果:準確且專業的回答🔄 Sequential Training(順序訓練)
核心思想
「不同時訓練所有組件,而是按順序逐個訓練」
為什麼要順序訓練?
同時訓練(Joint Training)的問題:
├─ 計算資源需求極高
├─ 訓練不穩定
├─ 難以除錯
└─ 需要大量標註資料
順序訓練(Sequential Training)的優勢:
├─ 分階段,資源需求較低
├─ 每個階段可以獨立優化
├─ 容易除錯和調整
└─ 可以使用不同的訓練資料兩種順序訓練策略
根據圖中的說明,有兩種主要策略:
策略 1:Retriever First(先訓練檢索器)
第一階段:訓練 Retriever
第二階段:固定 Retriever,訓練 LLM
策略 2:LLMs First(先訓練語言模型)
第一階段:訓練 LLM
第二階段:固定 LLM,訓練 Retriever🔍 策略 1:Retriever First(先訓練檢索器)
核心流程
第一階段(上方):訓練 Retriever
┌─────────────────────────────────────┐
│ Retriever → Datastore → Documents │
│ 🔥 │
│ (訓練中) │
└─────────────────────────────────────┘
↓
第二階段(下方):固定 Retriever,訓練 LLM
┌─────────────────────────────────────┐
│ Retriever → Datastore → Documents │
│ ⚙️ │
│ (固定) │
│ │
│ Input → LLM → Output │
│ 🔥 │
│ (訓練中) │
└─────────────────────────────────────┘詳細步驟
階段 1:訓練 Retriever
目標:讓檢索器學會找到真正相關的文檔
# 階段 1:訓練 Retriever
class RetrieverFirstTraining:
"""先訓練檢索器"""
def stage1_train_retriever(self, training_data):
"""
訓練資料格式:
{
'query': '問題',
'positive_docs': ['相關文檔1', '相關文檔2'],
'negative_docs': ['不相關文檔1', '不相關文檔2']
}
"""
print("🔥 階段 1:訓練 Retriever")
retriever = SentenceTransformer('base-model')
# 準備訓練樣本
train_examples = []
for item in training_data:
query = item['query']
# 正例:相關文檔
for pos_doc in item['positive_docs']:
train_examples.append(
InputExample(texts=[query, pos_doc], label=1.0)
)
# 負例:不相關文檔
for neg_doc in item['negative_docs']:
train_examples.append(
InputExample(texts=[query, neg_doc], label=0.0)
)
# 訓練
train_dataloader = DataLoader(train_examples, batch_size=16)
train_loss = losses.CosineSimilarityLoss(retriever)
retriever.fit(
train_objectives=[(train_dataloader, train_loss)],
epochs=3
)
print("✅ Retriever 訓練完成")
return retriever
# 使用
trainer = RetrieverFirstTraining()
trained_retriever = trainer.stage1_train_retriever(retriever_training_data)訓練資料來源:
1. 人工標註
- 為每個查詢標註相關/不相關文檔
- 成本高但效果好
2. 使用者行為資料
- 使用者點擊的文檔 = 相關
- 使用者跳過的文檔 = 不相關
3. LLM 生成
- 用強大的 LLM 評估文檔相關性
- 快速但可能有偏差階段 2:固定 Retriever,訓練 LLM
目標:讓 LLM 學會利用檢索到的資訊
# 階段 2:訓練 LLM
class RetrieverFirstTraining:
def stage2_train_llm(self, retriever, training_data):
"""
訓練資料格式:
{
'query': '問題',
'answer': '正確答案'
}
"""
print("🔥 階段 2:訓練 LLM")
print("⚙️ Retriever 已固定(不再訓練)")
# 固定 Retriever(不更新參數)
retriever.eval()
for param in retriever.parameters():
param.requires_grad = False
# 載入要訓練的 LLM
llm = AutoModelForCausalLM.from_pretrained("base-llm")
# 準備訓練資料
train_examples = []
for item in training_data:
query = item['query']
answer = item['answer']
# 使用固定的 Retriever 檢索
with torch.no_grad():
retrieved_docs = retriever.retrieve(query, k=5)
# 組合成訓練範例
context = '\n'.join(retrieved_docs)
prompt = f"""
根據以下資訊回答問題:
{context}
問題:{query}
答案:
"""
train_examples.append({
'input': prompt,
'output': answer
})
# 訓練 LLM
trainer = Trainer(
model=llm,
train_dataset=train_examples,
# ... 其他訓練參數
)
trainer.train()
print("✅ LLM 訓練完成")
return llm
# 使用
trained_llm = trainer.stage2_train_llm(
retriever=trained_retriever,
training_data=llm_training_data
)Retriever First 的完整實作
class RetrieverFirstRALLM:
"""完整的 Retriever First RA-LLM"""
def __init__(self):
self.retriever = None
self.llm = None
self.datastore = None
def train(self, retriever_data, llm_data, documents):
"""
完整訓練流程
Args:
retriever_data: Retriever 訓練資料
llm_data: LLM 訓練資料
documents: 知識庫文檔
"""
# 建立 Datastore
print("📚 建立 Datastore")
self.datastore = self._build_datastore(documents)
# 階段 1:訓練 Retriever
print("\n" + "="*50)
print("階段 1:訓練 Retriever")
print("="*50)
self.retriever = self._train_retriever(retriever_data)
# 階段 2:訓練 LLM
print("\n" + "="*50)
print("階段 2:訓練 LLM(Retriever 固定)")
print("="*50)
self.llm = self._train_llm(llm_data)
print("\n✅ RA-LLM 訓練完成!")
def _build_datastore(self, documents):
"""建立向量資料庫"""
# 使用基礎模型建立初始向量庫
# 後續會用訓練好的 Retriever 更新
embeddings = embed_documents(documents)
return VectorStore(embeddings, documents)
def _train_retriever(self, training_data):
"""訓練 Retriever"""
retriever = SentenceTransformer('base-model')
# 訓練邏輯(如前所述)
train_examples = self._prepare_retriever_data(training_data)
retriever.fit(train_examples, epochs=3)
# 用訓練好的 Retriever 更新 Datastore
print("🔄 用訓練好的 Retriever 更新 Datastore")
self.datastore.update_embeddings(retriever)
return retriever
def _train_llm(self, training_data):
"""訓練 LLM(Retriever 固定)"""
llm = AutoModelForCausalLM.from_pretrained("base-llm")
# 固定 Retriever
self.retriever.eval()
# 準備訓練資料(使用固定的 Retriever)
train_examples = []
for item in training_data:
# 用訓練好的 Retriever 檢索
docs = self.retriever.retrieve(
item['query'],
datastore=self.datastore
)
# 組合成訓練範例
train_examples.append({
'input': self._format_prompt(item['query'], docs),
'output': item['answer']
})
# 訓練
llm.fit(train_examples, epochs=3)
return llm
def _format_prompt(self, query, documents):
"""格式化 prompt"""
context = '\n\n'.join([f"文檔 {i+1}:{doc}"
for i, doc in enumerate(documents)])
return f"根據以下資訊回答:\n\n{context}\n\n問題:{query}\n\n答案:"
def answer(self, query):
"""使用訓練好的 RA-LLM 回答問題"""
# 1. 用訓練好的 Retriever 檢索
docs = self.retriever.retrieve(query, datastore=self.datastore)
# 2. 用訓練好的 LLM 生成
prompt = self._format_prompt(query, docs)
answer = self.llm.generate(prompt)
return answer
# 使用範例
ra_llm = RetrieverFirstRALLM()
# 訓練
ra_llm.train(
retriever_data=retriever_training_data,
llm_data=llm_training_data,
documents=knowledge_base_documents
)
# 推理
answer = ra_llm.answer("EGFR 突變的靶向治療方案?")
print(answer)Retriever First 的優勢與劣勢
優勢:
- ✅ 檢索品質優先:先確保能找到正確的資訊
- ✅ LLM 訓練更穩定:基於高品質的檢索結果訓練
- ✅ 適合檢索困難的場景:領域術語多、文檔複雜
劣勢:
- ❌ 檢索器錯誤會傳播:如果第一階段訓練不好,影響第二階段
- ❌ 不夠靈活:LLM 只能使用檢索器提供的資訊
適用場景:
✅ 領域特定性強(醫療、法律、科技)
✅ 文檔術語複雜
✅ 檢索準確率是瓶頸
✅ 有足夠的檢索訓練資料🤖 策略 2:LLMs First(先訓練語言模型)
核心流程
第一階段(上方):訓練 LLM
┌─────────────────────────────────────┐
│ Input → LLM → Output │
│ 🔥 │
│ (訓練中) │
└─────────────────────────────────────┘
↓
第二階段(下方):固定 LLM,訓練 Retriever
┌─────────────────────────────────────┐
│ Input → LLM → Output │
│ ⚙️ │
│ (固定) │
│ │
│ Retriever → Datastore → Documents │
│ 🔥 │
│ (訓練中) │
└─────────────────────────────────────┘詳細步驟
階段 1:訓練 LLM
目標:讓 LLM 學會利用外部資訊生成答案
# 階段 1:訓練 LLM
class LLMsFirstTraining:
"""先訓練語言模型"""
def stage1_train_llm(self, training_data):
"""
訓練資料格式:
{
'query': '問題',
'context': '相關文檔(人工提供)',
'answer': '正確答案'
}
注意:這裡的 context 是人工精選的,不是自動檢索的
"""
print("🔥 階段 1:訓練 LLM")
llm = AutoModelForCausalLM.from_pretrained("base-llm")
# 準備訓練資料
train_examples = []
for item in training_data:
# 使用人工提供的高品質 context
prompt = f"""
根據以下資訊回答問題:
{item['context']}
問題:{item['query']}
答案:
"""
train_examples.append({
'input': prompt,
'output': item['answer']
})
# 訓練
trainer = Trainer(
model=llm,
train_dataset=train_examples,
# ... 訓練參數
)
trainer.train()
print("✅ LLM 訓練完成")
print("💡 LLM 已學會如何利用提供的資訊回答問題")
return llm
# 使用
trainer = LLMsFirstTraining()
trained_llm = trainer.stage1_train_llm(llm_training_data)為什麼用人工 context?
原因:
├─ 這個階段重點是讓 LLM 學會「如何使用」資訊
├─ 不是學會「如何找到」資訊
└─ 使用高品質的人工 context 讓訓練更穩定
訓練目標:
LLM 學會:
├─ 從提供的 context 中提取關鍵資訊
├─ 整合多個來源的資訊
├─ 基於 context 生成準確答案
└─ 當 context 不足時,明確表達不確定階段 2:固定 LLM,訓練 Retriever
目標:讓檢索器學會找到 LLM 需要的資訊
# 階段 2:訓練 Retriever
class LLMsFirstTraining:
def stage2_train_retriever(self, llm, training_data, documents):
"""
訓練資料格式:
{
'query': '問題',
'answer': '正確答案' # 由 LLM 期望的答案
}
"""
print("🔥 階段 2:訓練 Retriever")
print("⚙️ LLM 已固定(不再訓練)")
# 固定 LLM
llm.eval()
for param in llm.parameters():
param.requires_grad = False
# 載入要訓練的 Retriever
retriever = SentenceTransformer('base-model')
# 關鍵:使用 LLM 的回饋訓練 Retriever
train_examples = []
for item in training_data:
query = item['query']
expected_answer = item['answer']
# 為這個查詢找候選文檔
candidates = self._get_candidates(query, documents)
# 用固定的 LLM 評估每個候選文檔的有用性
doc_scores = []
for doc in candidates:
# 用這個文檔讓 LLM 生成答案
with torch.no_grad():
prompt = f"根據:{doc}\n問題:{query}\n答案:"
generated = llm.generate(prompt)
# 評估生成的答案與期望答案的相似度
score = self._compute_similarity(generated, expected_answer)
doc_scores.append((doc, score))
# 高分的文檔 = 正例,低分的 = 負例
positive_docs = [doc for doc, score in doc_scores if score > 0.8]
negative_docs = [doc for doc, score in doc_scores if score < 0.3]
# 建立訓練樣本
for pos_doc in positive_docs:
train_examples.append(
InputExample(texts=[query, pos_doc], label=1.0)
)
for neg_doc in negative_docs:
train_examples.append(
InputExample(texts=[query, neg_doc], label=0.0)
)
# 訓練 Retriever
train_dataloader = DataLoader(train_examples, batch_size=16)
train_loss = losses.CosineSimilarityLoss(retriever)
retriever.fit(
train_objectives=[(train_dataloader, train_loss)],
epochs=3
)
print("✅ Retriever 訓練完成")
print("💡 Retriever 已學會找到讓 LLM 能回答好的文檔")
return retriever
def _get_candidates(self, query, documents, k=50):
"""用 BM25 等方法獲取候選文檔"""
# 簡單的基於關鍵字的初篩
return bm25_search(query, documents, k=k)
def _compute_similarity(self, generated, expected):
"""計算生成答案與期望答案的相似度"""
# 可以用 ROUGE, BLEU, 或語義相似度
return semantic_similarity(generated, expected)
# 使用
trained_retriever = trainer.stage2_train_retriever(
llm=trained_llm,
training_data=retriever_training_data,
documents=knowledge_base_documents
)LLMs First 的完整實作
class LLMsFirstRALLM:
"""完整的 LLMs First RA-LLM"""
def __init__(self):
self.llm = None
self.retriever = None
self.datastore = None
def train(self, llm_data, retriever_data, documents):
"""
完整訓練流程
Args:
llm_data: LLM 訓練資料(包含人工 context)
retriever_data: Retriever 訓練資料(查詢-答案對)
documents: 知識庫文檔
"""
# 建立 Datastore(初始版本)
print("📚 建立初始 Datastore")
self.datastore = self._build_initial_datastore(documents)
# 階段 1:訓練 LLM
print("\n" + "="*50)
print("階段 1:訓練 LLM")
print("="*50)
self.llm = self._train_llm(llm_data)
# 階段 2:訓練 Retriever(基於 LLM 的回饋)
print("\n" + "="*50)
print("階段 2:訓練 Retriever(LLM 固定)")
print("="*50)
self.retriever = self._train_retriever(retriever_data, documents)
# 用訓練好的 Retriever 更新 Datastore
print("🔄 用訓練好的 Retriever 更新 Datastore")
self.datastore.update_embeddings(self.retriever)
print("\n✅ RA-LLM 訓練完成!")
def _train_llm(self, training_data):
"""訓練 LLM(使用人工 context)"""
llm = AutoModelForCausalLM.from_pretrained("base-llm")
train_examples = []
for item in training_data:
prompt = f"""
根據以下資訊回答:
{item['context']}
問題:{item['query']}
答案:
"""
train_examples.append({
'input': prompt,
'output': item['answer']
})
llm.fit(train_examples, epochs=3)
return llm
def _train_retriever(self, training_data, documents):
"""訓練 Retriever(基於 LLM 回饋)"""
retriever = SentenceTransformer('base-model')
# 固定 LLM
self.llm.eval()
train_examples = []
for item in training_data:
query = item['query']
expected_answer = item['answer']
# 獲取候選文檔
candidates = bm25_search(query, documents, k=50)
# 用 LLM 評估每個候選
positive_docs = []
negative_docs = []
for doc in candidates:
# LLM 生成答案
prompt = f"根據:{doc}\n問題:{query}\n答案:"
with torch.no_grad():
generated = self.llm.generate(prompt)
# 評估答案品質
score = semantic_similarity(generated, expected_answer)
if score > 0.8:
positive_docs.append(doc)
elif score < 0.3:
negative_docs.append(doc)
# 建立訓練樣本
for doc in positive_docs:
train_examples.append(
InputExample(texts=[query, doc], label=1.0)
)
for doc in negative_docs:
train_examples.append(
InputExample(texts=[query, doc], label=0.0)
)
# 訓練
train_dataloader = DataLoader(train_examples, batch_size=16)
train_loss = losses.CosineSimilarityLoss(retriever)
retriever.fit(
train_objectives=[(train_dataloader, train_loss)],
epochs=3
)
return retriever
def answer(self, query):
"""使用訓練好的 RA-LLM 回答"""
# 1. 用訓練好的 Retriever 檢索
docs = self.retriever.retrieve(query, datastore=self.datastore)
# 2. 用訓練好的 LLM 生成
context = '\n\n'.join(docs)
prompt = f"根據以下資訊回答:\n\n{context}\n\n問題:{query}\n\n答案:"
answer = self.llm.generate(prompt)
return answer
# 使用範例
ra_llm = LLMsFirstRALLM()
# 訓練
ra_llm.train(
llm_data=llm_training_data, # 包含人工 context
retriever_data=retriever_training_data, # 查詢-答案對
documents=knowledge_base_documents
)
# 推理
answer = ra_llm.answer("EGFR 突變的靶向治療方案?")
print(answer)LLMs First 的優勢與劣勢
優勢:
- ✅ LLM 能力優先:先確保 LLM 會用資訊
- ✅ Retriever 目標明確:找 LLM 需要的資訊
- ✅ 更好的協調性:Retriever 專門服務這個 LLM
劣勢:
- ❌ 需要人工 context:第一階段需要高品質標註
- ❌ LLM 訓練成本高:大模型訓練昂貴
- ❌ 第二階段複雜:需要 LLM 評估文檔品質
適用場景:
✅ LLM 能力是瓶頸(不會用檢索結果)
✅ 有高品質的人工標註資料
✅ 計算資源充足(可以訓練大 LLM)
✅ 需要特定的輸出風格或格式📊 兩種策略對比
核心差異
| 維度 | Retriever First | LLMs First |
|---|---|---|
| 第一階段 | 訓練 Retriever | 訓練 LLM |
| 第二階段 | 訓練 LLM | 訓練 Retriever |
| Retriever 目標 | 找相關文檔 | 找 LLM 需要的文檔 |
| LLM 訓練資料 | 檢索結果 | 人工 context |
| 標註需求 | 查詢-文檔對 | 查詢-context-答案 |
| 訓練穩定性 | LLM 可能受檢索影響 | Retriever 有明確目標 |
流程對比
Retriever First:
階段 1: [Train Retriever] → 學會找相關文檔
↓
固定 Retriever
↓
階段 2: [Train LLM] → 學會用檢索結果
優點:檢索品質優先
缺點:檢索錯誤會影響 LLM 訓練
─────────────────────────────────────
LLMs First:
階段 1: [Train LLM] → 學會用資訊回答
↓
固定 LLM
↓
階段 2: [Train Retriever] → 學會找 LLM 需要的
優點:LLM-Retriever 協調性好
缺點:需要人工提供高品質 context訓練資料對比
Retriever First 的資料需求:
// 階段 1:Retriever 訓練資料
{
"query": "EGFR 突變治療",
"positive_docs": [
"EGFR 靶向藥物包括...",
"EGFR 突變患者可使用..."
],
"negative_docs": [
"化療的副作用包括...",
"手術治療的適應症..."
]
}
// 階段 2:LLM 訓練資料
{
"query": "EGFR 突變治療方案",
"answer": "根據指南,EGFR 突變患者可以使用..."
}LLMs First 的資料需求:
// 階段 1:LLM 訓練資料(需要人工 context)
{
"query": "EGFR 突變治療方案",
"context": "根據 NCCN 指南,EGFR 突變患者...", // 人工提供
"answer": "根據指南,建議使用..."
}
// 階段 2:Retriever 訓練資料
{
"query": "EGFR 突變治療方案",
"answer": "根據指南,建議使用..."
}
// Retriever 會自動找出讓 LLM 生成此答案的文檔成本對比
Retriever First:
├─ 階段 1:訓練小模型(Retriever)
│ 成本:$100-500
├─ 階段 2:訓練 LLM
│ 成本:$1000-5000
└─ 總成本:$1100-5500
LLMs First:
├─ 階段 1:訓練 LLM
│ 成本:$1000-5000
├─ 階段 2:訓練小模型(Retriever)+ LLM 推理
│ 成本:$300-1000(包含大量 LLM 推理)
└─ 總成本:$1300-6000
結論:成本相近,但 LLMs First 在第二階段需要更多 LLM 推理🎯 如何選擇?
決策樹
開始
│
▼
Q1: 瓶頸在哪裡?
│
├─ 檢索不準(找不到相關文檔)
│ └→ Retriever First ✓
│
└─ LLM 不會用資訊(找到了但答不好)
└→ LLMs First ✓
Q2: 有什麼樣的標註資料?
│
├─ 有查詢-文檔配對
│ └→ Retriever First ✓
│
└─ 有查詢-context-答案(人工提供 context)
└→ LLMs First ✓
Q3: 計算資源如何?
│
├─ 有限(先訓練小模型)
│ └→ Retriever First ✓
│
└─ 充足(可以先訓練 LLM)
└→ LLMs First ✓
Q4: 是否需要特定 LLM 風格?
│
├─ 是(需要客製化 LLM)
│ └→ LLMs First ✓
│
└─ 否(標準 LLM 即可)
└→ Retriever First ✓實際場景建議
場景 1:醫療問答系統
特點:
- 術語複雜,通用檢索器效果差
- 文檔專業,檢索是主要挑戰
推薦:Retriever First
理由:
✓ 先解決檢索問題
✓ 醫療術語需要專門訓練
✓ 有論文-摘要對(天然的訓練資料)場景 2:客服機器人
特點:
- 需要特定說話風格
- LLM 需要學會禮貌、專業的回答方式
推薦:LLMs First
理由:
✓ 先訓練 LLM 的說話風格
✓ 有客服對話記錄(人工 context)
✓ 檢索相對簡單(FAQ 查找)場景 3:法律文書分析
特點:
- 文書結構複雜
- 需要找到特定條款
推薦:Retriever First
理由:
✓ 檢索精確性要求高
✓ 法律術語需要專門理解
✓ 有判例-條文對(訓練資料)場景 4:教育輔導
特點:
- 需要特定教學風格
- LLM 需要逐步引導學生
推薦:LLMs First
理由:
✓ 教學風格很重要
✓ 有教師範例對話(人工 context)
✓ 檢索教材相對容易🔬 進階話題
混合訓練策略
不只有兩種純策略,還可以混合!
策略 A:迭代訓練
class IterativeRALLMTraining:
"""迭代式訓練:交替優化"""
def train(self, data, documents, iterations=3):
"""
多輪迭代訓練
每輪:
1. 訓練 Retriever
2. 固定 Retriever,訓練 LLM
3. 固定 LLM,再訓練 Retriever
...
"""
retriever = SentenceTransformer('base-model')
llm = AutoModelForCausalLM.from_pretrained('base-llm')
for iteration in range(1, iterations + 1):
print(f"\n{'='*50}")
print(f"迭代 {iteration}/{iterations}")
print(f"{'='*50}")
# 1. 訓練 Retriever(基於當前 LLM)
print("🔥 訓練 Retriever")
retriever = self._train_retriever(
retriever, llm, data, documents
)
# 2. 訓練 LLM(基於當前 Retriever)
print("🔥 訓練 LLM")
llm = self._train_llm(
llm, retriever, data
)
return retriever, llm策略 B:聯合微調(Joint Fine-tuning)
class JointRALLMTraining:
"""
聯合訓練:同時優化 Retriever 和 LLM
注意:計算成本極高,但效果最好
"""
def train(self, data, documents):
retriever = SentenceTransformer('base-model')
llm = AutoModelForCausalLM.from_pretrained('base-llm')
# 兩個模型同時訓練
for batch in data:
# 1. 用當前 Retriever 檢索
retrieved = retriever.retrieve(batch['query'])
# 2. 用當前 LLM 生成
generated = llm.generate(batch['query'], retrieved)
# 3. 計算 loss
loss = compute_loss(generated, batch['answer'])
# 4. 反向傳播(更新兩個模型)
loss.backward()
# 更新 Retriever
retriever_optimizer.step()
# 更新 LLM
llm_optimizer.step()Parameter-Efficient Training
降低訓練成本的技巧
# 使用 LoRA(Low-Rank Adaptation)
from peft import LoraConfig, get_peft_model
# 只訓練 LLM 的一小部分參數
lora_config = LoraConfig(
r=16, # Low-rank dimension
lora_alpha=32,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.1
)
# 應用到 LLM
llm = AutoModelForCausalLM.from_pretrained("base-llm")
llm = get_peft_model(llm, lora_config)
# 訓練參數大幅減少
print(f"可訓練參數:{llm.num_parameters() / 1e6:.2f}M")
# 輸出:可訓練參數:4.19M(vs 完整訓練的 7000M)🏁 總結
核心要點
1. Sequential Training 的智慧
- 不同時訓練所有組件
- 分階段,逐個優化
- 降低複雜度和資源需求
2. 兩種策略各有優勢
Retriever First:
✓ 檢索品質優先
✓ 適合檢索是瓶頸的場景
✓ 訓練資料要求:查詢-文檔對
✓ 推薦:領域特定、術語複雜的場景LLMs First:
✓ LLM 能力優先
✓ 適合 LLM 不會用資訊的場景
✓ 訓練資料要求:查詢-context-答案
✓ 推薦:需要特定風格或格式的場景3. 選擇建議
檢索困難 → Retriever First
LLM 不會用資訊 → LLMs First
兩者都有問題 → 迭代訓練
資源充足 → Joint Fine-tuning實施檢查清單
Retriever First 實施:
- ☑️ 準備查詢-文檔配對資料
- ☑️ 訓練 Retriever(3-5 epochs)
- ☑️ 固定 Retriever
- ☑️ 準備查詢-答案資料
- ☑️ 訓練 LLM(3-5 epochs)
- ☑️ 評估整體效果
LLMs First 實施:
- ☑️ 準備查詢-人工context-答案資料
- ☑️ 訓練 LLM(3-5 epochs)
- ☑️ 固定 LLM
- ☑️ 用 LLM 評估文檔品質
- ☑️ 訓練 Retriever(3-5 epochs)
- ☑️ 評估整體效果
最後的建議
從評估開始
- 先用通用模型測試
- 找出真正的瓶頸
- 針對瓶頸選擇策略
小規模驗證
- 不要一開始就大規模訓練
- 用小資料集驗證方法
- 確認有效後再擴大
持續監控
- 記錄每個階段的效果
- 分析錯誤案例
- 適時調整策略
考慮成本
- 計算訓練時間和費用
- 評估 ROI
- 選擇性價比最高的方案
記住:最好的訓練策略是最適合你的場景和資源的策略!
🔗 延伸閱讀
- 📄 REALM: Retrieval-Augmented Language Model Pre-Training
- 💻 RETRO: Improving language models by retrieving from trillions of tokens
- 📚 Atlas: Few-shot Learning with Retrieval Augmented Language Models
- 🎓 Parameter-Efficient Fine-Tuning (PEFT)
- 📊 LoRA: Low-Rank Adaptation of Large Language Models
🔗 相關文章系列
- 《Prompting vs RAG vs Fine-tuning》- 理解三種方法的差異
- 《RAG 系統實戰教學》- 建立基礎 RAG
- 《Advanced RAG 優化技術》- 優化檢索品質
- 《Reranking 深度解析》- 提升準確率