Skip to content
清晨的一缕阳光
返回

AI 应用成本控制实战

AI 应用成本控制实战

AI 应用的成本控制是规模化应用的关键挑战。如何降低 LLM 调用成本?如何优化 Token 使用?本文详解 AI 应用成本控制的实战方法。

一、成本构成分析

1.1 成本结构

AI 应用成本结构:

┌─────────────────────────────────────┐
│ 1. LLM API 成本(60-80%)            │
│    - 输入 Token                      │
│    - 输出 Token                      │
│    - 模型溢价                        │
├─────────────────────────────────────┤
│ 2. 基础设施成本(10-20%)            │
│    - 计算资源                        │
│    - 存储资源                        │
│    - 网络资源                        │
├─────────────────────────────────────┤
│ 3. 运维成本(5-10%)                 │
│    - 监控告警                        │
│    - 日志存储                        │
│    - 人力成本                        │
├─────────────────────────────────────┤
│ 4. 其他成本(5-10%)                 │
│    - 向量数据库                      │
│    - 第三方服务                      │
│    - 数据成本                        │
└─────────────────────────────────────┘

1.2 成本计算

# cost_calculator.py
from typing import Dict, List

class AICostCalculator:
    """AI 成本计算器"""
    
    def __init__(self):
        # 模型定价(每 1000 Token)
        self.model_pricing = {
            'gpt-4': {'input': 0.03, 'output': 0.06},
            'gpt-4-turbo': {'input': 0.01, 'output': 0.03},
            'gpt-3.5-turbo': {'input': 0.0005, 'output': 0.0015},
            'claude-3': {'input': 0.03, 'output': 0.015},
            'claude-3-haiku': {'input': 0.00025, 'output': 0.00125}
        }
    
    def calculate_llm_cost(
        self,
        model: str,
        input_tokens: int,
        output_tokens: int
    ) -> float:
        """计算 LLM 成本"""
        if model not in self.model_pricing:
            return 0.0
        
        pricing = self.model_pricing[model]
        input_cost = (input_tokens / 1000) * pricing['input']
        output_cost = (output_tokens / 1000) * pricing['output']
        
        return input_cost + output_cost
    
    def calculate_query_cost(
        self,
        query_details: Dict
    ) -> Dict:
        """计算单次查询成本"""
        llm_cost = self.calculate_llm_cost(
            query_details['model'],
            query_details['input_tokens'],
            query_details['output_tokens']
        )
        
        # 其他成本
        embedding_cost = query_details.get('embedding_cost', 0)
        vector_search_cost = query_details.get('vector_search_cost', 0)
        
        total_cost = llm_cost + embedding_cost + vector_search_cost
        
        return {
            'llm_cost': llm_cost,
            'embedding_cost': embedding_cost,
            'vector_search_cost': vector_search_cost,
            'total_cost': total_cost,
            'breakdown': {
                'input_tokens': query_details['input_tokens'],
                'output_tokens': query_details['output_tokens'],
                'model': query_details['model']
            }
        }
    
    def project_monthly_cost(
        self,
        avg_queries_per_day: int,
        avg_cost_per_query: float
    ) -> float:
        """预估月度成本"""
        return avg_queries_per_day * avg_cost_per_query * 30

二、Token 优化

2.1 Prompt 压缩

# prompt_compression.py
from typing import Dict, List

class PromptCompressor:
    """Prompt 压缩器"""
    
    def __init__(self, llm):
        self.llm = llm
    
    def compress_context(
        self,
        context: str,
        max_tokens: int,
        query: str
    ) -> str:
        """压缩上下文"""
        # 1. 移除冗余信息
        context = self._remove_redundancy(context)
        
        # 2. 如果还是太长,使用 LLM 摘要
        if self._count_tokens(context) > max_tokens:
            context = self._summarize_context(context, query, max_tokens)
        
        # 3. 截断到最大长度
        return self._truncate_to_tokens(context, max_tokens)
    
    def _remove_redundancy(self, context: str) -> str:
        """移除冗余信息"""
        lines = context.split('\n')
        seen = set()
        unique_lines = []
        
        for line in lines:
            line_hash = hash(line.strip())
            if line_hash not in seen:
                seen.add(line_hash)
                unique_lines.append(line)
        
        return '\n'.join(unique_lines)
    
    def _summarize_context(
        self,
        context: str,
        query: str,
        max_tokens: int
    ) -> str:
        """使用 LLM 摘要上下文"""
        prompt = f"""
请摘要以下内容,保留与问题最相关的信息:

问题:{query}

内容:
{context}

摘要(不超过{max_tokens} Token):
"""
        summary = self.llm.generate(prompt)
        return summary
    
    def _truncate_to_tokens(
        self,
        text: str,
        max_tokens: int
    ) -> str:
        """截断到指定 Token 数"""
        tokens = text.split()
        
        if len(tokens) <= max_tokens:
            return text
        
        # 保留开头和结尾
        head_size = max_tokens // 2
        tail_size = max_tokens - head_size
        
        head = ' '.join(tokens[:head_size])
        tail = ' '.join(tokens[-tail_size:])
        
        return f"{head}...{tail}"
    
    def _count_tokens(self, text: str) -> int:
        """计算 Token 数"""
        # 简化估算
        return len(text.split()) // 4

2.2 上下文优化

# context_optimization.py
from typing import Dict, List

class ContextOptimizer:
    """上下文优化器"""
    
    def __init__(self, embedding_model):
        self.embedding_model = embedding_model
    
    def optimize_context(
        self,
        query: str,
        contexts: List[Dict],
        max_contexts: int = 5,
        max_tokens: int = 2000
    ) -> List[Dict]:
        """优化上下文"""
        # 1. 计算相关性分数
        scored_contexts = self._score_relevance(query, contexts)
        
        # 2. 排序并选择 Top-K
        scored_contexts.sort(key=lambda x: x['score'], reverse=True)
        selected = scored_contexts[:max_contexts]
        
        # 3. 压缩每个上下文
        compressed = []
        remaining_tokens = max_tokens
        
        for ctx in selected:
            compressed_ctx = self._compress_context(ctx, remaining_tokens)
            compressed.append(compressed_ctx)
            remaining_tokens -= self._count_tokens(compressed_ctx['content'])
        
        return compressed
    
    def _score_relevance(
        self,
        query: str,
        contexts: List[Dict]
    ) -> List[Dict]:
        """计算相关性分数"""
        query_embedding = self.embedding_model.encode([query])[0]
        
        scored = []
        for ctx in contexts:
            ctx_embedding = self.embedding_model.encode(
                [ctx['content']]
            )[0]
            
            # 计算余弦相似度
            similarity = self._cosine_similarity(
                query_embedding,
                ctx_embedding
            )
            
            scored.append({
                **ctx,
                'score': similarity
            })
        
        return scored
    
    def _compress_context(
        self,
        context: Dict,
        max_tokens: int
    ) -> Dict:
        """压缩单个上下文"""
        content = context['content']
        
        if self._count_tokens(content) <= max_tokens:
            return context
        
        # 提取关键句子
        sentences = content.split('.')
        selected_sentences = []
        used_tokens = 0
        
        for sentence in sentences:
            sentence_tokens = self._count_tokens(sentence)
            if used_tokens + sentence_tokens <= max_tokens:
                selected_sentences.append(sentence)
                used_tokens += sentence_tokens
        
        return {
            **context,
            'content': '. '.join(selected_sentences),
            'compressed': True
        }
    
    def _cosine_similarity(
        self,
        vec1: List[float],
        vec2: List[float]
    ) -> float:
        """计算余弦相似度"""
        import numpy as np
        
        dot_product = np.dot(vec1, vec2)
        norm1 = np.linalg.norm(vec1)
        norm2 = np.linalg.norm(vec2)
        
        return dot_product / (norm1 * norm2)
    
    def _count_tokens(self, text: str) -> int:
        """计算 Token 数"""
        return len(text.split()) // 4

三、缓存策略

2.1 响应缓存

# response_cache.py
from typing import Dict, List, Optional
from datetime import datetime, timedelta
import hashlib

class ResponseCache:
    """响应缓存"""
    
    def __init__(self, ttl_hours: int = 24):
        self.cache: Dict[str, Dict] = {}
        self.ttl = timedelta(hours=ttl_hours)
        self.stats = {
            'hits': 0,
            'misses': 0
        }
    
    def generate_cache_key(
        self,
        query: str,
        context: Dict,
        model: str
    ) -> str:
        """生成缓存键"""
        key_data = f"{query}:{str(sorted(context.items()))}:{model}"
        return hashlib.md5(key_data.encode()).hexdigest()
    
    def get(
        self,
        cache_key: str
    ) -> Optional[Dict]:
        """获取缓存"""
        if cache_key in self.cache:
            entry = self.cache[cache_key]
            
            if datetime.now() < entry['expires_at']:
                self.stats['hits'] += 1
                return entry['response']
            else:
                del self.cache[cache_key]
        
        self.stats['misses'] += 1
        return None
    
    def set(
        self,
        cache_key: str,
        response: Dict
    ):
        """设置缓存"""
        self.cache[cache_key] = {
            'response': response,
            'created_at': datetime.now(),
            'expires_at': datetime.now() + self.ttl
        }
    
    def get_hit_rate(self) -> float:
        """获取命中率"""
        total = self.stats['hits'] + self.stats['misses']
        if total == 0:
            return 0.0
        return self.stats['hits'] / total
    
    def cleanup_expired(self):
        """清理过期缓存"""
        now = datetime.now()
        expired_keys = [
            k for k, v in self.cache.items()
            if now >= v['expires_at']
        ]
        
        for key in expired_keys:
            del self.cache[key]

2.2 语义缓存

# semantic_cache.py
from typing import Dict, List, Optional
import numpy as np

class SemanticCache:
    """语义缓存"""
    
    def __init__(
        self,
        embedding_model,
        similarity_threshold: float = 0.95
    ):
        self.embedding_model = embedding_model
        self.similarity_threshold = similarity_threshold
        self.cache: List[Dict] = []
    
    def get(
        self,
        query: str
    ) -> Optional[Dict]:
        """获取语义缓存"""
        query_embedding = self.embedding_model.encode([query])[0]
        
        for entry in self.cache:
            similarity = self._cosine_similarity(
                query_embedding,
                entry['embedding']
            )
            
            if similarity >= self.similarity_threshold:
                return entry['response']
        
        return None
    
    def set(
        self,
        query: str,
        response: Dict
    ):
        """设置语义缓存"""
        query_embedding = self.embedding_model.encode([query])[0]
        
        self.cache.append({
            'query': query,
            'embedding': query_embedding,
            'response': response,
            'created_at': datetime.now()
        })
        
        # 限制缓存大小
        if len(self.cache) > 10000:
            self.cache = self.cache[-5000:]
    
    def _cosine_similarity(
        self,
        vec1: np.ndarray,
        vec2: np.ndarray
    ) -> float:
        """计算余弦相似度"""
        dot_product = np.dot(vec1, vec2)
        norm1 = np.linalg.norm(vec1)
        norm2 = np.linalg.norm(vec2)
        
        if norm1 == 0 or norm2 == 0:
            return 0.0
        
        return dot_product / (norm1 * norm2)

四、模型选择优化

4.1 智能路由

# model_routing.py
from typing import Dict, List

class ModelRouter:
    """模型路由器"""
    
    def __init__(self):
        self.model_capabilities = {
            'gpt-4': {
                'quality': 'high',
                'cost_per_1k': 0.03,
                'latency_ms': 3000,
                'suitable_for': ['complex_reasoning', 'code_generation']
            },
            'gpt-4-turbo': {
                'quality': 'high',
                'cost_per_1k': 0.01,
                'latency_ms': 2000,
                'suitable_for': ['general', 'summarization']
            },
            'gpt-3.5-turbo': {
                'quality': 'medium',
                'cost_per_1k': 0.0005,
                'latency_ms': 1000,
                'suitable_for': ['simple_qa', 'chat']
            },
            'claude-3-haiku': {
                'quality': 'medium',
                'cost_per_1k': 0.00025,
                'latency_ms': 800,
                'suitable_for': ['simple_qa', 'extraction']
            }
        }
    
    def select_model(self, query: Dict) -> str:
        """选择模型"""
        query_type = self._classify_query(query)
        budget = query.get('budget', 'medium')
        latency_requirement = query.get('latency_requirement', 'normal')
        
        candidates = []
        
        for model, capabilities in self.model_capabilities.items():
            # 检查是否适合该查询类型
            if query_type in capabilities['suitable_for']:
                score = self._calculate_score(
                    capabilities,
                    budget,
                    latency_requirement
                )
                candidates.append((model, score))
        
        # 选择得分最高的
        if candidates:
            return max(candidates, key=lambda x: x[1])[0]
        
        return 'gpt-3.5-turbo'  # 默认
    
    def _classify_query(self, query: Dict) -> str:
        """分类查询类型"""
        complexity = query.get('complexity', 'medium')
        task_type = query.get('task_type', 'general')
        
        if complexity == 'high' or task_type in ['code', 'reasoning']:
            return 'complex_reasoning'
        elif complexity == 'low':
            return 'simple_qa'
        else:
            return 'general'
    
    def _calculate_score(
        self,
        capabilities: Dict,
        budget: str,
        latency_requirement: str
    ) -> float:
        """计算得分"""
        score = 0.0
        
        # 成本得分
        cost_scores = {
            'low': 0.001,
            'medium': 0.01,
            'high': 0.05
        }
        cost_score = 1 - (
            capabilities['cost_per_1k'] / cost_scores.get(budget, 0.01)
        )
        score += cost_score * 0.4
        
        # 延迟得分
        latency_scores = {
            'fast': 1000,
            'normal': 3000,
            'slow': 5000
        }
        latency_score = 1 - (
            capabilities['latency_ms'] / latency_scores.get(latency_requirement, 3000)
        )
        score += latency_score * 0.3
        
        # 质量得分
        quality_scores = {
            'low': 0.3,
            'medium': 0.6,
            'high': 1.0
        }
        score += quality_scores.get(capabilities['quality'], 0.6) * 0.3
        
        return score

4.2 降级策略

# fallback_strategy.py
from typing import Dict, List, Optional

class FallbackStrategy:
    """降级策略"""
    
    def __init__(self):
        self.model_hierarchy = [
            'gpt-4-turbo',      # 首选
            'gpt-4',            # 备选 1
            'gpt-3.5-turbo',    # 备选 2
            'claude-3-haiku'    # 备选 3
        ]
    
    def execute_with_fallback(
        self,
        query: Dict,
        llm_clients: Dict
    ) -> Dict:
        """执行带降级的调用"""
        last_error = None
        
        for model in self.model_hierarchy:
            if model not in llm_clients:
                continue
            
            try:
                result = llm_clients[model].generate(
                    query['prompt'],
                    **query.get('options', {})
                )
                
                return {
                    'success': True,
                    'result': result,
                    'model': model,
                    'attempted_models': self.model_hierarchy[
                        :self.model_hierarchy.index(model)+1
                    ]
                }
            
            except Exception as e:
                last_error = e
                continue
        
        # 所有模型都失败
        return {
            'success': False,
            'error': str(last_error),
            'attempted_models': self.model_hierarchy
        }
    
    def get_cost_savings(
        self,
        original_model: str,
        fallback_model: str,
        tokens: int
    ) -> float:
        """计算降级节省的成本"""
        pricing = {
            'gpt-4': 0.03,
            'gpt-4-turbo': 0.01,
            'gpt-3.5-turbo': 0.0005,
            'claude-3-haiku': 0.00025
        }
        
        original_cost = (tokens / 1000) * pricing.get(original_model, 0.01)
        fallback_cost = (tokens / 1000) * pricing.get(fallback_model, 0.0005)
        
        return original_cost - fallback_cost

五、资源调度

5.1 批量处理

# batch_processing.py
from typing import Dict, List
import asyncio

class BatchOptimizer:
    """批量优化器"""
    
    def __init__(
        self,
        max_batch_size: int = 20,
        max_wait_ms: int = 100
    ):
        self.max_batch_size = max_batch_size
        self.max_wait_ms = max_wait_ms
        self.pending_queue: List[Dict] = []
    
    async def submit(self, request: Dict) -> Dict:
        """提交请求"""
        future = asyncio.Future()
        
        self.pending_queue.append({
            'request': request,
            'future': future,
            'submitted_at': datetime.now()
        })
        
        # 检查是否应该执行批次
        if len(self.pending_queue) >= self.max_batch_size:
            asyncio.create_task(self._execute_batch())
        elif len(self.pending_queue) == 1:
            # 第一个请求,启动定时器
            asyncio.create_task(self._batch_timer())
        
        return await future
    
    async def _batch_timer(self):
        """批次定时器"""
        await asyncio.sleep(self.max_wait_ms / 1000)
        await self._execute_batch()
    
    async def _execute_batch(self):
        """执行批次"""
        if not self.pending_queue:
            return
        
        batch = self.pending_queue[:]
        self.pending_queue = []
        
        # 批量处理
        requests = [item['request'] for item in batch]
        results = await self._process_batch(requests)
        
        # 返回结果
        for item, result in zip(batch, results):
            if not item['future'].done():
                item['future'].set_result(result)
    
    async def _process_batch(
        self,
        requests: List[Dict]
    ) -> List[Dict]:
        """处理批次"""
        # 利用 LLM 的 batch API
        # 可以显著降低成本
        return []

5.2 资源池化

# resource_pooling.py
from typing import Dict, List, Optional
import asyncio

class LLMResourcePool:
    """LLM 资源池"""
    
    def __init__(self, pool_size: int = 10):
        self.pool_size = pool_size
        self.available: asyncio.Queue = asyncio.Queue(maxsize=pool_size)
        self.in_use: Dict = {}
    
    async def initialize(self, llm_clients: Dict):
        """初始化资源池"""
        for i in range(self.pool_size):
            client_key = list(llm_clients.keys())[i % len(llm_clients)]
            await self.available.put({
                'id': i,
                'client': llm_clients[client_key],
                'model': client_key
            })
    
    async def acquire(self) -> Dict:
        """获取资源"""
        resource = await self.available.get()
        self.in_use[resource['id']] = resource
        return resource
    
    async def release(self, resource_id: int):
        """释放资源"""
        if resource_id in self.in_use:
            resource = self.in_use.pop(resource_id)
            await self.available.put(resource)
    
    def get_pool_stats(self) -> Dict:
        """获取池统计"""
        return {
            'total': self.pool_size,
            'available': self.available.qsize(),
            'in_use': len(self.in_use),
            'utilization': len(self.in_use) / self.pool_size
        }

六、成本监控

6.1 成本追踪

# cost_tracking.py
from typing import Dict, List
from datetime import datetime, timedelta

class CostTracker:
    """成本追踪器"""
    
    def __init__(self):
        self.transactions: List[Dict] = []
        self.daily_budgets: Dict[str, float] = {}
    
    def record_transaction(
        self,
        model: str,
        input_tokens: int,
        output_tokens: int,
        cost: float,
        query_id: str = None
    ):
        """记录交易"""
        self.transactions.append({
            'timestamp': datetime.now().isoformat(),
            'model': model,
            'input_tokens': input_tokens,
            'output_tokens': output_tokens,
            'cost': cost,
            'query_id': query_id
        })
    
    def get_daily_cost(self, date: datetime = None) -> float:
        """获取每日成本"""
        if date is None:
            date = datetime.now()
        
        date_str = date.strftime('%Y-%m-%d')
        
        daily_total = sum(
            t['cost'] for t in self.transactions
            if t['timestamp'].startswith(date_str)
        )
        
        return daily_total
    
    def check_budget(
        self,
        date: datetime = None
    ) -> Dict:
        """检查预算"""
        if date is None:
            date = datetime.now()
        
        date_str = date.strftime('%Y-%m-%d')
        daily_cost = self.get_daily_cost(date)
        budget = self.daily_budgets.get(date_str, 100.0)
        
        return {
            'date': date_str,
            'spent': daily_cost,
            'budget': budget,
            'remaining': budget - daily_cost,
            'utilization': daily_cost / budget,
            'over_budget': daily_cost > budget
        }
    
    def get_cost_breakdown(
        self,
        start_date: datetime,
        end_date: datetime
    ) -> Dict:
        """获取成本分解"""
        breakdown = {}
        
        for t in self.transactions:
            t_date = datetime.fromisoformat(t['timestamp'])
            
            if start_date <= t_date <= end_date:
                model = t['model']
                if model not in breakdown:
                    breakdown[model] = 0
                breakdown[model] += t['cost']
        
        return breakdown

6.2 成本告警

# cost_alerting.py
from typing import Dict, List

class CostAlerter:
    """成本告警器"""
    
    def __init__(self):
        self.alert_rules: List[Dict] = []
        self.alert_history: List[Dict] = []
    
    def add_rule(
        self,
        name: str,
        metric: str,
        threshold: float,
        severity: str
    ):
        """添加告警规则"""
        self.alert_rules.append({
            'name': name,
            'metric': metric,
            'threshold': threshold,
            'severity': severity
        })
    
    def check_alerts(
        self,
        current_metrics: Dict
    ) -> List[Dict]:
        """检查告警"""
        triggered = []
        
        for rule in self.alert_rules:
            metric_value = current_metrics.get(rule['metric'])
            
            if metric_value and metric_value > rule['threshold']:
                alert = {
                    'rule': rule['name'],
                    'metric': rule['metric'],
                    'value': metric_value,
                    'threshold': rule['threshold'],
                    'severity': rule['severity'],
                    'timestamp': datetime.now().isoformat()
                }
                triggered.append(alert)
                self.alert_history.append(alert)
        
        return triggered

# 预定义告警规则
DEFAULT_COST_ALERTS = [
    {
        'name': 'Daily Budget Exceeded',
        'metric': 'daily_cost',
        'threshold': 100.0,
        'severity': 'critical'
    },
    {
        'name': 'High Cost Per Query',
        'metric': 'avg_cost_per_query',
        'threshold': 0.05,
        'severity': 'warning'
    },
    {
        'name': 'Low Cache Hit Rate',
        'metric': 'cache_hit_rate',
        'threshold': 0.3,
        'severity': 'warning'
    }
]

七、总结

7.1 核心要点

  1. Token 优化

    • Prompt 压缩
    • 上下文优化
    • 移除冗余
  2. 缓存策略

    • 响应缓存
    • 语义缓存
    • 多级缓存
  3. 模型选择

    • 智能路由
    • 降级策略
    • 成本感知

7.2 最佳实践

  1. 监控先行

    • 实时成本追踪
    • 预算告警
    • 成本分解
  2. 持续优化

    • 定期分析成本结构
    • 识别优化机会
    • A/B 测试验证
  3. 平衡质量与成本

    • 根据场景选择模型
    • 设置合理的缓存策略
    • 监控用户满意度

参考资料


分享这篇文章到:

上一篇文章
Redis 架构设计与核心概念
下一篇文章
RocketMQ 消息回溯与重置实战