RAG 向量检索工程实践
向量检索是 RAG 系统的核心组件,直接影响检索质量和系统性能。如何构建高效的向量索引?如何优化检索性能?本文将深入解析向量检索的工程实践。
一、向量检索基础
1.1 检索流程
graph LR
A[用户查询] --> B[Query 嵌入]
B --> C[向量检索]
C --> D[相似度计算]
D --> E[Top-K 结果]
E --> F[重排序]
F --> G[返回结果]
1.2 核心挑战
向量检索挑战:
┌─────────────────────────────────────┐
│ 1. 性能要求 │
│ - 毫秒级响应 │
│ - 高并发支持 │
├─────────────────────────────────────┤
│ 2. 规模挑战 │
│ - 亿级向量规模 │
│ - 内存和存储限制 │
├─────────────────────────────────────┤
│ 3. 精度要求 │
│ - 高召回率 │
│ - 高准确率 │
├─────────────────────────────────────┤
│ 4. 实时更新 │
│ - 动态增删向量 │
│ - 索引重建成本 │
└─────────────────────────────────────┘
1.3 主流向量数据库
| 数据库 | 特点 | 适用场景 |
|---|---|---|
| Pinecone | 托管服务、易用 | 快速上线、小规模 |
| Milvus | 开源、功能全 | 大规模、定制化 |
| Weaviate | 图数据库、混合检索 | 知识图谱 |
| Qdrant | 高性能、过滤强 | 生产环境 |
| Chroma | 轻量、易集成 | 开发测试 |
| FAISS | Facebook、高性能 | 科研、定制 |
二、Embedding 模型选择
2.1 主流 Embedding 模型
# embedding_models.py
from typing import List
import numpy as np
class EmbeddingModelManager:
"""Embedding 模型管理器"""
MODELS = {
# 中文模型
'bge-large-zh': {
'name': 'BAAI/bge-large-zh-v1.5',
'dimension': 1024,
'max_length': 512,
'language': 'zh',
'performance': 'high'
},
'bge-base-zh': {
'name': 'BAAI/bge-base-zh-v1.5',
'dimension': 768,
'max_length': 512,
'language': 'zh',
'performance': 'medium'
},
# 英文模型
'text-embedding-ada-002': {
'name': 'OpenAI Ada V2',
'dimension': 1536,
'max_length': 8191,
'language': 'en',
'performance': 'high'
},
'all-MiniLM-L6-v2': {
'name': 'sentence-transformers/all-MiniLM-L6-v2',
'dimension': 384,
'max_length': 256,
'language': 'en',
'performance': 'fast'
},
# 多语言模型
'multilingual-e5-large': {
'name': 'intfloat/multilingual-e5-large',
'dimension': 1024,
'max_length': 512,
'language': 'multi',
'performance': 'high'
}
}
@classmethod
def get_model_info(cls, model_name: str) -> dict:
"""获取模型信息"""
return cls.MODELS.get(model_name, {})
@classmethod
def recommend_model(
cls,
language: str = 'zh',
performance_requirement: str = 'balanced'
) -> str:
"""推荐模型"""
if language == 'zh':
if performance_requirement == 'high':
return 'bge-large-zh'
elif performance_requirement == 'fast':
return 'bge-base-zh'
elif language == 'en':
if performance_requirement == 'high':
return 'text-embedding-ada-002'
elif performance_requirement == 'fast':
return 'all-MiniLM-L6-v2'
return 'multilingual-e5-large'
2.2 本地 Embedding 服务
# embedding_service.py
from typing import List, Union
from sentence_transformers import SentenceTransformer
import numpy as np
class LocalEmbeddingService:
"""本地 Embedding 服务"""
def __init__(self, model_name: str = 'BAAI/bge-large-zh-v1.5'):
"""
初始化服务
Args:
model_name: 模型名称
"""
self.model = SentenceTransformer(model_name)
self.dimension = self.model.get_sentence_embedding_dimension()
def embed(
self,
texts: Union[str, List[str]],
batch_size: int = 32,
show_progress: bool = False
) -> np.ndarray:
"""
生成 Embedding
Args:
texts: 文本或文本列表
batch_size: 批处理大小
show_progress: 显示进度条
Returns:
Embedding 向量
"""
if isinstance(texts, str):
texts = [texts]
embeddings = self.model.encode(
texts,
batch_size=batch_size,
show_progress_bar=show_progress,
normalize_embeddings=True # 归一化
)
return embeddings
def embed_query(self, query: str) -> np.ndarray:
"""Embedding 查询(带指令前缀)"""
# BGE 模型推荐添加查询前缀
query_with_prefix = "为这个句子生成表示以用于检索:" + query
return self.embed(query_with_prefix)[0]
def embed_documents(
self,
documents: List[str],
batch_size: int = 32
) -> List[np.ndarray]:
"""Embedding 文档"""
embeddings = self.embed(documents, batch_size)
return embeddings.tolist()
def similarity(
self,
embedding1: np.ndarray,
embedding2: np.ndarray
) -> float:
"""计算余弦相似度"""
return np.dot(embedding1, embedding2)
2.3 API Embedding 服务
# api_embedding_service.py
from typing import List
import openai
import requests
class APIEmbeddingService:
"""API Embedding 服务"""
def __init__(self, api_key: str, model: str = 'text-embedding-ada-002'):
self.api_key = api_key
self.model = model
if model.startswith('text-embedding'):
self.client = openai.OpenAI(api_key=api_key)
def embed(self, texts: List[str]) -> List[List[float]]:
"""生成 Embedding"""
if self.model.startswith('text-embedding'):
return self._embed_openai(texts)
else:
return self._embed_custom_api(texts)
def _embed_openai(self, texts: List[str]) -> List[List[float]]:
"""OpenAI Embedding"""
response = self.client.embeddings.create(
model=self.model,
input=texts
)
return [item.embedding for item in response.data]
def _embed_custom_api(self, texts: List[str]) -> List[List[float]]:
"""自定义 API Embedding"""
response = requests.post(
'https://api.example.com/embeddings',
headers={'Authorization': f'Bearer {self.api_key}'},
json={'texts': texts, 'model': self.model}
)
return response.json()['embeddings']
三、向量索引构建
3.1 索引类型选择
# index_types.py
from enum import Enum
class IndexType(Enum):
"""索引类型"""
# 精确索引
FLAT = "flat" # 暴力搜索,100% 准确
# 基于树的索引
IVF_FLAT = "ivf_flat" # 倒排文件
IVF_PQ = "ivf_pq" # 乘积量化
IVF_SQ8 = "ivf_sq8" # 标量量化
# 基于图的索引
HNSW = "hnsw" # 分层可导航小世界
# 混合索引
IVF_HNSW = "ivf_hnsw" # IVF + HNSW
class IndexSelector:
"""索引选择器"""
@staticmethod
def select_index_type(
vector_count: int,
dimension: int,
accuracy_requirement: float,
memory_limit_mb: int
) -> IndexType:
"""
选择索引类型
Args:
vector_count: 向量数量
dimension: 维度
accuracy_requirement: 准确率要求 (0-1)
memory_limit_mb: 内存限制 (MB)
Returns:
推荐的索引类型
"""
# 小规模(< 10 万)
if vector_count < 100000:
return IndexType.FLAT
# 大规模但要求高准确率
if accuracy_requirement > 0.99:
return IndexType.HNSW
# 内存受限
estimated_memory = vector_count * dimension * 4 / 1024 / 1024
if estimated_memory > memory_limit_mb:
return IndexType.IVF_PQ # 压缩率高
# 平衡性能和准确率
return IndexType.IVF_FLAT
3.2 FAISS 索引构建
# faiss_index_builder.py
import faiss
import numpy as np
from typing import List
class FaissIndexBuilder:
"""FAISS 索引构建器"""
def __init__(self, dimension: int, index_type: str = 'ivf_flat'):
"""
初始化
Args:
dimension: 向量维度
index_type: 索引类型
"""
self.dimension = dimension
self.index_type = index_type
self.index = None
def build_flat_index(self) -> faiss.IndexFlatIP:
"""构建扁平索引(内积相似度)"""
self.index = faiss.IndexFlatIP(self.dimension)
return self.index
def build_ivf_index(
self,
nlist: int = 100,
quantizer_type: str = 'flat'
) -> faiss.IndexIVFFlat:
"""
构建 IVF 索引
Args:
nlist: 聚类中心数
quantizer_type: 量化器类型
"""
# 创建量化器
if quantizer_type == 'flat':
quantizer = faiss.IndexFlatIP(self.dimension)
else:
quantizer = faiss.IndexFlatL2(self.dimension)
# 创建 IVF 索引
self.index = faiss.IndexIVFFlat(
quantizer,
self.dimension,
nlist,
faiss.METRIC_INNER_PRODUCT
)
return self.index
def build_hnsw_index(
self,
M: int = 32,
efConstruction: int = 200
) -> faiss.IndexHNSWFlat:
"""
构建 HNSW 索引
Args:
M: 最大连接数
efConstruction: 构建时的搜索深度
"""
self.index = faiss.IndexHNSWFlat(
self.dimension,
M,
faiss.METRIC_INNER_PRODUCT
)
self.index.hnsw.efConstruction = efConstruction
return self.index
def train(self, vectors: np.ndarray):
"""训练索引(IVF 索引需要)"""
if not self.index.is_trained:
self.index.train(vectors)
def add(self, vectors: np.ndarray, ids: np.ndarray = None):
"""添加向量"""
if ids is not None:
self.index.add_with_ids(vectors, ids)
else:
self.index.add(vectors)
def search(
self,
query_vector: np.ndarray,
k: int = 10
) -> tuple:
"""
搜索
Args:
query_vector: 查询向量
k: 返回结果数
Returns:
(距离,索引)
"""
distances, indices = self.index.search(query_vector.reshape(1, -1), k)
return distances, indices
def save(self, filepath: str):
"""保存索引"""
faiss.write_index(self.index, filepath)
def load(self, filepath: str):
"""加载索引"""
self.index = faiss.read_index(filepath)
3.3 Milvus 索引构建
# milvus_index_builder.py
from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, Index
class MilvusIndexBuilder:
"""Milvus 索引构建器"""
def __init__(self, uri: str = "localhost:19530"):
"""初始化"""
connections.connect("default", uri=uri)
def create_collection(
self,
collection_name: str,
dimension: int,
metric_type: str = "COSINE"
) -> Collection:
"""
创建集合
Args:
collection_name: 集合名
dimension: 向量维度
metric_type: 距离类型
"""
# 定义 schema
fields = [
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True),
FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=65535),
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=dimension),
FieldSchema(name="metadata", dtype=DataType.JSON)
]
schema = CollectionSchema(fields, "RAG document collection")
# 创建集合
collection = Collection(collection_name, schema)
# 创建索引
index_params = {
"metric_type": metric_type,
"index_type": "HNSW",
"params": {"M": 32, "efConstruction": 200}
}
collection.create_index("embedding", index_params)
return collection
def insert_documents(
self,
collection: Collection,
documents: List[dict]
):
"""插入文档"""
entities = [
[doc['id'] for doc in documents],
[doc['content'] for doc in documents],
[doc['embedding'] for doc in documents],
[doc['metadata'] for doc in documents]
]
collection.insert(entities)
collection.flush()
def search(
self,
collection: Collection,
query_vector: List[float],
limit: int = 10,
filter_expr: str = None
) -> List[dict]:
"""
搜索
Args:
collection: 集合
query_vector: 查询向量
limit: 返回数量
filter_expr: 过滤表达式
Returns:
搜索结果
"""
search_params = {
"metric_type": "COSINE",
"params": {"ef": 100}
}
results = collection.search(
data=[query_vector],
anns_field="embedding",
param=search_params,
limit=limit,
expr=filter_expr
)
return results
四、检索优化
4.1 混合检索
# hybrid_search.py
from typing import List, Dict, Tuple
import numpy as np
class HybridSearcher:
"""混合检索器"""
def __init__(
self,
vector_searcher,
keyword_searcher,
vector_weight: float = 0.7
):
"""
初始化
Args:
vector_searcher: 向量检索器
keyword_searcher: 关键词检索器
vector_weight: 向量检索权重
"""
self.vector_searcher = vector_searcher
self.keyword_searcher = keyword_searcher
self.vector_weight = vector_weight
def search(
self,
query: str,
query_vector: np.ndarray,
top_k: int = 10
) -> List[Dict]:
"""
混合检索
Args:
query: 查询文本
query_vector: 查询向量
top_k: 返回数量
Returns:
检索结果
"""
# 向量检索
vector_results = self.vector_searcher.search(
query_vector,
top_k=top_k * 2 # 取更多用于融合
)
# 关键词检索
keyword_results = self.keyword_searcher.search(
query,
top_k=top_k * 2
)
# 融合结果
fused_results = self._reciprocal_rank_fusion(
vector_results,
keyword_results,
top_k
)
return fused_results
def _reciprocal_rank_fusion(
self,
vector_results: List[Dict],
keyword_results: List[Dict],
top_k: int
) -> List[Dict]:
"""
倒数排名融合 (RRF)
Args:
vector_results: 向量检索结果
keyword_results: 关键词检索结果
top_k: 返回数量
Returns:
融合后的结果
"""
# 计算 RRF 分数
scores = {}
for i, result in enumerate(vector_results):
doc_id = result['id']
rank = i + 1
score = self.vector_weight / (rank + 60) # k=60
scores[doc_id] = scores.get(doc_id, 0) + score
for i, result in enumerate(keyword_results):
doc_id = result['id']
rank = i + 1
score = (1 - self.vector_weight) / (rank + 60)
scores[doc_id] = scores.get(doc_id, 0) + score
# 排序
sorted_docs = sorted(
scores.items(),
key=lambda x: x[1],
reverse=True
)[:top_k]
# 构建结果
results = []
for doc_id, score in sorted_docs:
result = self._get_document_by_id(doc_id)
result['hybrid_score'] = score
results.append(result)
return results
def _get_document_by_id(self, doc_id: str) -> Dict:
"""根据 ID 获取文档"""
# 从存储中获取
pass
4.2 多路召回
# multi_recall_search.py
from typing import List, Dict
class MultiRecallSearcher:
"""多路召回检索器"""
def __init__(self, recall_strategies: List):
"""
初始化
Args:
recall_strategies: 召回策略列表
"""
self.recall_strategies = recall_strategies
def search(
self,
query: str,
query_vector: np.ndarray,
top_k_per_strategy: int = 20,
final_top_k: int = 10
) -> List[Dict]:
"""
多路召回
Args:
query: 查询文本
query_vector: 查询向量
top_k_per_strategy: 每路召回数量
final_top_k: 最终返回数量
Returns:
检索结果
"""
all_results = []
# 多路召回
for strategy in self.recall_strategies:
strategy_results = strategy.search(
query,
query_vector,
top_k=top_k_per_strategy
)
all_results.extend(strategy_results)
# 去重
unique_results = self._deduplicate(all_results)
# 重排序
ranked_results = self._rerank(unique_results, query, query_vector)
return ranked_results[:final_top_k]
def _deduplicate(self, results: List[Dict]) -> List[Dict]:
"""去重"""
seen_ids = set()
unique_results = []
for result in results:
if result['id'] not in seen_ids:
seen_ids.add(result['id'])
unique_results.append(result)
return unique_results
def _rerank(
self,
results: List[Dict],
query: str,
query_vector: np.ndarray
) -> List[Dict]:
"""重排序"""
# 使用 Cross-Encoder 重排序
from sentence_transformers import CrossEncoder
model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
pairs = [[query, result['content']] for result in results]
scores = model.predict(pairs)
# 按分数排序
for i, result in enumerate(results):
result['rerank_score'] = scores[i]
results.sort(key=lambda x: x['rerank_score'], reverse=True)
return results
4.3 元数据过滤
# metadata_filter.py
from typing import Dict, List, Optional
class MetadataFilter:
"""元数据过滤器"""
def __init__(self):
self.filters = []
def add_filter(
self,
field: str,
operator: str,
value: any
):
"""
添加过滤条件
Args:
field: 字段名
operator: 操作符 (eq, ne, gt, lt, gte, lte, in, like)
value: 值
"""
self.filters.append({
'field': field,
'operator': operator,
'value': value
})
def build_expression(self) -> str:
"""构建过滤表达式"""
expressions = []
for f in self.filters:
if f['operator'] == 'eq':
expr = f"{f['field']} == {repr(f['value'])}"
elif f['operator'] == 'ne':
expr = f"{f['field']} != {repr(f['value'])}"
elif f['operator'] == 'gt':
expr = f"{f['field']} > {repr(f['value'])}"
elif f['operator'] == 'lt':
expr = f"{f['field']} < {repr(f['value'])}"
elif f['operator'] == 'in':
values = ', '.join(repr(v) for v in f['value'])
expr = f"{f['field']} in [{values}]"
elif f['operator'] == 'like':
expr = f"{f['field']} like '{f['value']}'"
else:
continue
expressions.append(expr)
return ' && '.join(expressions)
def apply_filter(
self,
results: List[Dict],
metadata_field: str = 'metadata'
) -> List[Dict]:
"""应用过滤"""
filtered = []
for result in results:
metadata = result.get(metadata_field, {})
if self._match_filters(metadata):
filtered.append(result)
return filtered
def _match_filters(self, metadata: Dict) -> bool:
"""匹配过滤条件"""
for f in self.filters:
field_value = metadata.get(f['field'])
if f['operator'] == 'eq':
if field_value != f['value']:
return False
elif f['operator'] == 'in':
if field_value not in f['value']:
return False
elif f['operator'] == 'like':
if f['value'] not in str(field_value):
return False
# ... 其他操作符
return True
五、性能优化
5.1 批量检索
# batch_search.py
from typing import List
import numpy as np
class BatchSearcher:
"""批量检索器"""
def __init__(self, index, batch_size: int = 32):
self.index = index
self.batch_size = batch_size
def batch_search(
self,
query_vectors: np.ndarray,
k: int = 10
) -> tuple:
"""
批量搜索
Args:
query_vectors: 查询向量批次
k: 每查询返回数量
Returns:
(距离,索引)
"""
all_distances = []
all_indices = []
# 分批处理
for i in range(0, len(query_vectors), self.batch_size):
batch = query_vectors[i:i + self.batch_size]
distances, indices = self.index.search(batch, k)
all_distances.append(distances)
all_indices.append(indices)
# 合并结果
all_distances = np.vstack(all_distances)
all_indices = np.vstack(all_indices)
return all_distances, all_indices
5.2 缓存优化
# search_cache.py
from typing import Optional, Dict, List
from functools import lru_cache
import hashlib
import numpy as np
class SearchCache:
"""检索缓存"""
def __init__(self, max_size: int = 1000):
self.max_size = max_size
self.cache: Dict[str, List[Dict]] = {}
self.access_order = []
def _compute_key(self, query_vector: np.ndarray) -> str:
"""计算缓存键"""
# 使用向量哈希
vector_bytes = query_vector.tobytes()
return hashlib.md5(vector_bytes).hexdigest()
def get(self, query_vector: np.ndarray) -> Optional[List[Dict]]:
"""获取缓存"""
key = self._compute_key(query_vector)
if key in self.cache:
# 更新访问顺序
if key in self.access_order:
self.access_order.remove(key)
self.access_order.append(key)
return self.cache[key]
return None
def set(
self,
query_vector: np.ndarray,
results: List[Dict]
):
"""设置缓存"""
key = self._compute_key(query_vector)
# 如果缓存已满,移除最久未使用的
if len(self.cache) >= self.max_size:
oldest_key = self.access_order.pop(0)
del self.cache[oldest_key]
self.cache[key] = results
self.access_order.append(key)
def clear(self):
"""清空缓存"""
self.cache.clear()
self.access_order.clear()
六、监控与评估
6.1 检索质量评估
# retrieval_evaluator.py
from typing import List, Dict
import numpy as np
class RetrievalEvaluator:
"""检索评估器"""
def __init__(self):
pass
def evaluate(
self,
queries: List[str],
retrieved_docs: List[List[Dict]],
ground_truth: List[List[str]]
) -> Dict:
"""
评估检索质量
Args:
queries: 查询列表
retrieved_docs: 检索结果
ground_truth: 标准答案
Returns:
评估指标
"""
metrics = {}
# MRR (Mean Reciprocal Rank)
metrics['mrr'] = self._calculate_mrr(retrieved_docs, ground_truth)
# Recall@K
metrics['recall_at_1'] = self._calculate_recall_at_k(retrieved_docs, ground_truth, k=1)
metrics['recall_at_5'] = self._calculate_recall_at_k(retrieved_docs, ground_truth, k=5)
metrics['recall_at_10'] = self._calculate_recall_at_k(retrieved_docs, ground_truth, k=10)
# NDCG@K
metrics['ndcg_at_10'] = self._calculate_ndcg_at_k(retrieved_docs, ground_truth, k=10)
return metrics
def _calculate_mrr(
self,
retrieved_docs: List[List[Dict]],
ground_truth: List[List[str]]
) -> float:
"""计算 MRR"""
reciprocal_ranks = []
for results, truths in zip(retrieved_docs, ground_truth):
rr = 0
for i, doc in enumerate(results):
if doc['id'] in truths:
rr = 1.0 / (i + 1)
break
reciprocal_ranks.append(rr)
return np.mean(reciprocal_ranks)
def _calculate_recall_at_k(
self,
retrieved_docs: List[List[Dict]],
ground_truth: List[List[str]],
k: int
) -> float:
"""计算 Recall@K"""
recalls = []
for results, truths in zip(retrieved_docs, ground_truth):
top_k_docs = results[:k]
retrieved_ids = {doc['id'] for doc in top_k_docs}
truth_ids = set(truths)
if truth_ids:
recall = len(retrieved_ids & truth_ids) / len(truth_ids)
recalls.append(recall)
return np.mean(recalls)
def _calculate_ndcg_at_k(
self,
retrieved_docs: List[List[Dict]],
ground_truth: List[List[str]],
k: int
) -> float:
"""计算 NDCG@K"""
ndcgs = []
for results, truths in zip(retrieved_docs, ground_truth):
top_k_docs = results[:k]
# DCG
dcg = 0
for i, doc in enumerate(top_k_docs):
if doc['id'] in truths:
dcg += 1.0 / np.log2(i + 2)
# IDCG
idcg = 0
for i in range(min(len(truths), k)):
idcg += 1.0 / np.log2(i + 2)
ndcg = dcg / idcg if idcg > 0 else 0
ndcgs.append(ndcg)
return np.mean(ndcgs)
七、总结
7.1 核心要点
-
Embedding 模型选择
- 中文场景:BGE 系列
- 英文场景:OpenAI 或 MiniLM
- 多语言:multilingual-e5
-
索引类型选择
- 小规模:FLAT
- 大规模高准确:HNSW
- 内存受限:IVF_PQ
-
检索优化
- 混合检索提升准确率
- 多路召回提升召回率
- 重排序提升最终质量
7.2 性能基准
| 索引类型 | 构建时间 | 内存占用 | 检索延迟 | 准确率 |
|---|---|---|---|---|
| FLAT | 快 | 高 | 中 | 100% |
| IVF_FLAT | 中 | 中 | 快 | 95-98% |
| HNSW | 慢 | 高 | 最快 | 98-99% |
| IVF_PQ | 快 | 最低 | 快 | 90-95% |
参考资料