Agent 记忆系统设计实战
记忆是智能 Agent 的核心能力之一。如何让 Agent 记住历史交互?如何高效管理记忆?本文深入解析 Agent 记忆系统的设计与实现。
一、记忆系统架构
1.1 记忆类型
Agent 记忆体系:
┌─────────────────────────────────────┐
│ 短期记忆(Short-term Memory) │
│ - 当前对话上下文 │
│ - 容量有限(窗口大小) │
│ - 快速访问 │
├─────────────────────────────────────┤
│ 长期记忆(Long-term Memory) │
│ - 历史对话记录 │
│ - 用户偏好信息 │
│ - 容量大,可检索 │
├─────────────────────────────────────┤
│ 工作记忆(Working Memory) │
│ - 当前任务状态 │
│ - 临时计算结果 │
│ - 任务完成后清除 │
└─────────────────────────────────────┘
1.2 记忆系统架构
graph LR
A[用户输入] --> B[短期记忆]
B --> C[记忆检索]
C --> D[相关记忆]
D --> E[上下文构建]
E --> F[LLM 处理]
F --> G[响应生成]
G --> H[记忆存储]
H --> I[长期记忆]
H --> J[记忆压缩]
二、短期记忆实现
2.1 滑动窗口记忆
# short_term_memory.py
from typing import List, Dict
from collections import deque
from dataclasses import dataclass
@dataclass
class Message:
"""消息"""
role: str # user, assistant, system
content: str
timestamp: float
class ShortTermMemory:
"""短期记忆(滑动窗口)"""
def __init__(self, max_messages: int = 20):
"""
初始化
Args:
max_messages: 最大消息数
"""
self.max_messages = max_messages
self.messages = deque(maxlen=max_messages)
def add(self, role: str, content: str):
"""添加消息"""
import time
message = Message(
role=role,
content=content,
timestamp=time.time()
)
self.messages.append(message)
def get_context(self) -> List[Dict]:
"""获取上下文"""
return [
{'role': m.role, 'content': m.content}
for m in self.messages
]
def get_token_count(self) -> int:
"""估算 Token 数"""
total = 0
for m in self.messages:
total += len(m.content) // 4 # 粗略估算
return total
def clear(self):
"""清空记忆"""
self.messages.clear()
# 使用示例
memory = ShortTermMemory(max_messages=20)
# 添加消息
memory.add('system', '你是一个有用的助手')
memory.add('user', '你好')
memory.add('assistant', '你好!有什么可以帮你的吗?')
# 获取上下文
context = memory.get_context()
2.2 重要消息保留
# important_message_memory.py
from typing import List, Dict, Optional
import re
class ImportantMessageMemory:
"""重要消息记忆"""
def __init__(self, llm, max_messages: int = 50):
self.llm = llm
self.max_messages = max_messages
self.all_messages = []
self.important_indices = set()
def add(self, role: str, content: str):
"""添加消息"""
import time
self.all_messages.append({
'role': role,
'content': content,
'timestamp': time.time(),
'importance': self._calculate_importance(role, content)
})
# 自动标记重要消息
if len(self.all_messages) >= self.max_messages * 0.8:
self._prune_unimportant()
def _calculate_importance(self, role: str, content: str) -> float:
"""计算重要性分数"""
score = 0.5 # 基础分数
# 系统消息重要
if role == 'system':
score += 0.3
# 包含特定关键词的消息重要
important_keywords = [
'记住', '重要', '关键', '必须',
'偏好', '喜欢', '不喜欢', '要求'
]
for keyword in important_keywords:
if keyword in content:
score += 0.1
# 用户指令重要
if role == 'user' and ('?' in content or '请' in content):
score += 0.1
return min(1.0, score)
def _prune_unimportant(self):
"""修剪不重要的消息"""
# 保留重要消息
important = [
i for i, msg in enumerate(self.all_messages)
if msg['importance'] > 0.7 or i in self.important_indices
]
# 保留最近的 20 条
recent = list(range(len(self.all_messages) - 20, len(self.all_messages)))
# 合并保留索引
keep_indices = set(important) | set(recent)
# 过滤
self.all_messages = [
msg for i, msg in enumerate(self.all_messages)
if i in keep_indices
]
# 更新重要索引
self.important_indices = set(range(len(self.all_messages)))
def get_context(self) -> List[Dict]:
"""获取上下文"""
return [
{'role': m['role'], 'content': m['content']}
for m in self.all_messages
]
三、长期记忆实现
3.1 向量记忆存储
# long_term_memory.py
from typing import List, Dict, Optional
import numpy as np
from datetime import datetime
class LongTermMemory:
"""长期记忆(向量存储)"""
def __init__(self, embedding_model, vector_store):
"""
初始化
Args:
embedding_model: Embedding 模型
vector_store: 向量存储(如 FAISS、Chroma)
"""
self.embedding_model = embedding_model
self.vector_store = vector_store
self.metadata_store = {}
def add(
self,
content: str,
metadata: Dict = None
):
"""
添加记忆
Args:
content: 记忆内容
metadata: 元数据
"""
# 生成嵌入
embedding = self.embedding_model.encode([content])[0]
# 生成 ID
import uuid
memory_id = str(uuid.uuid4())
# 存储元数据
self.metadata_store[memory_id] = {
'content': content,
'created_at': datetime.now().isoformat(),
'access_count': 0,
'last_accessed': None,
**(metadata or {})
}
# 存储向量
self.vector_store.add(
ids=[memory_id],
embeddings=[embedding],
metadatas=[self.metadata_store[memory_id]]
)
def search(
self,
query: str,
top_k: int = 5
) -> List[Dict]:
"""
检索记忆
Args:
query: 查询
top_k: 返回数量
Returns:
相关记忆列表
"""
# 生成查询嵌入
query_embedding = self.embedding_model.encode([query])[0]
# 向量检索
results = self.vector_store.search(
query_embeddings=[query_embedding],
k=top_k
)
# 更新访问统计
for result in results:
memory_id = result['id']
if memory_id in self.metadata_store:
self.metadata_store[memory_id]['access_count'] += 1
self.metadata_store[memory_id]['last_accessed'] = datetime.now().isoformat()
return [
{
'id': r['id'],
'content': self.metadata_store[r['id']]['content'],
'score': r['score'],
'metadata': self.metadata_store[r['id']]
}
for r in results
]
def search_by_time(
self,
start_time: datetime = None,
end_time: datetime = None,
limit: int = 100
) -> List[Dict]:
"""按时间检索记忆"""
results = []
for memory_id, metadata in self.metadata_store.items():
created_at = datetime.fromisoformat(metadata['created_at'])
if start_time and created_at < start_time:
continue
if end_time and created_at > end_time:
continue
results.append({
'id': memory_id,
'content': metadata['content'],
'created_at': created_at
})
if len(results) >= limit:
break
return results
def forget(
self,
memory_ids: List[str]
):
"""删除记忆"""
for memory_id in memory_ids:
self.vector_store.delete(ids=[memory_id])
del self.metadata_store[memory_id]
3.2 记忆分类存储
# categorized_memory.py
from typing import Dict, List
class CategorizedMemory:
"""分类记忆存储"""
def __init__(self):
self.memories = {
'user_preferences': [], # 用户偏好
'task_context': [], # 任务上下文
'factual_knowledge': [], # 事实知识
'conversation_history': [] # 对话历史
}
def add(
self,
category: str,
content: str,
metadata: Dict = None
):
"""添加记忆"""
if category not in self.memories:
raise ValueError(f"Unknown category: {category}")
self.memories[category].append({
'content': content,
'metadata': metadata or {},
'created_at': datetime.now().isoformat()
})
def get(
self,
category: str,
limit: int = 10
) -> List[Dict]:
"""获取某类记忆"""
return self.memories.get(category, [])[-limit:]
def search_all(
self,
query: str,
embedding_model,
vector_stores: Dict[str, any]
) -> List[Dict]:
"""跨类别搜索"""
all_results = []
for category, store in vector_stores.items():
results = store.search(query, top_k=3)
for r in results:
r['category'] = category
all_results.extend(results)
# 按分数排序
all_results.sort(key=lambda x: x['score'], reverse=True)
return all_results
四、记忆压缩
4.1 摘要压缩
# memory_compression.py
from typing import List, Dict
class MemoryCompressor:
"""记忆压缩器"""
def __init__(self, llm):
self.llm = llm
def compress_conversation(
self,
messages: List[Dict],
max_length: int = 500
) -> str:
"""
压缩对话
Args:
messages: 对话消息列表
max_length: 最大长度
"""
# 构建对话文本
conversation_text = '\n'.join([
f"{m['role']}: {m['content']}"
for m in messages
])
prompt = f"""
请总结以下对话,保留关键信息:
{conversation_text}
请用简洁的语言总结对话要点,不超过{max_length}字。
"""
summary = self.llm.generate(prompt)
return summary
def compress_with_hierarchy(
self,
messages: List[Dict],
compression_levels: int = 3
) -> Dict:
"""
分层压缩
Args:
messages: 消息列表
compression_levels: 压缩层级
"""
results = {
'raw': messages,
'summaries': []
}
current_messages = messages
for level in range(compression_levels):
# 压缩
summary = self.compress_conversation(
current_messages,
max_length=200 * (level + 1)
)
results['summaries'].append({
'level': level + 1,
'summary': summary,
'original_count': len(current_messages)
})
# 为下一轮准备
current_messages = [{
'role': 'system',
'content': f"对话摘要(级别{level+1}): {summary}"
}]
return results
4.2 关键信息提取
# key_information_extraction.py
from typing import List, Dict
class KeyInformationExtractor:
"""关键信息提取器"""
def __init__(self, llm):
self.llm = llm
def extract_facts(self, conversation: List[Dict]) -> List[Dict]:
"""提取事实信息"""
conversation_text = '\n'.join([
f"{m['role']}: {m['content']}"
for m in conversation
])
prompt = f"""
从以下对话中提取事实信息:
{conversation_text}
请提取所有事实性陈述,按以下 JSON 格式输出:
[
{{"fact": "事实内容", "confidence": 0.0-1.0, "source": "user/assistant"}}
]
"""
response = self.llm.generate(prompt)
facts = self._parse_json(response)
return facts
def extract_preferences(self, conversation: List[Dict]) -> Dict:
"""提取用户偏好"""
conversation_text = '\n'.join([
f"{m['content']}"
for m in conversation if m['role'] == 'user'
])
prompt = f"""
从以下用户输入中提取偏好信息:
{conversation_text}
请提取用户的偏好、喜好、要求等,按以下 JSON 格式输出:
{{
"preferences": [
{{"category": "类别", "preference": "偏好内容", "strength": "strong/medium/weak"}}
]
}}
"""
response = self.llm.generate(prompt)
preferences = self._parse_json(response)
return preferences
def _parse_json(self, text: str) -> any:
"""解析 JSON"""
import json
import re
# 提取 JSON 部分
match = re.search(r'\{.*\}|\[.*\]', text, re.DOTALL)
if match:
return json.loads(match.group())
return None
五、记忆检索优化
5.1 混合检索
# hybrid_memory_retrieval.py
from typing import List, Dict
class HybridMemoryRetriever:
"""混合记忆检索器"""
def __init__(
self,
vector_retriever,
keyword_retriever,
time_retriever,
weights: Dict[str, float] = None
):
"""
初始化
Args:
vector_retriever: 向量检索器
keyword_retriever: 关键词检索器
time_retriever: 时间检索器
weights: 权重配置
"""
self.vector_retriever = vector_retriever
self.keyword_retriever = keyword_retriever
self.time_retriever = time_retriever
self.weights = weights or {
'vector': 0.5,
'keyword': 0.3,
'time': 0.2
}
def retrieve(
self,
query: str,
top_k: int = 5,
time_filter: Dict = None
) -> List[Dict]:
"""
混合检索
Args:
query: 查询
top_k: 返回数量
time_filter: 时间过滤
"""
# 各检索器结果
vector_results = self.vector_retriever.search(query, top_k=top_k * 2)
keyword_results = self.keyword_retriever.search(query, top_k=top_k * 2)
time_results = self.time_retriever.search(time_filter, top_k=top_k * 2)
# 分数归一化
vector_results = self._normalize_scores(vector_results, 'vector')
keyword_results = self._normalize_scores(keyword_results, 'keyword')
time_results = self._normalize_scores(time_results, 'time')
# 合并结果
all_results = {}
for results, source in [
(vector_results, 'vector'),
(keyword_results, 'keyword'),
(time_results, 'time')
]:
for r in results:
memory_id = r['id']
if memory_id not in all_results:
all_results[memory_id] = {
'id': memory_id,
'content': r['content'],
'scores': {}
}
all_results[memory_id]['scores'][source] = r['score']
# 计算加权分数
for memory in all_results.values():
weighted_score = 0
for source, weight in self.weights.items():
weighted_score += memory['scores'].get(source, 0) * weight
memory['combined_score'] = weighted_score
# 排序
sorted_results = sorted(
all_results.values(),
key=lambda x: x['combined_score'],
reverse=True
)
return sorted_results[:top_k]
def _normalize_scores(
self,
results: List[Dict],
source: str
) -> List[Dict]:
"""分数归一化"""
if not results:
return []
scores = [r['score'] for r in results]
min_score = min(scores)
max_score = max(scores)
if max_score == min_score:
for r in results:
r['normalized_score'] = 0.5
else:
for r in results:
r['normalized_score'] = (
(r['score'] - min_score) / (max_score - min_score)
)
return results
5.2 记忆关联检索
# associative_memory_retrieval.py
from typing import List, Dict, Set
class AssociativeMemoryRetriever:
"""关联记忆检索"""
def __init__(self, base_retriever, memory_graph):
"""
初始化
Args:
base_retriever: 基础检索器
memory_graph: 记忆图(节点为记忆,边为关联)
"""
self.base_retriever = base_retriever
self.memory_graph = memory_graph
def retrieve_with_association(
self,
query: str,
top_k: int = 5,
association_depth: int = 2
) -> List[Dict]:
"""
带关联的检索
Args:
query: 查询
top_k: 返回数量
association_depth: 关联深度
"""
# 基础检索
initial_results = self.base_retriever.search(query, top_k=top_k)
# 关联扩展
expanded_results = self._expand_with_associations(
initial_results,
association_depth
)
# 去重和排序
unique_results = {r['id']: r for r in expanded_results}
sorted_results = sorted(
unique_results.values(),
key=lambda x: x.get('relevance_score', 0),
reverse=True
)
return sorted_results[:top_k]
def _expand_with_associations(
self,
results: List[Dict],
depth: int
) -> List[Dict]:
"""关联扩展"""
expanded = list(results)
visited = {r['id'] for r in results}
for current_depth in range(depth):
new_memories = []
for result in results:
memory_id = result['id']
# 获取关联记忆
associations = self.memory_graph.get_associations(
memory_id,
top_k=3
)
for assoc in associations:
if assoc['id'] not in visited:
assoc['relevance_score'] = result.get('score', 0) * 0.8
new_memories.append(assoc)
visited.add(assoc['id'])
expanded.extend(new_memories)
results = new_memories
return expanded
六、实战案例
6.1 完整记忆系统
# complete_memory_system.py
class CompleteMemorySystem:
"""完整记忆系统"""
def __init__(self, config: Dict):
"""
初始化
Args:
config: 配置字典
"""
# 短期记忆
self.short_term = ShortTermMemory(
max_messages=config.get('short_term_max', 20)
)
# 长期记忆
self.long_term = LongTermMemory(
embedding_model=config['embedding_model'],
vector_store=config['vector_store']
)
# 压缩器
self.compressor = MemoryCompressor(config['llm'])
# 提取器
self.extractor = KeyInformationExtractor(config['llm'])
# 检索器
self.retriever = HybridMemoryRetriever(
vector_retriever=self.long_term,
keyword_retriever=config['keyword_retriever'],
time_retriever=self.long_term
)
def process_interaction(
self,
user_input: str,
assistant_response: str
):
"""处理交互"""
# 添加到短期记忆
self.short_term.add('user', user_input)
self.short_term.add('assistant', assistant_response)
# 提取关键信息
conversation = self.short_term.get_context()
# 提取事实
facts = self.extractor.extract_facts(conversation)
for fact in facts:
self.long_term.add(
fact['fact'],
metadata={'type': 'fact', 'confidence': fact['confidence']}
)
# 提取偏好
preferences = self.extractor.extract_preferences(conversation)
for pref in preferences.get('preferences', []):
self.long_term.add(
pref['preference'],
metadata={'type': 'preference', 'category': pref['category']}
)
# 检查是否需要压缩
if self.short_term.get_token_count() > 3000:
self._compress_and_archive()
def _compress_and_archive(self):
"""压缩并归档"""
# 获取当前对话
messages = self.short_term.get_context()
# 压缩
summary = self.compressor.compress_conversation(messages)
# 归档到长期记忆
self.long_term.add(
summary,
metadata={'type': 'conversation_summary'}
)
# 清空短期记忆(保留最近 5 条)
recent = list(self.short_term.messages)[-5:]
self.short_term.clear()
for msg in recent:
self.short_term.add(msg.role, msg.content)
def build_context(
self,
current_query: str,
max_context_tokens: int = 3000
) -> str:
"""构建上下文"""
# 检索相关记忆
relevant_memories = self.retriever.retrieve(
current_query,
top_k=5
)
# 获取短期记忆
short_term_context = self.short_term.get_context()
# 构建完整上下文
context_parts = []
# 相关记忆
if relevant_memories:
context_parts.append("【相关记忆】")
for mem in relevant_memories:
context_parts.append(f"- {mem['content']}")
# 短期记忆
if short_term_context:
context_parts.append("\n【最近对话】")
for msg in short_term_context:
context_parts.append(f"{msg['role']}: {msg['content']}")
full_context = '\n'.join(context_parts)
# 检查 Token 数
if len(full_context) // 4 > max_context_tokens:
# 需要截断
full_context = full_context[:max_context_tokens * 4]
return full_context
七、总结
7.1 核心要点
-
记忆分层
- 短期记忆:快速访问,容量有限
- 长期记忆:持久存储,可检索
- 工作记忆:任务临时状态
-
记忆压缩
- 摘要压缩:减少存储
- 关键信息提取:保留要点
- 分层存储:多级粒度
-
检索优化
- 混合检索:向量 + 关键词 + 时间
- 关联检索:图结构扩展
- 排序融合:加权评分
7.2 最佳实践
-
及时压缩
- 对话达到阈值时压缩
- 保留重要信息
- 归档历史摘要
-
智能检索
- 根据查询选择检索策略
- 平衡相关性和时效性
- 去重和排序
-
持续优化
- 监控记忆质量
- 调整压缩策略
- 优化检索参数
参考资料