Skip to content
清晨的一缕阳光
返回

RAG 重排序与后处理实战

RAG 重排序与后处理实战

重排序(Reranking)是提升 RAG 系统检索质量的关键环节。如何在初步检索后进一步优化结果?如何处理检索内容以提升生成效果?本文将深入解析重排序与后处理的实战方案。

一、重排序概述

1.1 为什么需要重排序

重排序的必要性:

┌─────────────────────────────────────┐
│ 1. 向量检索局限                       │
│    - 语义相似度 ≠ 相关性            │
│    - 丢失关键词匹配                 │
├─────────────────────────────────────┤
│ 2. 质量参差不齐                       │
│    - 检索结果质量不一               │
│    - 需要进一步筛选                 │
├─────────────────────────────────────┤
│ 3. 上下文优化                         │
│    - 去除冗余信息                   │
│    - 提升生成质量                   │
├─────────────────────────────────────┤
│ 4. 多样性需求                         │
│    - 避免结果单一                   │
│    - 提供多角度信息                 │
└─────────────────────────────────────┘

1.2 重排序流程

graph LR
    A[用户查询] --> B[向量检索]
    B --> C[Top-K 粗排]
    C --> D[重排序模型]
    D --> E[精排结果]
    E --> F[后处理]
    F --> G[最终上下文]
    G --> H[LLM 生成]

1.3 主流重排序方法

方法原理优点缺点
Cross-Encoder双塔交互准确度高速度慢
ColBERT延迟交互平衡好实现复杂
LLM RerankLLM 评分语义理解强成本高
Reciprocal Rank排名融合简单快速效果一般

二、Cross-Encoder 重排序

2.1 Cross-Encoder 原理

# cross_encoder_reranker.py
from typing import List, Dict, Tuple
from sentence_transformers import CrossEncoder
import numpy as np

class CrossEncoderReranker:
    """Cross-Encoder 重排序器"""
    
    def __init__(
        self,
        model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2',
        device: str = None
    ):
        """
        初始化
        
        Args:
            model_name: 模型名称
            device: 运行设备
        """
        self.model = CrossEncoder(model_name, device=device)
    
    def rerank(
        self,
        query: str,
        documents: List[str],
        top_k: int = 5
    ) -> List[Dict]:
        """
        重排序
        
        Args:
            query: 查询
            documents: 待排序文档列表
            top_k: 返回数量
        
        Returns:
            排序后的文档及分数
        """
        # 构建输入对
        pairs = [[query, doc] for doc in documents]
        
        # 预测相关性分数
        scores = self.model.predict(pairs)
        
        # 排序
        sorted_indices = np.argsort(scores)[::-1]
        
        # 构建结果
        results = []
        for idx in sorted_indices[:top_k]:
            results.append({
                'document': documents[idx],
                'score': float(scores[idx]),
                'rank': len(results) + 1
            })
        
        return results
    
    def rerank_with_metadata(
        self,
        query: str,
        documents: List[Dict],
        top_k: int = 5
    ) -> List[Dict]:
        """带元数据重排序"""
        contents = [doc['content'] for doc in documents]
        ranked = self.rerank(query, contents, top_k)
        
        # 关联元数据
        results = []
        for r in ranked:
            for doc in documents:
                if doc['content'] == r['document']:
                    r['metadata'] = doc.get('metadata', {})
                    results.append(r)
                    break
        
        return results

2.2 批量重排序优化

# batch_reranker.py
from typing import List, Dict
from concurrent.futures import ThreadPoolExecutor

class BatchReranker:
    """批量重排序器"""
    
    def __init__(self, reranker, batch_size: int = 32):
        self.reranker = reranker
        self.batch_size = batch_size
    
    def rerank_batch(
        self,
        queries: List[str],
        documents_list: List[List[str]],
        top_k: int = 5
    ) -> List[List[Dict]]:
        """
        批量重排序
        
        Args:
            queries: 查询列表
            documents_list: 文档列表的列表
            top_k: 每查询返回数量
        
        Returns:
            重排序结果列表
        """
        all_results = []
        
        # 分批处理
        for i in range(0, len(queries), self.batch_size):
            batch_queries = queries[i:i + self.batch_size]
            batch_docs = documents_list[i:i + self.batch_size]
            
            batch_results = self._process_batch(
                batch_queries,
                batch_docs,
                top_k
            )
            
            all_results.extend(batch_results)
        
        return all_results
    
    def _process_batch(
        self,
        queries: List[str],
        documents_list: List[List[str]],
        top_k: int
    ) -> List[List[Dict]]:
        """处理批次"""
        results = []
        
        for query, docs in zip(queries, documents_list):
            ranked = self.reranker.rerank(query, docs, top_k)
            results.append(ranked)
        
        return results

2.3 多阶段重排序

# multi_stage_reranker.py
from typing import List, Dict

class MultiStageReranker:
    """多阶段重排序器"""
    
    def __init__(self, rerankers: List):
        """
        初始化
        
        Args:
            rerankers: 重排序器列表(按顺序)
        """
        self.rerankers = rerankers
    
    def rerank(
        self,
        query: str,
        documents: List[Dict],
        top_k_list: List[int] = None
    ) -> List[Dict]:
        """
        多阶段重排序
        
        Args:
            query: 查询
            documents: 文档列表
            top_k_list: 每阶段返回数量
        
        Returns:
            最终排序结果
        """
        if top_k_list is None:
            top_k_list = [50, 20, 10, 5]
        
        current_docs = documents
        
        # 逐阶段重排序
        for i, reranker in enumerate(self.rerankers):
            top_k = top_k_list[i] if i < len(top_k_list) else 5
            
            contents = [doc['content'] for doc in current_docs]
            ranked = reranker.rerank(query, contents, top_k)
            
            # 关联回原始文档
            current_docs = []
            for r in ranked:
                for doc in current_docs:
                    if doc['content'] == r['document']:
                        r['stage_scores'] = r.get('stage_scores', [])
                        r['stage_scores'].append(r['score'])
                        current_docs.append(doc)
                        break
        
        return current_docs

# 使用示例
def create_multi_stage_reranker():
    """创建多阶段重排序器"""
    from sentence_transformers import CrossEncoder
    
    # 阶段 1: 快速粗排
    reranker1 = CrossEncoderReranker(
        'cross-encoder/ms-marco-TinyBERT-L-2-v2'
    )
    
    # 阶段 2: 中等精度
    reranker2 = CrossEncoderReranker(
        'cross-encoder/ms-marco-electra-base'
    )
    
    # 阶段 3: 高精度
    reranker3 = CrossEncoderReranker(
        'cross-encoder/ms-marco-MiniLM-L-6-v2'
    )
    
    return MultiStageReranker([reranker1, reranker2, reranker3])

三、多样性重排序

3.1 MMR(最大边界相关)

# mmr_reranker.py
from typing import List, Dict
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

class MMRReranker:
    """MMR 多样性重排序器"""
    
    def __init__(self, embedding_model, lambda_param: float = 0.5):
        """
        初始化
        
        Args:
            embedding_model: Embedding 模型
            lambda_param: 相关性权重(0-1)
                         越大越重视相关性,越小越重视多样性
        """
        self.embedding_model = embedding_model
        self.lambda_param = lambda_param
    
    def rerank(
        self,
        query: str,
        documents: List[str],
        top_k: int = 5
    ) -> List[Dict]:
        """
        MMR 重排序
        
        Args:
            query: 查询
            documents: 文档列表
            top_k: 返回数量
        
        Returns:
            多样性排序结果
        """
        if not documents:
            return []
        
        # 计算嵌入
        query_embedding = self.embedding_model.encode([query])[0]
        doc_embeddings = self.embedding_model.encode(documents)
        
        # 计算查询 - 文档相似度
        query_similarities = cosine_similarity(
            [query_embedding],
            doc_embeddings
        )[0]
        
        # 计算文档 - 文档相似度
        doc_similarities = cosine_similarity(doc_embeddings)
        
        # MMR 选择
        selected = []
        remaining = list(range(len(documents)))
        
        while len(selected) < min(top_k, len(documents)):
            best_score = -float('inf')
            best_idx = None
            
            for idx in remaining:
                # 查询相关性
                query_score = query_similarities[idx]
                
                # 与已选文档的最小相似度
                if selected:
                    min_sim = min(
                        doc_similarities[idx][s] 
                        for s in selected
                    )
                else:
                    min_sim = 0
                
                # MMR 分数
                mmr_score = (
                    self.lambda_param * query_score -
                    (1 - self.lambda_param) * min_sim
                )
                
                if mmr_score > best_score:
                    best_score = mmr_score
                    best_idx = idx
            
            if best_idx is not None:
                selected.append(best_idx)
                remaining.remove(best_idx)
        
        # 构建结果
        results = []
        for rank, idx in enumerate(selected, 1):
            results.append({
                'document': documents[idx],
                'mmr_score': query_similarities[idx],
                'rank': rank
            })
        
        return results

3.2 聚类多样性

# cluster_diversity_reranker.py
from typing import List, Dict
from sklearn.cluster import KMeans
import numpy as np

class ClusterDiversityReranker:
    """基于聚类的多样性重排序器"""
    
    def __init__(self, embedding_model):
        self.embedding_model = embedding_model
    
    def rerank(
        self,
        query: str,
        documents: List[Dict],
        top_k: int = 5,
        n_clusters: int = 3
    ) -> List[Dict]:
        """
        聚类多样性重排序
        
        Args:
            query: 查询
            documents: 文档列表(带分数)
            top_k: 返回数量
            n_clusters: 聚类数
        
        Returns:
            多样性排序结果
        """
        if not documents:
            return []
        
        # 计算嵌入
        contents = [doc['content'] for doc in documents]
        embeddings = self.embedding_model.encode(contents)
        
        # 聚类
        n_clusters = min(n_clusters, len(documents))
        kmeans = KMeans(n_clusters=n_clusters, random_state=42)
        clusters = kmeans.fit_predict(embeddings)
        
        # 从每个聚类中选择 top 文档
        results = []
        docs_per_cluster = max(1, top_k // n_clusters)
        
        for cluster_id in range(n_clusters):
            cluster_docs = [
                doc for doc, c in zip(documents, clusters) 
                if c == cluster_id
            ]
            
            # 按原始分数排序
            cluster_docs.sort(
                key=lambda x: x.get('score', 0), 
                reverse=True
            )
            
            # 选择 top
            selected = cluster_docs[:docs_per_cluster]
            results.extend(selected)
        
        # 重新排序
        results = results[:top_k]
        for i, doc in enumerate(results, 1):
            doc['rank'] = i
        
        return results

四、上下文后处理

4.1 上下文压缩

# context_compressor.py
from typing import List, Dict
import re

class ContextCompressor:
    """上下文压缩器"""
    
    def __init__(
        self,
        llm=None,
        max_tokens: int = 2000
    ):
        """
        初始化
        
        Args:
            llm: LLM 模型(用于智能压缩)
            max_tokens: 最大 Token 数
        """
        self.llm = llm
        self.max_tokens = max_tokens
    
    def compress(
        self,
        documents: List[Dict],
        query: str
    ) -> str:
        """
        压缩上下文
        
        Args:
            documents: 文档列表
            query: 查询
        
        Returns:
            压缩后的上下文
        """
        # 计算当前 Token 数
        current_tokens = self._count_tokens(documents)
        
        if current_tokens <= self.max_tokens:
            return self._concatenate_documents(documents)
        
        # 需要压缩
        if self.llm:
            return self._llm_compress(documents, query)
        else:
            return self._heuristic_compress(documents)
    
    def _llm_compress(
        self,
        documents: List[Dict],
        query: str
    ) -> str:
        """LLM 智能压缩"""
        context = self._concatenate_documents(documents)
        
        prompt = f"""
请压缩以下上下文,保留与问题最相关的信息:

问题:{query}

上下文:
{context}

请提取与问题最相关的核心信息,压缩到 2000 字以内。
"""
        return self.llm.generate(prompt)
    
    def _heuristic_compress(
        self,
        documents: List[Dict]
    ) -> str:
        """启发式压缩"""
        # 按分数排序
        sorted_docs = sorted(
            documents,
            key=lambda x: x.get('score', 0),
            reverse=True
        )
        
        # 逐步添加文档直到达到限制
        result = []
        current_tokens = 0
        
        for doc in sorted_docs:
            doc_tokens = self._count_tokens([doc])
            if current_tokens + doc_tokens > self.max_tokens:
                break
            
            result.append(doc['content'])
            current_tokens += doc_tokens
        
        return '\n\n'.join(result)
    
    def _count_tokens(self, documents: List[Dict]) -> int:
        """计算 Token 数"""
        text = self._concatenate_documents(documents)
        # 简单估算:1 中文字符 ≈ 1.5 Token
        return int(len(text) * 1.5)
    
    def _concatenate_documents(self, documents: List[Dict]) -> str:
        """拼接文档"""
        return '\n\n---\n\n'.join([
            f"【文档 {i+1}\n{doc['content']}"
            for i, doc in enumerate(documents)
        ])

4.2 相关信息提取

# relevant_extraction.py
from typing import List, Dict

class RelevantExtractor:
    """相关信息提取器"""
    
    def __init__(self, llm):
        self.llm = llm
    
    def extract(
        self,
        documents: List[Dict],
        query: str,
        max_length: int = 1000
    ) -> str:
        """
        提取相关信息
        
        Args:
            documents: 文档列表
            query: 查询
            max_length: 最大长度
        
        Returns:
            提取的相关信息
        """
        context = '\n\n'.join([doc['content'] for doc in documents])
        
        prompt = f"""
从以下上下文中提取与问题最相关的信息:

问题:{query}

上下文:
{context}

请只提取与问题直接相关的信息,忽略无关内容。
提取的信息不超过 {max_length} 字。
"""
        return self.llm.generate(prompt)

4.3 上下文格式化

# context_formatter.py
from typing import List, Dict

class ContextFormatter:
    """上下文格式化器"""
    
    def __init__(self, template: str = None):
        """
        初始化
        
        Args:
            template: 格式化模板
        """
        self.template = template or self._default_template()
    
    def _default_template(self) -> str:
        """默认模板"""
        return """
【相关文档】
{documents}

【回答要求】
请基于以上文档回答问题,引用相关文档内容。
"""
    
    def format(
        self,
        documents: List[Dict],
        query: str
    ) -> str:
        """
        格式化上下文
        
        Args:
            documents: 文档列表
            query: 查询
        
        Returns:
            格式化的上下文
        """
        docs_text = self._format_documents(documents)
        return self.template.format(documents=docs_text)
    
    def _format_documents(self, documents: List[Dict]) -> str:
        """格式化文档"""
        formatted = []
        
        for i, doc in enumerate(documents, 1):
            doc_text = f"""
[文档 {i}]
{doc['content']}
(相关度:{doc.get('score', 'N/A')})
"""
            formatted.append(doc_text)
        
        return '\n'.join(formatted)

五、重排序评估

5.1 评估指标

# reranking_metrics.py
from typing import List, Dict
import numpy as np

class RerankingMetrics:
    """重排序评估指标"""
    
    @staticmethod
    def calculate_ndcg(
        ranked_docs: List[Dict],
        ground_truth: List[str],
        k: int = 10
    ) -> float:
        """
        计算 NDCG@K
        
        Args:
            ranked_docs: 排序后的文档
            ground_truth: 相关文档 ID 列表
            k: 截取位置
        """
        dcg = 0
        idcg = 0
        
        # DCG
        for i, doc in enumerate(ranked_docs[:k]):
            if doc['id'] in ground_truth:
                dcg += 1.0 / np.log2(i + 2)
        
        # IDCG
        ideal_count = min(len(ground_truth), k)
        for i in range(ideal_count):
            idcg += 1.0 / np.log2(i + 2)
        
        return dcg / idcg if idcg > 0 else 0
    
    @staticmethod
    def calculate_precision_at_k(
        ranked_docs: List[Dict],
        ground_truth: List[str],
        k: int = 10
    ) -> float:
        """计算 Precision@K"""
        top_k = ranked_docs[:k]
        relevant_count = sum(
            1 for doc in top_k 
            if doc['id'] in ground_truth
        )
        return relevant_count / k
    
    @staticmethod
    def calculate_recall_at_k(
        ranked_docs: List[Dict],
        ground_truth: List[str],
        k: int = 10
    ) -> float:
        """计算 Recall@K"""
        top_k = ranked_docs[:k]
        relevant_in_top_k = sum(
            1 for doc in top_k 
            if doc['id'] in ground_truth
        )
        return relevant_in_top_k / len(ground_truth)
    
    @staticmethod
    def calculate_mrr(
        ranked_docs: List[Dict],
        ground_truth: List[str]
    ) -> float:
        """计算 MRR"""
        for i, doc in enumerate(ranked_docs, 1):
            if doc['id'] in ground_truth:
                return 1.0 / i
        return 0

5.2 A/B 测试框架

# rerank_ab_test.py
from typing import List, Dict
import random

class RerankABTester:
    """重排序 A/B 测试器"""
    
    def __init__(self):
        self.results = {'A': [], 'B': []}
    
    def run_test(
        self,
        queries: List[str],
        documents_list: List[List[Dict]],
        ground_truth_list: List[List[str]],
        reranker_a,
        reranker_b,
        n_tests: int = 100
    ) -> Dict:
        """
        运行 A/B 测试
        
        Args:
            queries: 查询列表
            documents_list: 文档列表
            ground_truth_list: 标准答案
            reranker_a: 重排序器 A
            reranker_b: 重排序器 B
            n_tests: 测试次数
        """
        metrics_a = []
        metrics_b = []
        
        for _ in range(n_tests):
            # 随机选择查询
            idx = random.randint(0, len(queries) - 1)
            query = queries[idx]
            docs = documents_list[idx]
            ground_truth = ground_truth_list[idx]
            
            # A 方案
            ranked_a = reranker_a.rerank(query, docs, top_k=10)
            ndcg_a = RerankingMetrics.calculate_ndcg(
                ranked_a, ground_truth
            )
            metrics_a.append(ndcg_a)
            
            # B 方案
            ranked_b = reranker_b.rerank(query, docs, top_k=10)
            ndcg_b = RerankingMetrics.calculate_ndcg(
                ranked_b, ground_truth
            )
            metrics_b.append(ndcg_b)
        
        # 统计结果
        return {
            'A': {
                'mean_ndcg': np.mean(metrics_a),
                'std_ndcg': np.std(metrics_a),
                'samples': len(metrics_a)
            },
            'B': {
                'mean_ndcg': np.mean(metrics_b),
                'std_ndcg': np.std(metrics_b),
                'samples': len(metrics_b)
            },
            'improvement': (np.mean(metrics_b) - np.mean(metrics_a)) / np.mean(metrics_a)
        }

六、实战案例

6.1 完整 RAG 重排序流程

# complete_rag_rerank.py
from typing import List, Dict

class CompleteRAGRerank:
    """完整 RAG 重排序流程"""
    
    def __init__(self, config: Dict):
        """
        初始化
        
        Args:
            config: 配置字典
        """
        # 检索器
        self.retriever = config['retriever']
        
        # 重排序器
        self.rerankers = config['rerankers']
        
        # 后处理器
        self.compressor = config.get('compressor')
        self.formatter = config.get('formatter')
    
    def search_and_rerank(
        self,
        query: str,
        top_k_initial: int = 50,
        top_k_final: int = 5
    ) -> Dict:
        """
        检索并重排序
        
        Args:
            query: 查询
            top_k_initial: 初始检索数量
            top_k_final: 最终返回数量
        
        Returns:
            结果
        """
        # 1. 向量检索
        retrieved_docs = self.retriever.search(
            query, 
            top_k=top_k_initial
        )
        
        # 2. 多阶段重排序
        current_docs = retrieved_docs
        for reranker in self.rerankers:
            current_docs = reranker.rerank(
                query,
                current_docs,
                top_k=top_k_final
            )
        
        # 3. 上下文压缩
        if self.compressor:
            context = self.compressor.compress(
                current_docs,
                query
            )
        else:
            context = '\n\n'.join([
                doc['content'] for doc in current_docs
            ])
        
        # 4. 格式化
        if self.formatter:
            formatted_context = self.formatter.format(
                current_docs,
                query
            )
        else:
            formatted_context = context
        
        return {
            'query': query,
            'retrieved_count': len(retrieved_docs),
            'reranked_docs': current_docs,
            'context': formatted_context
        }

# 使用示例
config = {
    'retriever': vector_retriever,
    'rerankers': [
        CrossEncoderReranker('cross-encoder/ms-marco-TinyBERT-L-2-v2'),
        CrossEncoderReranker('cross-encoder/ms-marco-MiniLM-L-6-v2')
    ],
    'compressor': ContextCompressor(llm, max_tokens=2000),
    'formatter': ContextFormatter()
}

rag_system = CompleteRAGRerank(config)
result = rag_system.search_and_rerank(
    "如何提高 RAG 系统的检索质量?",
    top_k_initial=50,
    top_k_final=5
)

七、总结

7.1 核心要点

  1. 重排序方法选择

    • 快速场景:Reciprocal Rank Fusion
    • 平衡场景:Cross-Encoder
    • 高质量:多阶段重排序
  2. 多样性处理

    • MMR:平衡相关性和多样性
    • 聚类:确保覆盖不同主题
  3. 后处理优化

    • 压缩:减少 Token 消耗
    • 提取:聚焦相关信息
    • 格式化:提升生成效果

7.2 性能基准

方法NDCG@10延迟成本
无重排序0.65最快最低
Cross-Encoder0.78
多阶段重排序0.82
LLM 重排序0.85最慢最高

参考资料


分享这篇文章到:

上一篇文章
Redis RDB 持久化详解
下一篇文章
Redis 数据类型选择指南