Skip to content
清晨的一缕阳光
返回

Agent 反思与自修正机制

Agent 反思与自修正机制

自我反思和修正是高级 Agent 的核心能力。如何让 Agent 发现并纠正自己的错误?如何实现持续改进?本文深入解析 Agent 反思与自修正机制。

一、反思能力基础

1.1 反思层次

Agent 反思层次:

┌─────────────────────────────────────┐
│ 结果反思(Outcome Reflection)       │
│ - 评估最终结果质量                   │
│ - 对比预期与实际                     │
│ - 总结成功与失败                     │
├─────────────────────────────────────┤
│ 过程反思(Process Reflection)       │
│ - 回顾执行过程                       │
│ - 识别关键决策点                     │
│ - 分析决策合理性                     │
├─────────────────────────────────────┤
│ 元认知反思(Metacognitive)          │
│ - 反思思考方式                       │
│ - 评估认知局限                       │
│ - 改进思维模式                       │
└─────────────────────────────────────┘

1.2 反思触发条件

触发条件说明示例
任务完成常规反思每次任务后总结
错误检测被动反思发现错误后分析
低置信度预防反思结果不确定时验证
用户反馈外部触发根据反馈调整
周期性定时反思定期回顾总结

二、错误检测机制

2.1 一致性检查

# consistency_checker.py
from typing import Dict, List, Optional

class ConsistencyChecker:
    """一致性检查器"""
    
    def __init__(self, llm):
        self.llm = llm
    
    def check_internal_consistency(
        self,
        response: str
    ) -> Dict:
        """
        检查内部一致性
        
        Args:
            response: Agent 响应
        
        Returns:
            检查结果
        """
        # 提取关键陈述
        statements = self._extract_statements(response)
        
        # 检查陈述间是否矛盾
        contradictions = []
        
        for i, stmt1 in enumerate(statements):
            for stmt2 in statements[i+1:]:
                if self._is_contradictory(stmt1, stmt2):
                    contradictions.append({
                        'statement1': stmt1,
                        'statement2': stmt2,
                        'conflict_type': self._classify_conflict(stmt1, stmt2)
                    })
        
        return {
            'consistent': len(contradictions) == 0,
            'contradictions': contradictions,
            'statement_count': len(statements)
        }
    
    def check_external_consistency(
        self,
        response: str,
        context: Dict
    ) -> Dict:
        """
        检查外部一致性(与已知事实)
        
        Args:
            response: Agent 响应
            context: 上下文信息
        
        Returns:
            检查结果
        """
        statements = self._extract_statements(response)
        inconsistencies = []
        
        for stmt in statements:
            # 检查与上下文的冲突
            if 'context_facts' in context:
                for fact in context['context_facts']:
                    if self._contradicts_fact(stmt, fact):
                        inconsistencies.append({
                            'statement': stmt,
                            'conflicting_fact': fact,
                            'severity': 'high'
                        })
        
        return {
            'consistent': len(inconsistencies) == 0,
            'inconsistencies': inconsistencies
        }
    
    def _extract_statements(self, text: str) -> List[str]:
        """提取陈述"""
        import re
        
        # 简单实现:按句子分割
        sentences = re.split(r'[.!?。!?]+', text)
        return [s.strip() for s in sentences if len(s.strip()) > 10]
    
    def _is_contradictory(
        self,
        stmt1: str,
        stmt2: str
    ) -> bool:
        """检查是否矛盾"""
        # 使用 LLM 判断
        prompt = f"""
判断以下两个陈述是否矛盾:

陈述 1: {stmt1}
陈述 2: {stmt2}

如果矛盾,请回答"是";否则回答"否"。
"""
        response = self.llm.generate(prompt)
        return '' in response
    
    def _classify_conflict(
        self,
        stmt1: str,
        stmt2: str
    ) -> str:
        """分类冲突类型"""
        # 数值冲突、时间冲突、逻辑冲突等
        return 'logical'
    
    def _contradicts_fact(self, statement: str, fact: str) -> bool:
        """检查是否与事实矛盾"""
        prompt = f"""
判断以下陈述是否与已知事实矛盾:

陈述:{statement}
事实:{fact}

如果矛盾,请回答"是";否则回答"否"。
"""
        response = self.llm.generate(prompt)
        return '' in response

2.2 质量评估器

# quality_evaluator.py
from typing import Dict, List

class QualityEvaluator:
    """质量评估器"""
    
    def __init__(self, llm):
        self.llm = llm
        self.criteria = {
            'accuracy': '准确性',
            'completeness': '完整性',
            'consistency': '一致性',
            'relevance': '相关性',
            'clarity': '清晰性'
        }
    
    def evaluate(
        self,
        response: str,
        query: str,
        context: Dict = None
    ) -> Dict:
        """
        评估响应质量
        
        Args:
            response: Agent 响应
            query: 原始查询
            context: 上下文
        
        Returns:
            评估结果
        """
        scores = {}
        feedback = {}
        
        for criterion, name in self.criteria.items():
            score, comment = self._evaluate_criterion(
                criterion,
                response,
                query,
                context
            )
            scores[criterion] = score
            feedback[criterion] = comment
        
        # 综合评分
        overall_score = sum(scores.values()) / len(scores)
        
        return {
            'overall_score': overall_score,
            'criterion_scores': scores,
            'feedback': feedback,
            'needs_revision': overall_score < 0.7
        }
    
    def _evaluate_criterion(
        self,
        criterion: str,
        response: str,
        query: str,
        context: Dict
    ) -> tuple:
        """评估单个标准"""
        prompt = f"""
请评估以下响应的{self.criteria[criterion]}(0-1 分):

查询:{query}
响应:{response}

请给出分数(0-1 之间的小数)和简短评语。
格式:分数|评语
"""
        response_text = self.llm.generate(prompt)
        
        try:
            score_str, comment = response_text.split('|', 1)
            score = float(score_str.strip())
            score = max(0, min(1, score))
        except:
            score = 0.5
            comment = response_text
        
        return score, comment

三、自修正机制

3.1 迭代修正

# iterative_refinement.py
from typing import Dict, List, Optional

class IterativeRefiner:
    """迭代修正器"""
    
    def __init__(
        self,
        llm,
        max_iterations: int = 3,
        min_improvement: float = 0.1
    ):
        """
        初始化
        
        Args:
            llm: LLM 模型
            max_iterations: 最大迭代次数
            min_improvement: 最小改进幅度
        """
        self.llm = llm
        self.max_iterations = max_iterations
        self.min_improvement = min_improvement
        self.evaluator = QualityEvaluator(llm)
    
    def refine(
        self,
        initial_response: str,
        query: str,
        context: Dict = None
    ) -> Dict:
        """
        迭代修正
        
        Args:
            initial_response: 初始响应
            query: 查询
            context: 上下文
        
        Returns:
            修正结果
        """
        current_response = initial_response
        history = []
        
        for iteration in range(self.max_iterations):
            # 评估当前响应
            evaluation = self.evaluator.evaluate(
                current_response,
                query,
                context
            )
            
            history.append({
                'iteration': iteration,
                'response': current_response,
                'evaluation': evaluation
            })
            
            # 检查是否需要继续修正
            if not evaluation['needs_revision']:
                break
            
            if iteration > 0:
                prev_score = history[-2]['evaluation']['overall_score']
                improvement = evaluation['overall_score'] - prev_score
                
                if improvement < self.min_improvement:
                    # 改进不足,停止
                    break
            
            # 生成修正版本
            current_response = self._generate_revision(
                current_response,
                query,
                evaluation['feedback'],
                context
            )
        
        # 选择最佳版本
        best_version = max(
            history,
            key=lambda x: x['evaluation']['overall_score']
        )
        
        return {
            'final_response': best_version['response'],
            'iterations': len(history),
            'final_score': best_version['evaluation']['overall_score'],
            'history': history
        }
    
    def _generate_revision(
        self,
        current_response: str,
        query: str,
        feedback: Dict,
        context: Dict
    ) -> str:
        """生成修正版本"""
        feedback_text = '\n'.join([
            f"- {k}: {v}"
            for k, v in feedback.items()
        ])
        
        prompt = f"""
请根据以下反馈改进响应:

查询:{query}
当前响应:
{current_response}

改进建议:
{feedback_text}

请综合以上信息,给出改进后的完整响应。
"""
        return self.llm.generate(prompt)

3.2 自我批评

# self_critique.py
from typing import Dict, List

class SelfCritic:
    """自我批评器"""
    
    def __init__(self, llm):
        self.llm = llm
    
    def critique(
        self,
        response: str,
        query: str,
        criteria: List[str] = None
    ) -> Dict:
        """
        自我批评
        
        Args:
            response: 响应
            query: 查询
            criteria: 批评标准
        
        Returns:
            批评结果
        """
        criteria = criteria or [
            '逻辑错误',
            '事实错误',
            '遗漏信息',
            '表达不清',
            '改进建议'
        ]
        
        critique_result = {}
        
        for criterion in criteria:
            critique = self._critique_criterion(
                criterion,
                response,
                query
            )
            critique_result[criterion] = critique
        
        # 生成改进计划
        improvement_plan = self._generate_improvement_plan(
            response,
            critique_result
        )
        
        return {
            'critique': critique_result,
            'improvement_plan': improvement_plan
        }
    
    def _critique_criterion(
        self,
        criterion: str,
        response: str,
        query: str
    ) -> Dict:
        """批评单个标准"""
        prompt = f"""
请从"{criterion}"角度批评以下响应:

查询:{query}
响应:{response}

请指出存在的问题,按以下 JSON 格式输出:
{{
    "issues": [
        {{"description": "问题描述", "severity": "high/medium/low", "suggestion": "改进建议"}}
    ],
    "score": 0-10
}}
"""
        response_text = self.llm.generate(prompt)
        return self._parse_json(response_text)
    
    def _generate_improvement_plan(
        self,
        response: str,
        critique_result: Dict
    ) -> List[Dict]:
        """生成改进计划"""
        all_issues = []
        
        for criterion_critique in critique_result.values():
            if isinstance(criterion_critique, dict):
                all_issues.extend(criterion_critique.get('issues', []))
        
        # 按严重程度排序
        severity_order = {'high': 0, 'medium': 1, 'low': 2}
        all_issues.sort(
            key=lambda x: severity_order.get(x.get('severity', 'low'), 2)
        )
        
        return all_issues
    
    def _parse_json(self, text: str) -> Dict:
        """解析 JSON"""
        import json
        import re
        
        match = re.search(r'\{.*\}', text, re.DOTALL)
        if match:
            return json.loads(match.group())
        return {}

四、反思学习

4.1 经验积累

# experience_learning.py
from typing import Dict, List
from datetime import datetime

class ExperienceLearner:
    """经验学习器"""
    
    def __init__(self):
        self.experiences: List[Dict] = []
        self.patterns: Dict = {}
    
    def record_experience(
        self,
        task: str,
        approach: str,
        outcome: str,
        success: bool,
        lessons: List[str]
    ):
        """
        记录经验
        
        Args:
            task: 任务
            approach: 方法
            outcome: 结果
            success: 是否成功
            lessons: 经验教训
        """
        experience = {
            'task': task,
            'approach': approach,
            'outcome': outcome,
            'success': success,
            'lessons': lessons,
            'timestamp': datetime.now().isoformat()
        }
        
        self.experiences.append(experience)
        
        # 更新模式
        self._update_patterns(experience)
    
    def _update_patterns(self, experience: Dict):
        """更新模式"""
        task_type = self._classify_task(experience['task'])
        
        if task_type not in self.patterns:
            self.patterns[task_type] = {
                'successful_approaches': [],
                'failed_approaches': [],
                'common_lessons': []
            }
        
        if experience['success']:
            self.patterns[task_type]['successful_approaches'].append(
                experience['approach']
            )
        else:
            self.patterns[task_type]['failed_approaches'].append(
                experience['approach']
            )
        
        self.patterns[task_type]['common_lessons'].extend(
            experience['lessons']
        )
    
    def _classify_task(self, task: str) -> str:
        """分类任务"""
        task_keywords = {
            'analysis': ['分析', '评估', '诊断'],
            'creation': ['创建', '编写', '设计'],
            'modification': ['修改', '优化', '改进'],
            'research': ['研究', '调研', '查找']
        }
        
        for task_type, keywords in task_keywords.items():
            if any(kw in task for kw in keywords):
                return task_type
        
        return 'general'
    
    def get_recommendations(self, task: str) -> Dict:
        """获取推荐方法"""
        task_type = self._classify_task(task)
        
        if task_type not in self.patterns:
            return {'recommended_approach': '无历史经验'}
        
        patterns = self.patterns[task_type]
        
        # 推荐成功的方法
        if patterns['successful_approaches']:
            recommended = patterns['successful_approaches'][0]
        else:
            recommended = '无成功经验'
        
        # 避免失败的方法
        avoid = patterns['failed_approaches'][:3]
        
        # 经验教训
        lessons = patterns['common_lessons'][:5]
        
        return {
            'recommended_approach': recommended,
            'approaches_to_avoid': avoid,
            'lessons_learned': lessons
        }

4.2 模式识别

# pattern_recognition.py
from typing import Dict, List
from collections import Counter

class PatternRecognizer:
    """模式识别器"""
    
    def __init__(self):
        self.error_patterns = Counter()
        self.success_patterns = Counter()
    
    def analyze_failures(
        self,
        failures: List[Dict]
    ) -> Dict:
        """
        分析失败模式
        
        Args:
            failures: 失败记录列表
        
        Returns:
            分析结果
        """
        # 统计错误类型
        for failure in failures:
            error_type = failure.get('error_type', 'unknown')
            self.error_patterns[error_type] += 1
        
        # 找出常见模式
        common_patterns = self.error_patterns.most_common(5)
        
        # 生成改进建议
        recommendations = []
        
        for error_type, count in common_patterns:
            recommendation = self._generate_recommendation(
                error_type,
                count,
                len(failures)
            )
            recommendations.append(recommendation)
        
        return {
            'common_error_types': common_patterns,
            'recommendations': recommendations,
            'total_failures': len(failures)
        }
    
    def analyze_successes(
        self,
        successes: List[Dict]
    ) -> Dict:
        """分析成功模式"""
        for success in successes:
            approach = success.get('approach', 'unknown')
            self.success_patterns[approach] += 1
        
        common_patterns = self.success_patterns.most_common(5)
        
        return {
            'common_success_approaches': common_patterns,
            'total_successes': len(successes)
        }
    
    def _generate_recommendation(
        self,
        error_type: str,
        count: int,
        total: int
    ) -> Dict:
        """生成改进建议"""
        frequency = count / total
        
        if frequency > 0.3:
            priority = 'high'
        elif frequency > 0.1:
            priority = 'medium'
        else:
            priority = 'low'
        
        suggestion = self._get_suggestion_for_error(error_type)
        
        return {
            'error_type': error_type,
            'frequency': frequency,
            'priority': priority,
            'suggestion': suggestion
        }
    
    def _get_suggestion_for_error(self, error_type: str) -> str:
        """获取错误类型的建议"""
        suggestions = {
            'factual_error': '加强事实核查,使用可靠信息源',
            'logical_error': '加强逻辑推理,逐步验证',
            'completeness_error': '使用检查清单,确保覆盖全面',
            'clarity_error': '简化表达,使用结构化格式',
            'context_error': '更好地理解上下文和用户需求'
        }
        
        return suggestions.get(error_type, '需要进一步分析')

五、实战案例

5.1 代码生成自修正

# code_self_correction.py
class CodeSelfCorrector:
    """代码自修正器"""
    
    def __init__(self, llm):
        self.llm = llm
        self.consistency_checker = ConsistencyChecker(llm)
        self.iterative_refiner = IterativeRefiner(llm)
    
    def generate_and_correct(
        self,
        requirement: str,
        language: str = 'Python'
    ) -> Dict:
        """
        生成并自修正代码
        
        Args:
            requirement: 需求描述
            language: 编程语言
        """
        # 1. 初始生成
        initial_code = self._generate_code(requirement, language)
        
        # 2. 语法检查
        syntax_check = self._check_syntax(initial_code, language)
        
        if not syntax_check['valid']:
            # 语法错误,修正
            initial_code = self._fix_syntax(
                initial_code,
                syntax_check['errors'],
                language
            )
        
        # 3. 逻辑检查
        logic_check = self._check_logic(initial_code, requirement)
        
        # 4. 迭代修正
        refined_result = self.iterative_refiner.refine(
            initial_code,
            f"生成{language}代码:{requirement}",
            {'logic_check': logic_check}
        )
        
        # 5. 最终验证
        final_verification = self._verify_code(
            refined_result['final_response'],
            requirement
        )
        
        return {
            'code': refined_result['final_response'],
            'iterations': refined_result['iterations'],
            'final_score': refined_result['final_score'],
            'verification': final_verification
        }
    
    def _generate_code(self, requirement: str, language: str) -> str:
        """生成代码"""
        prompt = f"""
请编写{language}代码实现以下需求:

需求:{requirement}

要求:
1. 代码简洁清晰
2. 包含必要的注释
3. 考虑边界情况
4. 包含错误处理
"""
        return self.llm.generate(prompt)
    
    def _check_syntax(self, code: str, language: str) -> Dict:
        """检查语法"""
        if language == 'Python':
            try:
                compile(code, '<string>', 'exec')
                return {'valid': True, 'errors': []}
            except SyntaxError as e:
                return {
                    'valid': False,
                    'errors': [str(e)]
                }
        return {'valid': True, 'errors': []}
    
    def _fix_syntax(
        self,
        code: str,
        errors: List[str],
        language: str
    ) -> str:
        """修复语法错误"""
        prompt = f"""
请修复以下代码的语法错误:

代码:
{code}

错误:
{chr(10).join(errors)}

请只输出修复后的完整代码,不要解释。
"""
        return self.llm.generate(prompt)
    
    def _check_logic(self, code: str, requirement: str) -> Dict:
        """检查逻辑"""
        prompt = f"""
请检查以下代码是否满足需求:

需求:{requirement}

代码:
{code}

请分析:
1. 代码是否实现了需求?
2. 是否有逻辑错误?
3. 是否有遗漏的边界情况?

按 JSON 格式输出:
{{
    "implements_requirement": true/false,
    "logic_errors": ["错误列表"],
    "missing_cases": ["遗漏情况"]
}}
"""
        return self._parse_json(self.llm.generate(prompt))
    
    def _verify_code(self, code: str, requirement: str) -> Dict:
        """验证代码"""
        # 生成测试用例
        test_cases = self._generate_test_cases(requirement)
        
        # 运行测试(简化)
        passed = 0
        failed = 0
        
        return {
            'test_cases': len(test_cases),
            'passed': passed,
            'failed': failed
        }
    
    def _generate_test_cases(self, requirement: str) -> List[Dict]:
        """生成测试用例"""
        prompt = f"""
为以下需求生成测试用例:

需求:{requirement}

请按 JSON 格式输出测试用例列表:
[
    {{"input": "输入", "expected": "期望输出"}}
]
"""
        return self._parse_json(self.llm.generate(prompt))
    
    def _parse_json(self, text: str) -> Dict:
        """解析 JSON"""
        import json
        import re
        
        match = re.search(r'\{.*\}', text, re.DOTALL)
        if match:
            return json.loads(match.group())
        return {}

5.2 写作自修正

# writing_self_correction.py
class WritingSelfCorrector:
    """写作自修正器"""
    
    def __init__(self, llm):
        self.llm = llm
        self.quality_evaluator = QualityEvaluator(llm)
        self.self_critic = SelfCritic(llm)
    
    def write_and_correct(
        self,
        topic: str,
        style: str = 'professional',
        length: int = 1000
    ) -> Dict:
        """
        写作并自修正
        
        Args:
            topic: 主题
            style: 风格
            length: 目标长度
        """
        # 1. 初始写作
        initial_draft = self._write_draft(topic, style, length)
        
        # 2. 自我批评
        critique = self.self_critic.critique(
            initial_draft,
            f"写一篇关于{topic}的文章"
        )
        
        # 3. 质量评估
        evaluation = self.quality_evaluator.evaluate(
            initial_draft,
            f"写一篇关于{topic}的文章"
        )
        
        # 4. 修正
        revised = self._revise_writing(
            initial_draft,
            critique,
            evaluation
        )
        
        # 5. 最终评估
        final_evaluation = self.quality_evaluator.evaluate(
            revised,
            f"写一篇关于{topic}的文章"
        )
        
        return {
            'initial_draft': initial_draft,
            'revised': revised,
            'initial_score': evaluation['overall_score'],
            'final_score': final_evaluation['overall_score'],
            'improvement': final_evaluation['overall_score'] - evaluation['overall_score']
        }
    
    def _write_draft(
        self,
        topic: str,
        style: str,
        length: int
    ) -> str:
        """写初稿"""
        prompt = f"""
请写一篇关于"{topic}"的文章。

风格:{style}
目标长度:约{length}

要求:
1. 结构清晰(引言、正文、结论)
2. 论点明确,论据充分
3. 语言流畅,表达准确
"""
        return self.llm.generate(prompt)
    
    def _revise_writing(
        self,
        draft: str,
        critique: Dict,
        evaluation: Dict
    ) -> str:
        """修正文章"""
        feedback_summary = []
        
        for criterion, comment in evaluation['feedback'].items():
            if isinstance(comment, str):
                feedback_summary.append(f"- {criterion}: {comment}")
        
        prompt = f"""
请根据以下批评和评估改进文章:

原文:
{draft}

批评意见:
{critique}

评估反馈:
{chr(10).join(feedback_summary)}

请综合以上信息,给出改进后的完整文章。
"""
        return self.llm.generate(prompt)

六、总结

6.1 核心要点

  1. 错误检测

    • 一致性检查
    • 质量评估
    • 事实核查
  2. 自修正机制

    • 迭代修正
    • 自我批评
    • 反馈驱动
  3. 反思学习

    • 经验积累
    • 模式识别
    • 持续改进

6.2 最佳实践

  1. 多层次反思

    • 结果反思
    • 过程反思
    • 元认知反思
  2. 及时修正

    • 发现问题立即修正
    • 设置修正次数上限
    • 避免过度修正
  3. 持续学习

    • 记录经验教训
    • 识别成功模式
    • 避免重复错误

参考资料


分享这篇文章到:

上一篇文章
Redis 消息队列实现方案
下一篇文章
RocketMQ POP 消费模式详解与实战