Skip to content
清晨的一缕阳光
返回

Agent 记忆系统设计实战

Agent 记忆系统设计实战

记忆是智能 Agent 的核心能力之一。如何让 Agent 记住历史交互?如何高效管理记忆?本文深入解析 Agent 记忆系统的设计与实现。

一、记忆系统架构

1.1 记忆类型

Agent 记忆体系:

┌─────────────────────────────────────┐
│ 短期记忆(Short-term Memory)        │
│ - 当前对话上下文                     │
│ - 容量有限(窗口大小)               │
│ - 快速访问                           │
├─────────────────────────────────────┤
│ 长期记忆(Long-term Memory)         │
│ - 历史对话记录                       │
│ - 用户偏好信息                       │
│ - 容量大,可检索                     │
├─────────────────────────────────────┤
│ 工作记忆(Working Memory)           │
│ - 当前任务状态                       │
│ - 临时计算结果                       │
│ - 任务完成后清除                     │
└─────────────────────────────────────┘

1.2 记忆系统架构

graph LR
    A[用户输入] --> B[短期记忆]
    B --> C[记忆检索]
    C --> D[相关记忆]
    D --> E[上下文构建]
    E --> F[LLM 处理]
    F --> G[响应生成]
    G --> H[记忆存储]
    H --> I[长期记忆]
    H --> J[记忆压缩]

二、短期记忆实现

2.1 滑动窗口记忆

# short_term_memory.py
from typing import List, Dict
from collections import deque
from dataclasses import dataclass

@dataclass
class Message:
    """消息"""
    role: str  # user, assistant, system
    content: str
    timestamp: float

class ShortTermMemory:
    """短期记忆(滑动窗口)"""
    
    def __init__(self, max_messages: int = 20):
        """
        初始化
        
        Args:
            max_messages: 最大消息数
        """
        self.max_messages = max_messages
        self.messages = deque(maxlen=max_messages)
    
    def add(self, role: str, content: str):
        """添加消息"""
        import time
        message = Message(
            role=role,
            content=content,
            timestamp=time.time()
        )
        self.messages.append(message)
    
    def get_context(self) -> List[Dict]:
        """获取上下文"""
        return [
            {'role': m.role, 'content': m.content}
            for m in self.messages
        ]
    
    def get_token_count(self) -> int:
        """估算 Token 数"""
        total = 0
        for m in self.messages:
            total += len(m.content) // 4  # 粗略估算
        return total
    
    def clear(self):
        """清空记忆"""
        self.messages.clear()

# 使用示例
memory = ShortTermMemory(max_messages=20)

# 添加消息
memory.add('system', '你是一个有用的助手')
memory.add('user', '你好')
memory.add('assistant', '你好!有什么可以帮你的吗?')

# 获取上下文
context = memory.get_context()

2.2 重要消息保留

# important_message_memory.py
from typing import List, Dict, Optional
import re

class ImportantMessageMemory:
    """重要消息记忆"""
    
    def __init__(self, llm, max_messages: int = 50):
        self.llm = llm
        self.max_messages = max_messages
        self.all_messages = []
        self.important_indices = set()
    
    def add(self, role: str, content: str):
        """添加消息"""
        import time
        
        self.all_messages.append({
            'role': role,
            'content': content,
            'timestamp': time.time(),
            'importance': self._calculate_importance(role, content)
        })
        
        # 自动标记重要消息
        if len(self.all_messages) >= self.max_messages * 0.8:
            self._prune_unimportant()
    
    def _calculate_importance(self, role: str, content: str) -> float:
        """计算重要性分数"""
        score = 0.5  # 基础分数
        
        # 系统消息重要
        if role == 'system':
            score += 0.3
        
        # 包含特定关键词的消息重要
        important_keywords = [
            '记住', '重要', '关键', '必须',
            '偏好', '喜欢', '不喜欢', '要求'
        ]
        for keyword in important_keywords:
            if keyword in content:
                score += 0.1
        
        # 用户指令重要
        if role == 'user' and ('?' in content or '' in content):
            score += 0.1
        
        return min(1.0, score)
    
    def _prune_unimportant(self):
        """修剪不重要的消息"""
        # 保留重要消息
        important = [
            i for i, msg in enumerate(self.all_messages)
            if msg['importance'] > 0.7 or i in self.important_indices
        ]
        
        # 保留最近的 20 条
        recent = list(range(len(self.all_messages) - 20, len(self.all_messages)))
        
        # 合并保留索引
        keep_indices = set(important) | set(recent)
        
        # 过滤
        self.all_messages = [
            msg for i, msg in enumerate(self.all_messages)
            if i in keep_indices
        ]
        
        # 更新重要索引
        self.important_indices = set(range(len(self.all_messages)))
    
    def get_context(self) -> List[Dict]:
        """获取上下文"""
        return [
            {'role': m['role'], 'content': m['content']}
            for m in self.all_messages
        ]

三、长期记忆实现

3.1 向量记忆存储

# long_term_memory.py
from typing import List, Dict, Optional
import numpy as np
from datetime import datetime

class LongTermMemory:
    """长期记忆(向量存储)"""
    
    def __init__(self, embedding_model, vector_store):
        """
        初始化
        
        Args:
            embedding_model: Embedding 模型
            vector_store: 向量存储(如 FAISS、Chroma)
        """
        self.embedding_model = embedding_model
        self.vector_store = vector_store
        self.metadata_store = {}
    
    def add(
        self,
        content: str,
        metadata: Dict = None
    ):
        """
        添加记忆
        
        Args:
            content: 记忆内容
            metadata: 元数据
        """
        # 生成嵌入
        embedding = self.embedding_model.encode([content])[0]
        
        # 生成 ID
        import uuid
        memory_id = str(uuid.uuid4())
        
        # 存储元数据
        self.metadata_store[memory_id] = {
            'content': content,
            'created_at': datetime.now().isoformat(),
            'access_count': 0,
            'last_accessed': None,
            **(metadata or {})
        }
        
        # 存储向量
        self.vector_store.add(
            ids=[memory_id],
            embeddings=[embedding],
            metadatas=[self.metadata_store[memory_id]]
        )
    
    def search(
        self,
        query: str,
        top_k: int = 5
    ) -> List[Dict]:
        """
        检索记忆
        
        Args:
            query: 查询
            top_k: 返回数量
        
        Returns:
            相关记忆列表
        """
        # 生成查询嵌入
        query_embedding = self.embedding_model.encode([query])[0]
        
        # 向量检索
        results = self.vector_store.search(
            query_embeddings=[query_embedding],
            k=top_k
        )
        
        # 更新访问统计
        for result in results:
            memory_id = result['id']
            if memory_id in self.metadata_store:
                self.metadata_store[memory_id]['access_count'] += 1
                self.metadata_store[memory_id]['last_accessed'] = datetime.now().isoformat()
        
        return [
            {
                'id': r['id'],
                'content': self.metadata_store[r['id']]['content'],
                'score': r['score'],
                'metadata': self.metadata_store[r['id']]
            }
            for r in results
        ]
    
    def search_by_time(
        self,
        start_time: datetime = None,
        end_time: datetime = None,
        limit: int = 100
    ) -> List[Dict]:
        """按时间检索记忆"""
        results = []
        
        for memory_id, metadata in self.metadata_store.items():
            created_at = datetime.fromisoformat(metadata['created_at'])
            
            if start_time and created_at < start_time:
                continue
            if end_time and created_at > end_time:
                continue
            
            results.append({
                'id': memory_id,
                'content': metadata['content'],
                'created_at': created_at
            })
            
            if len(results) >= limit:
                break
        
        return results
    
    def forget(
        self,
        memory_ids: List[str]
    ):
        """删除记忆"""
        for memory_id in memory_ids:
            self.vector_store.delete(ids=[memory_id])
            del self.metadata_store[memory_id]

3.2 记忆分类存储

# categorized_memory.py
from typing import Dict, List

class CategorizedMemory:
    """分类记忆存储"""
    
    def __init__(self):
        self.memories = {
            'user_preferences': [],  # 用户偏好
            'task_context': [],      # 任务上下文
            'factual_knowledge': [], # 事实知识
            'conversation_history': []  # 对话历史
        }
    
    def add(
        self,
        category: str,
        content: str,
        metadata: Dict = None
    ):
        """添加记忆"""
        if category not in self.memories:
            raise ValueError(f"Unknown category: {category}")
        
        self.memories[category].append({
            'content': content,
            'metadata': metadata or {},
            'created_at': datetime.now().isoformat()
        })
    
    def get(
        self,
        category: str,
        limit: int = 10
    ) -> List[Dict]:
        """获取某类记忆"""
        return self.memories.get(category, [])[-limit:]
    
    def search_all(
        self,
        query: str,
        embedding_model,
        vector_stores: Dict[str, any]
    ) -> List[Dict]:
        """跨类别搜索"""
        all_results = []
        
        for category, store in vector_stores.items():
            results = store.search(query, top_k=3)
            for r in results:
                r['category'] = category
            all_results.extend(results)
        
        # 按分数排序
        all_results.sort(key=lambda x: x['score'], reverse=True)
        
        return all_results

四、记忆压缩

4.1 摘要压缩

# memory_compression.py
from typing import List, Dict

class MemoryCompressor:
    """记忆压缩器"""
    
    def __init__(self, llm):
        self.llm = llm
    
    def compress_conversation(
        self,
        messages: List[Dict],
        max_length: int = 500
    ) -> str:
        """
        压缩对话
        
        Args:
            messages: 对话消息列表
            max_length: 最大长度
        """
        # 构建对话文本
        conversation_text = '\n'.join([
            f"{m['role']}: {m['content']}"
            for m in messages
        ])
        
        prompt = f"""
请总结以下对话,保留关键信息:

{conversation_text}

请用简洁的语言总结对话要点,不超过{max_length}字。
"""
        
        summary = self.llm.generate(prompt)
        return summary
    
    def compress_with_hierarchy(
        self,
        messages: List[Dict],
        compression_levels: int = 3
    ) -> Dict:
        """
        分层压缩
        
        Args:
            messages: 消息列表
            compression_levels: 压缩层级
        """
        results = {
            'raw': messages,
            'summaries': []
        }
        
        current_messages = messages
        
        for level in range(compression_levels):
            # 压缩
            summary = self.compress_conversation(
                current_messages,
                max_length=200 * (level + 1)
            )
            
            results['summaries'].append({
                'level': level + 1,
                'summary': summary,
                'original_count': len(current_messages)
            })
            
            # 为下一轮准备
            current_messages = [{
                'role': 'system',
                'content': f"对话摘要(级别{level+1}): {summary}"
            }]
        
        return results

4.2 关键信息提取

# key_information_extraction.py
from typing import List, Dict

class KeyInformationExtractor:
    """关键信息提取器"""
    
    def __init__(self, llm):
        self.llm = llm
    
    def extract_facts(self, conversation: List[Dict]) -> List[Dict]:
        """提取事实信息"""
        conversation_text = '\n'.join([
            f"{m['role']}: {m['content']}"
            for m in conversation
        ])
        
        prompt = f"""
从以下对话中提取事实信息:

{conversation_text}

请提取所有事实性陈述,按以下 JSON 格式输出:
[
    {{"fact": "事实内容", "confidence": 0.0-1.0, "source": "user/assistant"}}
]
"""
        
        response = self.llm.generate(prompt)
        facts = self._parse_json(response)
        
        return facts
    
    def extract_preferences(self, conversation: List[Dict]) -> Dict:
        """提取用户偏好"""
        conversation_text = '\n'.join([
            f"{m['content']}"
            for m in conversation if m['role'] == 'user'
        ])
        
        prompt = f"""
从以下用户输入中提取偏好信息:

{conversation_text}

请提取用户的偏好、喜好、要求等,按以下 JSON 格式输出:
{{
    "preferences": [
        {{"category": "类别", "preference": "偏好内容", "strength": "strong/medium/weak"}}
    ]
}}
"""
        
        response = self.llm.generate(prompt)
        preferences = self._parse_json(response)
        
        return preferences
    
    def _parse_json(self, text: str) -> any:
        """解析 JSON"""
        import json
        import re
        
        # 提取 JSON 部分
        match = re.search(r'\{.*\}|\[.*\]', text, re.DOTALL)
        if match:
            return json.loads(match.group())
        return None

五、记忆检索优化

5.1 混合检索

# hybrid_memory_retrieval.py
from typing import List, Dict

class HybridMemoryRetriever:
    """混合记忆检索器"""
    
    def __init__(
        self,
        vector_retriever,
        keyword_retriever,
        time_retriever,
        weights: Dict[str, float] = None
    ):
        """
        初始化
        
        Args:
            vector_retriever: 向量检索器
            keyword_retriever: 关键词检索器
            time_retriever: 时间检索器
            weights: 权重配置
        """
        self.vector_retriever = vector_retriever
        self.keyword_retriever = keyword_retriever
        self.time_retriever = time_retriever
        self.weights = weights or {
            'vector': 0.5,
            'keyword': 0.3,
            'time': 0.2
        }
    
    def retrieve(
        self,
        query: str,
        top_k: int = 5,
        time_filter: Dict = None
    ) -> List[Dict]:
        """
        混合检索
        
        Args:
            query: 查询
            top_k: 返回数量
            time_filter: 时间过滤
        """
        # 各检索器结果
        vector_results = self.vector_retriever.search(query, top_k=top_k * 2)
        keyword_results = self.keyword_retriever.search(query, top_k=top_k * 2)
        time_results = self.time_retriever.search(time_filter, top_k=top_k * 2)
        
        # 分数归一化
        vector_results = self._normalize_scores(vector_results, 'vector')
        keyword_results = self._normalize_scores(keyword_results, 'keyword')
        time_results = self._normalize_scores(time_results, 'time')
        
        # 合并结果
        all_results = {}
        
        for results, source in [
            (vector_results, 'vector'),
            (keyword_results, 'keyword'),
            (time_results, 'time')
        ]:
            for r in results:
                memory_id = r['id']
                if memory_id not in all_results:
                    all_results[memory_id] = {
                        'id': memory_id,
                        'content': r['content'],
                        'scores': {}
                    }
                all_results[memory_id]['scores'][source] = r['score']
        
        # 计算加权分数
        for memory in all_results.values():
            weighted_score = 0
            for source, weight in self.weights.items():
                weighted_score += memory['scores'].get(source, 0) * weight
            memory['combined_score'] = weighted_score
        
        # 排序
        sorted_results = sorted(
            all_results.values(),
            key=lambda x: x['combined_score'],
            reverse=True
        )
        
        return sorted_results[:top_k]
    
    def _normalize_scores(
        self,
        results: List[Dict],
        source: str
    ) -> List[Dict]:
        """分数归一化"""
        if not results:
            return []
        
        scores = [r['score'] for r in results]
        min_score = min(scores)
        max_score = max(scores)
        
        if max_score == min_score:
            for r in results:
                r['normalized_score'] = 0.5
        else:
            for r in results:
                r['normalized_score'] = (
                    (r['score'] - min_score) / (max_score - min_score)
                )
        
        return results

5.2 记忆关联检索

# associative_memory_retrieval.py
from typing import List, Dict, Set

class AssociativeMemoryRetriever:
    """关联记忆检索"""
    
    def __init__(self, base_retriever, memory_graph):
        """
        初始化
        
        Args:
            base_retriever: 基础检索器
            memory_graph: 记忆图(节点为记忆,边为关联)
        """
        self.base_retriever = base_retriever
        self.memory_graph = memory_graph
    
    def retrieve_with_association(
        self,
        query: str,
        top_k: int = 5,
        association_depth: int = 2
    ) -> List[Dict]:
        """
        带关联的检索
        
        Args:
            query: 查询
            top_k: 返回数量
            association_depth: 关联深度
        """
        # 基础检索
        initial_results = self.base_retriever.search(query, top_k=top_k)
        
        # 关联扩展
        expanded_results = self._expand_with_associations(
            initial_results,
            association_depth
        )
        
        # 去重和排序
        unique_results = {r['id']: r for r in expanded_results}
        sorted_results = sorted(
            unique_results.values(),
            key=lambda x: x.get('relevance_score', 0),
            reverse=True
        )
        
        return sorted_results[:top_k]
    
    def _expand_with_associations(
        self,
        results: List[Dict],
        depth: int
    ) -> List[Dict]:
        """关联扩展"""
        expanded = list(results)
        visited = {r['id'] for r in results}
        
        for current_depth in range(depth):
            new_memories = []
            
            for result in results:
                memory_id = result['id']
                
                # 获取关联记忆
                associations = self.memory_graph.get_associations(
                    memory_id,
                    top_k=3
                )
                
                for assoc in associations:
                    if assoc['id'] not in visited:
                        assoc['relevance_score'] = result.get('score', 0) * 0.8
                        new_memories.append(assoc)
                        visited.add(assoc['id'])
            
            expanded.extend(new_memories)
            results = new_memories
        
        return expanded

六、实战案例

6.1 完整记忆系统

# complete_memory_system.py
class CompleteMemorySystem:
    """完整记忆系统"""
    
    def __init__(self, config: Dict):
        """
        初始化
        
        Args:
            config: 配置字典
        """
        # 短期记忆
        self.short_term = ShortTermMemory(
            max_messages=config.get('short_term_max', 20)
        )
        
        # 长期记忆
        self.long_term = LongTermMemory(
            embedding_model=config['embedding_model'],
            vector_store=config['vector_store']
        )
        
        # 压缩器
        self.compressor = MemoryCompressor(config['llm'])
        
        # 提取器
        self.extractor = KeyInformationExtractor(config['llm'])
        
        # 检索器
        self.retriever = HybridMemoryRetriever(
            vector_retriever=self.long_term,
            keyword_retriever=config['keyword_retriever'],
            time_retriever=self.long_term
        )
    
    def process_interaction(
        self,
        user_input: str,
        assistant_response: str
    ):
        """处理交互"""
        # 添加到短期记忆
        self.short_term.add('user', user_input)
        self.short_term.add('assistant', assistant_response)
        
        # 提取关键信息
        conversation = self.short_term.get_context()
        
        # 提取事实
        facts = self.extractor.extract_facts(conversation)
        for fact in facts:
            self.long_term.add(
                fact['fact'],
                metadata={'type': 'fact', 'confidence': fact['confidence']}
            )
        
        # 提取偏好
        preferences = self.extractor.extract_preferences(conversation)
        for pref in preferences.get('preferences', []):
            self.long_term.add(
                pref['preference'],
                metadata={'type': 'preference', 'category': pref['category']}
            )
        
        # 检查是否需要压缩
        if self.short_term.get_token_count() > 3000:
            self._compress_and_archive()
    
    def _compress_and_archive(self):
        """压缩并归档"""
        # 获取当前对话
        messages = self.short_term.get_context()
        
        # 压缩
        summary = self.compressor.compress_conversation(messages)
        
        # 归档到长期记忆
        self.long_term.add(
            summary,
            metadata={'type': 'conversation_summary'}
        )
        
        # 清空短期记忆(保留最近 5 条)
        recent = list(self.short_term.messages)[-5:]
        self.short_term.clear()
        for msg in recent:
            self.short_term.add(msg.role, msg.content)
    
    def build_context(
        self,
        current_query: str,
        max_context_tokens: int = 3000
    ) -> str:
        """构建上下文"""
        # 检索相关记忆
        relevant_memories = self.retriever.retrieve(
            current_query,
            top_k=5
        )
        
        # 获取短期记忆
        short_term_context = self.short_term.get_context()
        
        # 构建完整上下文
        context_parts = []
        
        # 相关记忆
        if relevant_memories:
            context_parts.append("【相关记忆】")
            for mem in relevant_memories:
                context_parts.append(f"- {mem['content']}")
        
        # 短期记忆
        if short_term_context:
            context_parts.append("\n【最近对话】")
            for msg in short_term_context:
                context_parts.append(f"{msg['role']}: {msg['content']}")
        
        full_context = '\n'.join(context_parts)
        
        # 检查 Token 数
        if len(full_context) // 4 > max_context_tokens:
            # 需要截断
            full_context = full_context[:max_context_tokens * 4]
        
        return full_context

七、总结

7.1 核心要点

  1. 记忆分层

    • 短期记忆:快速访问,容量有限
    • 长期记忆:持久存储,可检索
    • 工作记忆:任务临时状态
  2. 记忆压缩

    • 摘要压缩:减少存储
    • 关键信息提取:保留要点
    • 分层存储:多级粒度
  3. 检索优化

    • 混合检索:向量 + 关键词 + 时间
    • 关联检索:图结构扩展
    • 排序融合:加权评分

7.2 最佳实践

  1. 及时压缩

    • 对话达到阈值时压缩
    • 保留重要信息
    • 归档历史摘要
  2. 智能检索

    • 根据查询选择检索策略
    • 平衡相关性和时效性
    • 去重和排序
  3. 持续优化

    • 监控记忆质量
    • 调整压缩策略
    • 优化检索参数

参考资料


分享这篇文章到:

上一篇文章
RocketMQ 容量规划与性能优化实战
下一篇文章
RocketMQ Controller 控制器详解与高可用