RAG 重排序与后处理实战
重排序(Reranking)是提升 RAG 系统检索质量的关键环节。如何在初步检索后进一步优化结果?如何处理检索内容以提升生成效果?本文将深入解析重排序与后处理的实战方案。
一、重排序概述
1.1 为什么需要重排序
重排序的必要性:
┌─────────────────────────────────────┐
│ 1. 向量检索局限 │
│ - 语义相似度 ≠ 相关性 │
│ - 丢失关键词匹配 │
├─────────────────────────────────────┤
│ 2. 质量参差不齐 │
│ - 检索结果质量不一 │
│ - 需要进一步筛选 │
├─────────────────────────────────────┤
│ 3. 上下文优化 │
│ - 去除冗余信息 │
│ - 提升生成质量 │
├─────────────────────────────────────┤
│ 4. 多样性需求 │
│ - 避免结果单一 │
│ - 提供多角度信息 │
└─────────────────────────────────────┘
1.2 重排序流程
graph LR
A[用户查询] --> B[向量检索]
B --> C[Top-K 粗排]
C --> D[重排序模型]
D --> E[精排结果]
E --> F[后处理]
F --> G[最终上下文]
G --> H[LLM 生成]
1.3 主流重排序方法
| 方法 | 原理 | 优点 | 缺点 |
|---|---|---|---|
| Cross-Encoder | 双塔交互 | 准确度高 | 速度慢 |
| ColBERT | 延迟交互 | 平衡好 | 实现复杂 |
| LLM Rerank | LLM 评分 | 语义理解强 | 成本高 |
| Reciprocal Rank | 排名融合 | 简单快速 | 效果一般 |
二、Cross-Encoder 重排序
2.1 Cross-Encoder 原理
# cross_encoder_reranker.py
from typing import List, Dict, Tuple
from sentence_transformers import CrossEncoder
import numpy as np
class CrossEncoderReranker:
"""Cross-Encoder 重排序器"""
def __init__(
self,
model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2',
device: str = None
):
"""
初始化
Args:
model_name: 模型名称
device: 运行设备
"""
self.model = CrossEncoder(model_name, device=device)
def rerank(
self,
query: str,
documents: List[str],
top_k: int = 5
) -> List[Dict]:
"""
重排序
Args:
query: 查询
documents: 待排序文档列表
top_k: 返回数量
Returns:
排序后的文档及分数
"""
# 构建输入对
pairs = [[query, doc] for doc in documents]
# 预测相关性分数
scores = self.model.predict(pairs)
# 排序
sorted_indices = np.argsort(scores)[::-1]
# 构建结果
results = []
for idx in sorted_indices[:top_k]:
results.append({
'document': documents[idx],
'score': float(scores[idx]),
'rank': len(results) + 1
})
return results
def rerank_with_metadata(
self,
query: str,
documents: List[Dict],
top_k: int = 5
) -> List[Dict]:
"""带元数据重排序"""
contents = [doc['content'] for doc in documents]
ranked = self.rerank(query, contents, top_k)
# 关联元数据
results = []
for r in ranked:
for doc in documents:
if doc['content'] == r['document']:
r['metadata'] = doc.get('metadata', {})
results.append(r)
break
return results
2.2 批量重排序优化
# batch_reranker.py
from typing import List, Dict
from concurrent.futures import ThreadPoolExecutor
class BatchReranker:
"""批量重排序器"""
def __init__(self, reranker, batch_size: int = 32):
self.reranker = reranker
self.batch_size = batch_size
def rerank_batch(
self,
queries: List[str],
documents_list: List[List[str]],
top_k: int = 5
) -> List[List[Dict]]:
"""
批量重排序
Args:
queries: 查询列表
documents_list: 文档列表的列表
top_k: 每查询返回数量
Returns:
重排序结果列表
"""
all_results = []
# 分批处理
for i in range(0, len(queries), self.batch_size):
batch_queries = queries[i:i + self.batch_size]
batch_docs = documents_list[i:i + self.batch_size]
batch_results = self._process_batch(
batch_queries,
batch_docs,
top_k
)
all_results.extend(batch_results)
return all_results
def _process_batch(
self,
queries: List[str],
documents_list: List[List[str]],
top_k: int
) -> List[List[Dict]]:
"""处理批次"""
results = []
for query, docs in zip(queries, documents_list):
ranked = self.reranker.rerank(query, docs, top_k)
results.append(ranked)
return results
2.3 多阶段重排序
# multi_stage_reranker.py
from typing import List, Dict
class MultiStageReranker:
"""多阶段重排序器"""
def __init__(self, rerankers: List):
"""
初始化
Args:
rerankers: 重排序器列表(按顺序)
"""
self.rerankers = rerankers
def rerank(
self,
query: str,
documents: List[Dict],
top_k_list: List[int] = None
) -> List[Dict]:
"""
多阶段重排序
Args:
query: 查询
documents: 文档列表
top_k_list: 每阶段返回数量
Returns:
最终排序结果
"""
if top_k_list is None:
top_k_list = [50, 20, 10, 5]
current_docs = documents
# 逐阶段重排序
for i, reranker in enumerate(self.rerankers):
top_k = top_k_list[i] if i < len(top_k_list) else 5
contents = [doc['content'] for doc in current_docs]
ranked = reranker.rerank(query, contents, top_k)
# 关联回原始文档
current_docs = []
for r in ranked:
for doc in current_docs:
if doc['content'] == r['document']:
r['stage_scores'] = r.get('stage_scores', [])
r['stage_scores'].append(r['score'])
current_docs.append(doc)
break
return current_docs
# 使用示例
def create_multi_stage_reranker():
"""创建多阶段重排序器"""
from sentence_transformers import CrossEncoder
# 阶段 1: 快速粗排
reranker1 = CrossEncoderReranker(
'cross-encoder/ms-marco-TinyBERT-L-2-v2'
)
# 阶段 2: 中等精度
reranker2 = CrossEncoderReranker(
'cross-encoder/ms-marco-electra-base'
)
# 阶段 3: 高精度
reranker3 = CrossEncoderReranker(
'cross-encoder/ms-marco-MiniLM-L-6-v2'
)
return MultiStageReranker([reranker1, reranker2, reranker3])
三、多样性重排序
3.1 MMR(最大边界相关)
# mmr_reranker.py
from typing import List, Dict
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
class MMRReranker:
"""MMR 多样性重排序器"""
def __init__(self, embedding_model, lambda_param: float = 0.5):
"""
初始化
Args:
embedding_model: Embedding 模型
lambda_param: 相关性权重(0-1)
越大越重视相关性,越小越重视多样性
"""
self.embedding_model = embedding_model
self.lambda_param = lambda_param
def rerank(
self,
query: str,
documents: List[str],
top_k: int = 5
) -> List[Dict]:
"""
MMR 重排序
Args:
query: 查询
documents: 文档列表
top_k: 返回数量
Returns:
多样性排序结果
"""
if not documents:
return []
# 计算嵌入
query_embedding = self.embedding_model.encode([query])[0]
doc_embeddings = self.embedding_model.encode(documents)
# 计算查询 - 文档相似度
query_similarities = cosine_similarity(
[query_embedding],
doc_embeddings
)[0]
# 计算文档 - 文档相似度
doc_similarities = cosine_similarity(doc_embeddings)
# MMR 选择
selected = []
remaining = list(range(len(documents)))
while len(selected) < min(top_k, len(documents)):
best_score = -float('inf')
best_idx = None
for idx in remaining:
# 查询相关性
query_score = query_similarities[idx]
# 与已选文档的最小相似度
if selected:
min_sim = min(
doc_similarities[idx][s]
for s in selected
)
else:
min_sim = 0
# MMR 分数
mmr_score = (
self.lambda_param * query_score -
(1 - self.lambda_param) * min_sim
)
if mmr_score > best_score:
best_score = mmr_score
best_idx = idx
if best_idx is not None:
selected.append(best_idx)
remaining.remove(best_idx)
# 构建结果
results = []
for rank, idx in enumerate(selected, 1):
results.append({
'document': documents[idx],
'mmr_score': query_similarities[idx],
'rank': rank
})
return results
3.2 聚类多样性
# cluster_diversity_reranker.py
from typing import List, Dict
from sklearn.cluster import KMeans
import numpy as np
class ClusterDiversityReranker:
"""基于聚类的多样性重排序器"""
def __init__(self, embedding_model):
self.embedding_model = embedding_model
def rerank(
self,
query: str,
documents: List[Dict],
top_k: int = 5,
n_clusters: int = 3
) -> List[Dict]:
"""
聚类多样性重排序
Args:
query: 查询
documents: 文档列表(带分数)
top_k: 返回数量
n_clusters: 聚类数
Returns:
多样性排序结果
"""
if not documents:
return []
# 计算嵌入
contents = [doc['content'] for doc in documents]
embeddings = self.embedding_model.encode(contents)
# 聚类
n_clusters = min(n_clusters, len(documents))
kmeans = KMeans(n_clusters=n_clusters, random_state=42)
clusters = kmeans.fit_predict(embeddings)
# 从每个聚类中选择 top 文档
results = []
docs_per_cluster = max(1, top_k // n_clusters)
for cluster_id in range(n_clusters):
cluster_docs = [
doc for doc, c in zip(documents, clusters)
if c == cluster_id
]
# 按原始分数排序
cluster_docs.sort(
key=lambda x: x.get('score', 0),
reverse=True
)
# 选择 top
selected = cluster_docs[:docs_per_cluster]
results.extend(selected)
# 重新排序
results = results[:top_k]
for i, doc in enumerate(results, 1):
doc['rank'] = i
return results
四、上下文后处理
4.1 上下文压缩
# context_compressor.py
from typing import List, Dict
import re
class ContextCompressor:
"""上下文压缩器"""
def __init__(
self,
llm=None,
max_tokens: int = 2000
):
"""
初始化
Args:
llm: LLM 模型(用于智能压缩)
max_tokens: 最大 Token 数
"""
self.llm = llm
self.max_tokens = max_tokens
def compress(
self,
documents: List[Dict],
query: str
) -> str:
"""
压缩上下文
Args:
documents: 文档列表
query: 查询
Returns:
压缩后的上下文
"""
# 计算当前 Token 数
current_tokens = self._count_tokens(documents)
if current_tokens <= self.max_tokens:
return self._concatenate_documents(documents)
# 需要压缩
if self.llm:
return self._llm_compress(documents, query)
else:
return self._heuristic_compress(documents)
def _llm_compress(
self,
documents: List[Dict],
query: str
) -> str:
"""LLM 智能压缩"""
context = self._concatenate_documents(documents)
prompt = f"""
请压缩以下上下文,保留与问题最相关的信息:
问题:{query}
上下文:
{context}
请提取与问题最相关的核心信息,压缩到 2000 字以内。
"""
return self.llm.generate(prompt)
def _heuristic_compress(
self,
documents: List[Dict]
) -> str:
"""启发式压缩"""
# 按分数排序
sorted_docs = sorted(
documents,
key=lambda x: x.get('score', 0),
reverse=True
)
# 逐步添加文档直到达到限制
result = []
current_tokens = 0
for doc in sorted_docs:
doc_tokens = self._count_tokens([doc])
if current_tokens + doc_tokens > self.max_tokens:
break
result.append(doc['content'])
current_tokens += doc_tokens
return '\n\n'.join(result)
def _count_tokens(self, documents: List[Dict]) -> int:
"""计算 Token 数"""
text = self._concatenate_documents(documents)
# 简单估算:1 中文字符 ≈ 1.5 Token
return int(len(text) * 1.5)
def _concatenate_documents(self, documents: List[Dict]) -> str:
"""拼接文档"""
return '\n\n---\n\n'.join([
f"【文档 {i+1}】\n{doc['content']}"
for i, doc in enumerate(documents)
])
4.2 相关信息提取
# relevant_extraction.py
from typing import List, Dict
class RelevantExtractor:
"""相关信息提取器"""
def __init__(self, llm):
self.llm = llm
def extract(
self,
documents: List[Dict],
query: str,
max_length: int = 1000
) -> str:
"""
提取相关信息
Args:
documents: 文档列表
query: 查询
max_length: 最大长度
Returns:
提取的相关信息
"""
context = '\n\n'.join([doc['content'] for doc in documents])
prompt = f"""
从以下上下文中提取与问题最相关的信息:
问题:{query}
上下文:
{context}
请只提取与问题直接相关的信息,忽略无关内容。
提取的信息不超过 {max_length} 字。
"""
return self.llm.generate(prompt)
4.3 上下文格式化
# context_formatter.py
from typing import List, Dict
class ContextFormatter:
"""上下文格式化器"""
def __init__(self, template: str = None):
"""
初始化
Args:
template: 格式化模板
"""
self.template = template or self._default_template()
def _default_template(self) -> str:
"""默认模板"""
return """
【相关文档】
{documents}
【回答要求】
请基于以上文档回答问题,引用相关文档内容。
"""
def format(
self,
documents: List[Dict],
query: str
) -> str:
"""
格式化上下文
Args:
documents: 文档列表
query: 查询
Returns:
格式化的上下文
"""
docs_text = self._format_documents(documents)
return self.template.format(documents=docs_text)
def _format_documents(self, documents: List[Dict]) -> str:
"""格式化文档"""
formatted = []
for i, doc in enumerate(documents, 1):
doc_text = f"""
[文档 {i}]
{doc['content']}
(相关度:{doc.get('score', 'N/A')})
"""
formatted.append(doc_text)
return '\n'.join(formatted)
五、重排序评估
5.1 评估指标
# reranking_metrics.py
from typing import List, Dict
import numpy as np
class RerankingMetrics:
"""重排序评估指标"""
@staticmethod
def calculate_ndcg(
ranked_docs: List[Dict],
ground_truth: List[str],
k: int = 10
) -> float:
"""
计算 NDCG@K
Args:
ranked_docs: 排序后的文档
ground_truth: 相关文档 ID 列表
k: 截取位置
"""
dcg = 0
idcg = 0
# DCG
for i, doc in enumerate(ranked_docs[:k]):
if doc['id'] in ground_truth:
dcg += 1.0 / np.log2(i + 2)
# IDCG
ideal_count = min(len(ground_truth), k)
for i in range(ideal_count):
idcg += 1.0 / np.log2(i + 2)
return dcg / idcg if idcg > 0 else 0
@staticmethod
def calculate_precision_at_k(
ranked_docs: List[Dict],
ground_truth: List[str],
k: int = 10
) -> float:
"""计算 Precision@K"""
top_k = ranked_docs[:k]
relevant_count = sum(
1 for doc in top_k
if doc['id'] in ground_truth
)
return relevant_count / k
@staticmethod
def calculate_recall_at_k(
ranked_docs: List[Dict],
ground_truth: List[str],
k: int = 10
) -> float:
"""计算 Recall@K"""
top_k = ranked_docs[:k]
relevant_in_top_k = sum(
1 for doc in top_k
if doc['id'] in ground_truth
)
return relevant_in_top_k / len(ground_truth)
@staticmethod
def calculate_mrr(
ranked_docs: List[Dict],
ground_truth: List[str]
) -> float:
"""计算 MRR"""
for i, doc in enumerate(ranked_docs, 1):
if doc['id'] in ground_truth:
return 1.0 / i
return 0
5.2 A/B 测试框架
# rerank_ab_test.py
from typing import List, Dict
import random
class RerankABTester:
"""重排序 A/B 测试器"""
def __init__(self):
self.results = {'A': [], 'B': []}
def run_test(
self,
queries: List[str],
documents_list: List[List[Dict]],
ground_truth_list: List[List[str]],
reranker_a,
reranker_b,
n_tests: int = 100
) -> Dict:
"""
运行 A/B 测试
Args:
queries: 查询列表
documents_list: 文档列表
ground_truth_list: 标准答案
reranker_a: 重排序器 A
reranker_b: 重排序器 B
n_tests: 测试次数
"""
metrics_a = []
metrics_b = []
for _ in range(n_tests):
# 随机选择查询
idx = random.randint(0, len(queries) - 1)
query = queries[idx]
docs = documents_list[idx]
ground_truth = ground_truth_list[idx]
# A 方案
ranked_a = reranker_a.rerank(query, docs, top_k=10)
ndcg_a = RerankingMetrics.calculate_ndcg(
ranked_a, ground_truth
)
metrics_a.append(ndcg_a)
# B 方案
ranked_b = reranker_b.rerank(query, docs, top_k=10)
ndcg_b = RerankingMetrics.calculate_ndcg(
ranked_b, ground_truth
)
metrics_b.append(ndcg_b)
# 统计结果
return {
'A': {
'mean_ndcg': np.mean(metrics_a),
'std_ndcg': np.std(metrics_a),
'samples': len(metrics_a)
},
'B': {
'mean_ndcg': np.mean(metrics_b),
'std_ndcg': np.std(metrics_b),
'samples': len(metrics_b)
},
'improvement': (np.mean(metrics_b) - np.mean(metrics_a)) / np.mean(metrics_a)
}
六、实战案例
6.1 完整 RAG 重排序流程
# complete_rag_rerank.py
from typing import List, Dict
class CompleteRAGRerank:
"""完整 RAG 重排序流程"""
def __init__(self, config: Dict):
"""
初始化
Args:
config: 配置字典
"""
# 检索器
self.retriever = config['retriever']
# 重排序器
self.rerankers = config['rerankers']
# 后处理器
self.compressor = config.get('compressor')
self.formatter = config.get('formatter')
def search_and_rerank(
self,
query: str,
top_k_initial: int = 50,
top_k_final: int = 5
) -> Dict:
"""
检索并重排序
Args:
query: 查询
top_k_initial: 初始检索数量
top_k_final: 最终返回数量
Returns:
结果
"""
# 1. 向量检索
retrieved_docs = self.retriever.search(
query,
top_k=top_k_initial
)
# 2. 多阶段重排序
current_docs = retrieved_docs
for reranker in self.rerankers:
current_docs = reranker.rerank(
query,
current_docs,
top_k=top_k_final
)
# 3. 上下文压缩
if self.compressor:
context = self.compressor.compress(
current_docs,
query
)
else:
context = '\n\n'.join([
doc['content'] for doc in current_docs
])
# 4. 格式化
if self.formatter:
formatted_context = self.formatter.format(
current_docs,
query
)
else:
formatted_context = context
return {
'query': query,
'retrieved_count': len(retrieved_docs),
'reranked_docs': current_docs,
'context': formatted_context
}
# 使用示例
config = {
'retriever': vector_retriever,
'rerankers': [
CrossEncoderReranker('cross-encoder/ms-marco-TinyBERT-L-2-v2'),
CrossEncoderReranker('cross-encoder/ms-marco-MiniLM-L-6-v2')
],
'compressor': ContextCompressor(llm, max_tokens=2000),
'formatter': ContextFormatter()
}
rag_system = CompleteRAGRerank(config)
result = rag_system.search_and_rerank(
"如何提高 RAG 系统的检索质量?",
top_k_initial=50,
top_k_final=5
)
七、总结
7.1 核心要点
-
重排序方法选择
- 快速场景:Reciprocal Rank Fusion
- 平衡场景:Cross-Encoder
- 高质量:多阶段重排序
-
多样性处理
- MMR:平衡相关性和多样性
- 聚类:确保覆盖不同主题
-
后处理优化
- 压缩:减少 Token 消耗
- 提取:聚焦相关信息
- 格式化:提升生成效果
7.2 性能基准
| 方法 | NDCG@10 | 延迟 | 成本 |
|---|---|---|---|
| 无重排序 | 0.65 | 最快 | 最低 |
| Cross-Encoder | 0.78 | 中 | 中 |
| 多阶段重排序 | 0.82 | 慢 | 高 |
| LLM 重排序 | 0.85 | 最慢 | 最高 |
参考资料