Skip to content
清晨的一缕阳光
返回

RAG 性能优化实战指南

RAG 性能优化实战指南

性能是 RAG 系统生产落地的关键因素。如何降低响应延迟?如何提升系统吞吐量?如何优化资源消耗?本文将详解 RAG 性能优化的完整实战方案。

一、性能瓶颈分析

1.1 RAG 流程拆解

graph LR
    A[用户查询] --> B[Query 处理]
    B --> C[向量检索]
    C --> D[重排序]
    D --> E[上下文构建]
    E --> F[LLM 生成]
    F --> G[响应返回]
    
    style B fill:#ff9999
    style C fill:#99ccff
    style D fill:#99ff99
    style E fill:#ffff99
    style F fill:#ff99ff

1.2 典型延迟分布

阶段延迟(P95)占比优化空间
Query 处理10-50ms2-5%
向量检索50-200ms10-20%
重排序100-500ms20-40%
上下文构建20-100ms5-10%
LLM 生成500-3000ms40-70%最高

1.3 性能分析工具

# performance_profiler.py
import time
from functools import wraps
from contextlib import contextmanager
from typing import Dict, List

class RAGProfiler:
    """RAG 性能分析器"""
    
    def __init__(self):
        self.timings: Dict[str, List[float]] = {}
    
    @contextmanager
    def profile(self, stage: str):
        """性能分析上下文"""
        start = time.perf_counter()
        try:
            yield
        finally:
            elapsed = time.perf_counter() - start
            if stage not in self.timings:
                self.timings[stage] = []
            self.timings[stage].append(elapsed)
    
    def get_statistics(self) -> Dict:
        """获取统计信息"""
        import numpy as np
        
        stats = {}
        for stage, times in self.timings.items():
            stats[stage] = {
                'count': len(times),
                'mean': np.mean(times),
                'p50': np.percentile(times, 50),
                'p95': np.percentile(times, 95),
                'p99': np.percentile(times, 99),
                'min': np.min(times),
                'max': np.max(times)
            }
        
        return stats
    
    def print_report(self):
        """打印性能报告"""
        stats = self.get_statistics()
        
        print("\n=== RAG 性能分析 ===\n")
        print(f"{'阶段':<15} {'P50':<10} {'P95':<10} {'P99':<10} {'占比':<10}")
        print("-" * 65)
        
        total_p95 = sum(s['p95'] for s in stats.values())
        
        for stage, s in stats.items():
            percentage = (s['p95'] / total_p95 * 100) if total_p95 > 0 else 0
            print(f"{stage:<15} {s['p50']*1000:<10.1f} {s['p95']*1000:<10.1f} "
                  f"{s['p99']*1000:<10.1f} {percentage:<10.1f}%")

# 使用示例
profiler = RAGProfiler()

with profiler.profile('retrieval'):
    # 检索代码
    pass

with profiler.profile('reranking'):
    # 重排序代码
    pass

with profiler.profile('generation'):
    # 生成代码
    pass

profiler.print_report()

二、检索优化

2.1 索引优化

# index_optimization.py
import faiss
import numpy as np

class IndexOptimizer:
    """索引优化器"""
    
    @staticmethod
    def optimize_for_recall(
        vectors: np.ndarray,
        target_recall: float = 0.95
    ) -> faiss.Index:
        """
        为召回率优化索引
        
        Args:
            vectors: 向量数据
            target_recall: 目标召回率
        """
        dimension = vectors.shape[1]
        
        # 选择 IVF 索引
        nlist = int(4 * np.sqrt(len(vectors)))
        quantizer = faiss.IndexFlatIP(dimension)
        index = faiss.IndexIVFFlat(quantizer, dimension, nlist)
        
        # 训练索引
        index.train(vectors)
        
        # 调整 nprobe 以达到目标召回率
        nprobe = 1
        while nprobe < nlist:
            recall = IndexOptimizer._estimate_recall(index, vectors, nprobe)
            if recall >= target_recall:
                break
            nprobe += 1
        
        index.nprobe = nprobe
        
        return index
    
    @staticmethod
    def optimize_for_latency(
        vectors: np.ndarray,
        target_latency_ms: float = 50
    ) -> faiss.Index:
        """
        为延迟优化索引
        
        Args:
            vectors: 向量数据
            target_latency_ms: 目标延迟(毫秒)
        """
        dimension = vectors.shape[1]
        
        # 使用 HNSW 索引(更快)
        M = 32
        index = faiss.IndexHNSWFlat(dimension, M)
        index.hnsw.efSearch = 64  # 降低 efSearch 提升速度
        
        index.add(vectors)
        
        return index
    
    @staticmethod
    def _estimate_recall(
        index: faiss.Index,
        vectors: np.ndarray,
        nprobe: int,
        k: int = 10
    ) -> float:
        """估计召回率"""
        index.nprobe = nprobe
        
        # 采样测试
        sample_size = min(1000, len(vectors))
        sample_indices = np.random.choice(len(vectors), sample_size, replace=False)
        
        recall_scores = []
        for idx in sample_indices:
            query = vectors[idx:idx+1]
            
            # 精确搜索
            exact_D, exact_I = faiss.IndexFlatIP(index.d).search(query, k)
            
            # 近似搜索
            index.nprobe = nprobe
            approx_D, approx_I = index.search(query, k)
            
            # 计算召回率
            intersection = len(set(exact_I[0]) & set(approx_I[0]))
            recall_scores.append(intersection / k)
        
        return np.mean(recall_scores)

2.2 缓存策略

# retrieval_cache.py
from typing import Dict, List, Optional
from functools import lru_cache
import hashlib
import numpy as np
from datetime import datetime, timedelta

class RetrievalCache:
    """检索缓存"""
    
    def __init__(
        self,
        max_size: int = 10000,
        ttl_seconds: int = 3600
    ):
        """
        初始化
        
        Args:
            max_size: 最大缓存数
            ttl_seconds: 过期时间(秒)
        """
        self.max_size = max_size
        self.ttl = timedelta(seconds=ttl_seconds)
        self.cache: Dict[str, Dict] = {}
        self.access_times: Dict[str, datetime] = {}
    
    def _compute_key(self, query: str, query_vector: np.ndarray) -> str:
        """计算缓存键"""
        # 使用查询文本作为键(更稳定)
        return hashlib.md5(query.encode()).hexdigest()
    
    def get(
        self,
        query: str,
        query_vector: np.ndarray
    ) -> Optional[List[Dict]]:
        """获取缓存"""
        key = self._compute_key(query, query_vector)
        
        if key in self.cache:
            # 检查是否过期
            if datetime.now() - self.access_times[key] > self.ttl:
                del self.cache[key]
                del self.access_times[key]
                return None
            
            # 更新访问时间
            self.access_times[key] = datetime.now()
            return self.cache[key]
        
        return None
    
    def set(
        self,
        query: str,
        query_vector: np.ndarray,
        results: List[Dict]
    ):
        """设置缓存"""
        key = self._compute_key(query, query_vector)
        
        # LRU 淘汰
        if len(self.cache) >= self.max_size:
            oldest_key = min(
                self.access_times,
                key=lambda k: self.access_times[k]
            )
            del self.cache[oldest_key]
            del self.access_times[oldest_key]
        
        self.cache[key] = results
        self.access_times[key] = datetime.now()
    
    def clear(self):
        """清空缓存"""
        self.cache.clear()
        self.access_times.clear()
    
    def get_stats(self) -> Dict:
        """获取缓存统计"""
        return {
            'size': len(self.cache),
            'max_size': self.max_size,
            'utilization': len(self.cache) / self.max_size
        }

# 使用示例
cache = RetrievalCache(max_size=10000, ttl_seconds=3600)

def search_with_cache(query: str, query_vector: np.ndarray):
    # 尝试缓存
    cached_results = cache.get(query, query_vector)
    if cached_results:
        return cached_results
    
    # 执行检索
    results = vector_retriever.search(query_vector, top_k=10)
    
    # 缓存结果
    cache.set(query, query_vector, results)
    
    return results

2.3 批量检索

# batch_retrieval.py
from typing import List
import numpy as np

class BatchRetriever:
    """批量检索器"""
    
    def __init__(self, retriever, batch_size: int = 32):
        self.retriever = retriever
        self.batch_size = batch_size
    
    def batch_search(
        self,
        queries: List[str],
        query_vectors: np.ndarray,
        top_k: int = 10
    ) -> List[List[Dict]]:
        """
        批量检索
        
        Args:
            queries: 查询列表
            query_vectors: 查询向量
            top_k: 返回数量
        """
        all_results = []
        
        # 分批处理
        for i in range(0, len(queries), self.batch_size):
            batch_queries = queries[i:i + self.batch_size]
            batch_vectors = query_vectors[i:i + self.batch_size]
            
            batch_results = self._process_batch(
                batch_queries,
                batch_vectors,
                top_k
            )
            all_results.extend(batch_results)
        
        return all_results
    
    def _process_batch(
        self,
        queries: List[str],
        vectors: np.ndarray,
        top_k: int
    ) -> List[List[Dict]]:
        """处理批次"""
        # FAISS 支持批量搜索
        distances, indices = self.retriever.index.search(vectors, top_k)
        
        results = []
        for i in range(len(queries)):
            batch_results = []
            for j in range(top_k):
                batch_results.append({
                    'id': indices[i][j],
                    'score': distances[i][j],
                    'content': self.retriever.get_content(indices[i][j])
                })
            results.append(batch_results)
        
        return results

三、重排序优化

3.1 模型蒸馏

# reranker_distillation.py
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch

class RerankerDistillation:
    """重排序模型蒸馏"""
    
    def __init__(
        self,
        teacher_model: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2',
        student_model: str = 'cross-encoder/ms-marco-TinyBERT-L-2-v2'
    ):
        self.teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model)
        self.teacher_model = AutoModelForSequenceClassification.from_pretrained(teacher_model)
        
        self.student_tokenizer = AutoTokenizer.from_pretrained(student_model)
        self.student_model = AutoModelForSequenceClassification.from_pretrained(student_model)
    
    def distill(
        self,
        training_data: List[Dict],
        epochs: int = 3,
        batch_size: int = 32
    ):
        """
        蒸馏训练
        
        Args:
            training_data: 训练数据
            epochs: 训练轮数
            batch_size: 批次大小
        """
        self.student_model.train()
        self.teacher_model.eval()
        
        optimizer = torch.optim.AdamW(self.student_model.parameters(), lr=2e-5)
        
        for epoch in range(epochs):
            total_loss = 0
            
            for i in range(0, len(training_data), batch_size):
                batch = training_data[i:i + batch_size]
                
                # 教师模型预测
                teacher_inputs = self.teacher_tokenizer(
                    [(t['query'], t['document']) for t in batch],
                    padding=True,
                    truncation=True,
                    return_tensors='pt'
                )
                
                with torch.no_grad():
                    teacher_outputs = self.teacher_model(**teacher_inputs).logits
                
                # 学生模型预测
                student_inputs = self.student_tokenizer(
                    [(t['query'], t['document']) for t in batch],
                    padding=True,
                    truncation=True,
                    return_tensors='pt'
                )
                
                student_outputs = self.student_model(**student_inputs).logits
                
                # KL 散度损失
                loss = torch.nn.functional.kl_div(
                    torch.log_softmax(student_outputs, dim=1),
                    torch.softmax(teacher_outputs, dim=1),
                    reduction='batchmean'
                )
                
                # 反向传播
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
            
            print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(training_data)*batch_size:.4f}")

3.2 轻量级重排序

# lightweight_reranker.py
from typing import List, Dict
import numpy as np

class LightweightReranker:
    """轻量级重排序器"""
    
    def __init__(self, embedding_model):
        self.embedding_model = embedding_model
    
    def rerank(
        self,
        query: str,
        documents: List[Dict],
        top_k: int = 5
    ) -> List[Dict]:
        """
        轻量级重排序
        
        使用 Embedding 余弦相似度代替 Cross-Encoder
        """
        if not documents:
            return []
        
        # 计算嵌入
        query_embedding = self.embedding_model.encode([query])[0]
        doc_embeddings = self.embedding_model.encode(
            [doc['content'] for doc in documents]
        )
        
        # 计算相似度
        similarities = np.dot(doc_embeddings, query_embedding)
        
        # 排序
        sorted_indices = np.argsort(similarities)[::-1]
        
        # 构建结果
        results = []
        for rank, idx in enumerate(sorted_indices[:top_k], 1):
            results.append({
                **documents[idx],
                'rerank_score': float(similarities[idx]),
                'rank': rank
            })
        
        return results
    
    def rerank_hybrid(
        self,
        query: str,
        documents: List[Dict],
        original_scores: List[float],
        top_k: int = 5,
        alpha: float = 0.5
    ) -> List[Dict]:
        """
        混合重排序(向量 + 原始分数)
        
        Args:
            alpha: 向量分数权重
        """
        # 向量重排序
        vector_results = self.rerank(query, documents, top_k=top_k * 2)
        
        # 分数融合
        for doc in vector_results:
            # 归一化原始分数
            original_norm = (doc.get('score', 0) - min(original_scores)) / (max(original_scores) - min(original_scores))
            
            # 混合分数
            doc['hybrid_score'] = (
                alpha * doc['rerank_score'] +
                (1 - alpha) * original_norm
            )
        
        # 按混合分数排序
        vector_results.sort(key=lambda x: x['hybrid_score'], reverse=True)
        
        return vector_results[:top_k]

四、生成优化

4.1 Prompt 优化

# prompt_optimization.py
from typing import List, Dict

class PromptOptimizer:
    """Prompt 优化器"""
    
    @staticmethod
    def optimize_context(
        documents: List[Dict],
        max_tokens: int = 2000
    ) -> str:
        """
        优化上下文
        
        Args:
            documents: 文档列表
            max_tokens: 最大 Token 数
        """
        # 1. 按相关性排序
        sorted_docs = sorted(
            documents,
            key=lambda x: x.get('score', 0),
            reverse=True
        )
        
        # 2. 选择 Top-K
        selected_docs = []
        current_tokens = 0
        
        for doc in sorted_docs:
            doc_tokens = len(doc['content']) // 4  # 粗略估算
            if current_tokens + doc_tokens > max_tokens:
                break
            
            selected_docs.append(doc)
            current_tokens += doc_tokens
        
        # 3. 构建优化上下文
        context_parts = []
        for i, doc in enumerate(selected_docs, 1):
            context_parts.append(f"【相关文档 {i}\n{doc['content']}")
        
        return '\n\n'.join(context_parts)
    
    @staticmethod
    def optimize_instructions(query: str) -> str:
        """优化指令"""
        return f"""请基于以下文档信息回答问题。
如果文档中没有相关信息,请说明"根据提供的文档无法回答此问题"。
回答要简洁准确,引用相关文档内容。

问题:{query}
"""

4.2 流式输出

# streaming_generation.py
from typing import Generator, Dict

class StreamingGenerator:
    """流式生成器"""
    
    def __init__(self, llm):
        self.llm = llm
    
    def generate_stream(
        self,
        prompt: str,
        max_tokens: int = 2000
    ) -> Generator[str, None, None]:
        """
        流式生成
        
        Yields:
            生成的文本片段
        """
        # 不同 LLM 的流式接口不同
        # 这里是通用示例
        
        response = self.llm.generate(
            prompt,
            max_tokens=max_tokens,
            stream=True
        )
        
        for chunk in response:
            if hasattr(chunk, 'choices'):
                text = chunk.choices[0].delta.content
            else:
                text = chunk.text
            
            if text:
                yield text
    
    def generate_with_timing(
        self,
        prompt: str
    ) -> Dict:
        """生成并记录时间"""
        import time
        
        start_time = time.time()
        first_token_time = None
        full_response = []
        
        for chunk in self.generate_stream(prompt):
            if first_token_time is None:
                first_token_time = time.time()
            full_response.append(chunk)
        
        end_time = time.time()
        
        return {
            'response': ''.join(full_response),
            'time_to_first_token': first_token_time - start_time,
            'total_time': end_time - start_time,
            'tokens_per_second': len(full_response) / (end_time - start_time)
        }

4.3 模型选择优化

# model_selection.py
from typing import Dict

class ModelSelector:
    """模型选择器"""
    
    MODELS = {
        'fast': {
            'name': 'gpt-3.5-turbo',
            'latency': 'low',
            'cost': 'low',
            'quality': 'medium'
        },
        'balanced': {
            'name': 'gpt-4-turbo',
            'latency': 'medium',
            'cost': 'medium',
            'quality': 'high'
        },
        'quality': {
            'name': 'gpt-4',
            'latency': 'high',
            'cost': 'high',
            'quality': 'highest'
        }
    }
    
    @classmethod
    def select_model(
        cls,
        query_complexity: str,
        latency_requirement: str,
        budget_constraint: str
    ) -> str:
        """
        选择模型
        
        Args:
            query_complexity: 查询复杂度(simple/medium/complex)
            latency_requirement: 延迟要求(low/medium/high)
            budget_constraint: 预算限制(low/medium/high)
        """
        # 简单查询,低延迟要求 -> 快速模型
        if query_complexity == 'simple' and latency_requirement == 'low':
            return cls.MODELS['fast']['name']
        
        # 复杂查询,高质量要求 -> 高质量模型
        if query_complexity == 'complex' and latency_requirement == 'high':
            return cls.MODELS['quality']['name']
        
        # 默认平衡
        return cls.MODELS['balanced']['name']

五、系统级优化

5.1 并行处理

# parallel_processing.py
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Dict

class ParallelProcessor:
    """并行处理器"""
    
    def __init__(self, max_workers: int = 4):
        self.max_workers = max_workers
    
    def parallel_rerank(
        self,
        queries: List[str],
        documents_list: List[List[Dict]],
        reranker
    ) -> List[List[Dict]]:
        """并行重排序"""
        results = []
        
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            futures = {
                executor.submit(reranker.rerank, query, docs): i
                for i, (query, docs) in enumerate(zip(queries, documents_list))
            }
            
            for future in as_completed(futures):
                idx = futures[future]
                results.append((idx, future.result()))
        
        # 按原顺序返回
        results.sort(key=lambda x: x[0])
        return [r[1] for r in results]
    
    def parallel_embed(
        self,
        texts: List[str],
        embedding_model
    ) -> List:
        """并行 Embedding"""
        batch_size = self.max_workers * 4
        all_embeddings = []
        
        for i in range(0, len(texts), batch_size):
            batch = texts[i:i + batch_size]
            embeddings = embedding_model.encode(batch)
            all_embeddings.extend(embeddings)
        
        return all_embeddings

5.2 资源优化

# resource_optimization.py
import psutil
import torch
from typing import Dict

class ResourceOptimizer:
    """资源优化器"""
    
    @staticmethod
    def get_system_resources() -> Dict:
        """获取系统资源"""
        return {
            'cpu_percent': psutil.cpu_percent(),
            'memory_percent': psutil.virtual_memory().percent,
            'memory_available': psutil.virtual_memory().available,
            'disk_percent': psutil.disk_usage('/').percent
        }
    
    @staticmethod
    def get_gpu_resources() -> Dict:
        """获取 GPU 资源"""
        if not torch.cuda.is_available():
            return {'available': False}
        
        return {
            'available': True,
            'gpu_count': torch.cuda.device_count(),
            'gpu_memory_used': torch.cuda.memory_allocated() / 1024**2,
            'gpu_memory_total': torch.cuda.get_device_properties(0).total_memory / 1024**2
        }
    
    @staticmethod
    def optimize_batch_size(
        available_memory_mb: int,
        model_memory_per_batch_mb: int
    ) -> int:
        """优化批次大小"""
        # 保留 20% 内存余量
        safe_memory = available_memory_mb * 0.8
        batch_size = int(safe_memory / model_memory_per_batch_mb)
        return max(1, batch_size)
    
    @staticmethod
    def enable_memory_efficient_mode():
        """启用内存高效模式"""
        # 清理缓存
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        # 设置内存分配策略
        torch.cuda.set_per_process_memory_fraction(0.8)

六、性能监控

6.1 指标收集

# performance_monitoring.py
from typing import Dict, List
from datetime import datetime

class PerformanceMonitor:
    """性能监控器"""
    
    def __init__(self):
        self.metrics_history: List[Dict] = []
    
    def record(
        self,
        latency_p50: float,
        latency_p95: float,
        latency_p99: float,
        qps: float,
        error_rate: float,
        metadata: Dict = None
    ):
        """记录指标"""
        self.metrics_history.append({
            'timestamp': datetime.now().isoformat(),
            'latency_p50': latency_p50,
            'latency_p95': latency_p95,
            'latency_p99': latency_p99,
            'qps': qps,
            'error_rate': error_rate,
            'metadata': metadata or {}
        })
    
    def get_summary(self, hours: int = 1) -> Dict:
        """获取汇总"""
        from datetime import timedelta
        
        cutoff = datetime.now() - timedelta(hours=hours)
        recent = [
            m for m in self.metrics_history
            if datetime.fromisoformat(m['timestamp']) >= cutoff
        ]
        
        if not recent:
            return {}
        
        import numpy as np
        
        return {
            'avg_latency_p95': np.mean([m['latency_p95'] for m in recent]),
            'avg_qps': np.mean([m['qps'] for m in recent]),
            'avg_error_rate': np.mean([m['error_rate'] for m in recent]),
            'min_latency_p95': np.min([m['latency_p95'] for m in recent]),
            'max_latency_p95': np.max([m['latency_p95'] for m in recent])
        }

七、总结

7.1 优化清单

检索优化

重排序优化

生成优化

系统优化

7.2 性能基准

优化项优化前优化后提升
P95 延迟3000ms800ms73%
QPS1050400%
成本/请求$0.05$0.0260%

参考资料


分享这篇文章到:

上一篇文章
RocketMQ 生产者发送机制详解
下一篇文章
Redis AOF 持久化详解