AI 应用成本控制实战
AI 应用的成本控制是规模化应用的关键挑战。如何降低 LLM 调用成本?如何优化 Token 使用?本文详解 AI 应用成本控制的实战方法。
一、成本构成分析
1.1 成本结构
AI 应用成本结构:
┌─────────────────────────────────────┐
│ 1. LLM API 成本(60-80%) │
│ - 输入 Token │
│ - 输出 Token │
│ - 模型溢价 │
├─────────────────────────────────────┤
│ 2. 基础设施成本(10-20%) │
│ - 计算资源 │
│ - 存储资源 │
│ - 网络资源 │
├─────────────────────────────────────┤
│ 3. 运维成本(5-10%) │
│ - 监控告警 │
│ - 日志存储 │
│ - 人力成本 │
├─────────────────────────────────────┤
│ 4. 其他成本(5-10%) │
│ - 向量数据库 │
│ - 第三方服务 │
│ - 数据成本 │
└─────────────────────────────────────┘
1.2 成本计算
# cost_calculator.py
from typing import Dict, List
class AICostCalculator:
"""AI 成本计算器"""
def __init__(self):
# 模型定价(每 1000 Token)
self.model_pricing = {
'gpt-4': {'input': 0.03, 'output': 0.06},
'gpt-4-turbo': {'input': 0.01, 'output': 0.03},
'gpt-3.5-turbo': {'input': 0.0005, 'output': 0.0015},
'claude-3': {'input': 0.03, 'output': 0.015},
'claude-3-haiku': {'input': 0.00025, 'output': 0.00125}
}
def calculate_llm_cost(
self,
model: str,
input_tokens: int,
output_tokens: int
) -> float:
"""计算 LLM 成本"""
if model not in self.model_pricing:
return 0.0
pricing = self.model_pricing[model]
input_cost = (input_tokens / 1000) * pricing['input']
output_cost = (output_tokens / 1000) * pricing['output']
return input_cost + output_cost
def calculate_query_cost(
self,
query_details: Dict
) -> Dict:
"""计算单次查询成本"""
llm_cost = self.calculate_llm_cost(
query_details['model'],
query_details['input_tokens'],
query_details['output_tokens']
)
# 其他成本
embedding_cost = query_details.get('embedding_cost', 0)
vector_search_cost = query_details.get('vector_search_cost', 0)
total_cost = llm_cost + embedding_cost + vector_search_cost
return {
'llm_cost': llm_cost,
'embedding_cost': embedding_cost,
'vector_search_cost': vector_search_cost,
'total_cost': total_cost,
'breakdown': {
'input_tokens': query_details['input_tokens'],
'output_tokens': query_details['output_tokens'],
'model': query_details['model']
}
}
def project_monthly_cost(
self,
avg_queries_per_day: int,
avg_cost_per_query: float
) -> float:
"""预估月度成本"""
return avg_queries_per_day * avg_cost_per_query * 30
二、Token 优化
2.1 Prompt 压缩
# prompt_compression.py
from typing import Dict, List
class PromptCompressor:
"""Prompt 压缩器"""
def __init__(self, llm):
self.llm = llm
def compress_context(
self,
context: str,
max_tokens: int,
query: str
) -> str:
"""压缩上下文"""
# 1. 移除冗余信息
context = self._remove_redundancy(context)
# 2. 如果还是太长,使用 LLM 摘要
if self._count_tokens(context) > max_tokens:
context = self._summarize_context(context, query, max_tokens)
# 3. 截断到最大长度
return self._truncate_to_tokens(context, max_tokens)
def _remove_redundancy(self, context: str) -> str:
"""移除冗余信息"""
lines = context.split('\n')
seen = set()
unique_lines = []
for line in lines:
line_hash = hash(line.strip())
if line_hash not in seen:
seen.add(line_hash)
unique_lines.append(line)
return '\n'.join(unique_lines)
def _summarize_context(
self,
context: str,
query: str,
max_tokens: int
) -> str:
"""使用 LLM 摘要上下文"""
prompt = f"""
请摘要以下内容,保留与问题最相关的信息:
问题:{query}
内容:
{context}
摘要(不超过{max_tokens} Token):
"""
summary = self.llm.generate(prompt)
return summary
def _truncate_to_tokens(
self,
text: str,
max_tokens: int
) -> str:
"""截断到指定 Token 数"""
tokens = text.split()
if len(tokens) <= max_tokens:
return text
# 保留开头和结尾
head_size = max_tokens // 2
tail_size = max_tokens - head_size
head = ' '.join(tokens[:head_size])
tail = ' '.join(tokens[-tail_size:])
return f"{head}...{tail}"
def _count_tokens(self, text: str) -> int:
"""计算 Token 数"""
# 简化估算
return len(text.split()) // 4
2.2 上下文优化
# context_optimization.py
from typing import Dict, List
class ContextOptimizer:
"""上下文优化器"""
def __init__(self, embedding_model):
self.embedding_model = embedding_model
def optimize_context(
self,
query: str,
contexts: List[Dict],
max_contexts: int = 5,
max_tokens: int = 2000
) -> List[Dict]:
"""优化上下文"""
# 1. 计算相关性分数
scored_contexts = self._score_relevance(query, contexts)
# 2. 排序并选择 Top-K
scored_contexts.sort(key=lambda x: x['score'], reverse=True)
selected = scored_contexts[:max_contexts]
# 3. 压缩每个上下文
compressed = []
remaining_tokens = max_tokens
for ctx in selected:
compressed_ctx = self._compress_context(ctx, remaining_tokens)
compressed.append(compressed_ctx)
remaining_tokens -= self._count_tokens(compressed_ctx['content'])
return compressed
def _score_relevance(
self,
query: str,
contexts: List[Dict]
) -> List[Dict]:
"""计算相关性分数"""
query_embedding = self.embedding_model.encode([query])[0]
scored = []
for ctx in contexts:
ctx_embedding = self.embedding_model.encode(
[ctx['content']]
)[0]
# 计算余弦相似度
similarity = self._cosine_similarity(
query_embedding,
ctx_embedding
)
scored.append({
**ctx,
'score': similarity
})
return scored
def _compress_context(
self,
context: Dict,
max_tokens: int
) -> Dict:
"""压缩单个上下文"""
content = context['content']
if self._count_tokens(content) <= max_tokens:
return context
# 提取关键句子
sentences = content.split('.')
selected_sentences = []
used_tokens = 0
for sentence in sentences:
sentence_tokens = self._count_tokens(sentence)
if used_tokens + sentence_tokens <= max_tokens:
selected_sentences.append(sentence)
used_tokens += sentence_tokens
return {
**context,
'content': '. '.join(selected_sentences),
'compressed': True
}
def _cosine_similarity(
self,
vec1: List[float],
vec2: List[float]
) -> float:
"""计算余弦相似度"""
import numpy as np
dot_product = np.dot(vec1, vec2)
norm1 = np.linalg.norm(vec1)
norm2 = np.linalg.norm(vec2)
return dot_product / (norm1 * norm2)
def _count_tokens(self, text: str) -> int:
"""计算 Token 数"""
return len(text.split()) // 4
三、缓存策略
2.1 响应缓存
# response_cache.py
from typing import Dict, List, Optional
from datetime import datetime, timedelta
import hashlib
class ResponseCache:
"""响应缓存"""
def __init__(self, ttl_hours: int = 24):
self.cache: Dict[str, Dict] = {}
self.ttl = timedelta(hours=ttl_hours)
self.stats = {
'hits': 0,
'misses': 0
}
def generate_cache_key(
self,
query: str,
context: Dict,
model: str
) -> str:
"""生成缓存键"""
key_data = f"{query}:{str(sorted(context.items()))}:{model}"
return hashlib.md5(key_data.encode()).hexdigest()
def get(
self,
cache_key: str
) -> Optional[Dict]:
"""获取缓存"""
if cache_key in self.cache:
entry = self.cache[cache_key]
if datetime.now() < entry['expires_at']:
self.stats['hits'] += 1
return entry['response']
else:
del self.cache[cache_key]
self.stats['misses'] += 1
return None
def set(
self,
cache_key: str,
response: Dict
):
"""设置缓存"""
self.cache[cache_key] = {
'response': response,
'created_at': datetime.now(),
'expires_at': datetime.now() + self.ttl
}
def get_hit_rate(self) -> float:
"""获取命中率"""
total = self.stats['hits'] + self.stats['misses']
if total == 0:
return 0.0
return self.stats['hits'] / total
def cleanup_expired(self):
"""清理过期缓存"""
now = datetime.now()
expired_keys = [
k for k, v in self.cache.items()
if now >= v['expires_at']
]
for key in expired_keys:
del self.cache[key]
2.2 语义缓存
# semantic_cache.py
from typing import Dict, List, Optional
import numpy as np
class SemanticCache:
"""语义缓存"""
def __init__(
self,
embedding_model,
similarity_threshold: float = 0.95
):
self.embedding_model = embedding_model
self.similarity_threshold = similarity_threshold
self.cache: List[Dict] = []
def get(
self,
query: str
) -> Optional[Dict]:
"""获取语义缓存"""
query_embedding = self.embedding_model.encode([query])[0]
for entry in self.cache:
similarity = self._cosine_similarity(
query_embedding,
entry['embedding']
)
if similarity >= self.similarity_threshold:
return entry['response']
return None
def set(
self,
query: str,
response: Dict
):
"""设置语义缓存"""
query_embedding = self.embedding_model.encode([query])[0]
self.cache.append({
'query': query,
'embedding': query_embedding,
'response': response,
'created_at': datetime.now()
})
# 限制缓存大小
if len(self.cache) > 10000:
self.cache = self.cache[-5000:]
def _cosine_similarity(
self,
vec1: np.ndarray,
vec2: np.ndarray
) -> float:
"""计算余弦相似度"""
dot_product = np.dot(vec1, vec2)
norm1 = np.linalg.norm(vec1)
norm2 = np.linalg.norm(vec2)
if norm1 == 0 or norm2 == 0:
return 0.0
return dot_product / (norm1 * norm2)
四、模型选择优化
4.1 智能路由
# model_routing.py
from typing import Dict, List
class ModelRouter:
"""模型路由器"""
def __init__(self):
self.model_capabilities = {
'gpt-4': {
'quality': 'high',
'cost_per_1k': 0.03,
'latency_ms': 3000,
'suitable_for': ['complex_reasoning', 'code_generation']
},
'gpt-4-turbo': {
'quality': 'high',
'cost_per_1k': 0.01,
'latency_ms': 2000,
'suitable_for': ['general', 'summarization']
},
'gpt-3.5-turbo': {
'quality': 'medium',
'cost_per_1k': 0.0005,
'latency_ms': 1000,
'suitable_for': ['simple_qa', 'chat']
},
'claude-3-haiku': {
'quality': 'medium',
'cost_per_1k': 0.00025,
'latency_ms': 800,
'suitable_for': ['simple_qa', 'extraction']
}
}
def select_model(self, query: Dict) -> str:
"""选择模型"""
query_type = self._classify_query(query)
budget = query.get('budget', 'medium')
latency_requirement = query.get('latency_requirement', 'normal')
candidates = []
for model, capabilities in self.model_capabilities.items():
# 检查是否适合该查询类型
if query_type in capabilities['suitable_for']:
score = self._calculate_score(
capabilities,
budget,
latency_requirement
)
candidates.append((model, score))
# 选择得分最高的
if candidates:
return max(candidates, key=lambda x: x[1])[0]
return 'gpt-3.5-turbo' # 默认
def _classify_query(self, query: Dict) -> str:
"""分类查询类型"""
complexity = query.get('complexity', 'medium')
task_type = query.get('task_type', 'general')
if complexity == 'high' or task_type in ['code', 'reasoning']:
return 'complex_reasoning'
elif complexity == 'low':
return 'simple_qa'
else:
return 'general'
def _calculate_score(
self,
capabilities: Dict,
budget: str,
latency_requirement: str
) -> float:
"""计算得分"""
score = 0.0
# 成本得分
cost_scores = {
'low': 0.001,
'medium': 0.01,
'high': 0.05
}
cost_score = 1 - (
capabilities['cost_per_1k'] / cost_scores.get(budget, 0.01)
)
score += cost_score * 0.4
# 延迟得分
latency_scores = {
'fast': 1000,
'normal': 3000,
'slow': 5000
}
latency_score = 1 - (
capabilities['latency_ms'] / latency_scores.get(latency_requirement, 3000)
)
score += latency_score * 0.3
# 质量得分
quality_scores = {
'low': 0.3,
'medium': 0.6,
'high': 1.0
}
score += quality_scores.get(capabilities['quality'], 0.6) * 0.3
return score
4.2 降级策略
# fallback_strategy.py
from typing import Dict, List, Optional
class FallbackStrategy:
"""降级策略"""
def __init__(self):
self.model_hierarchy = [
'gpt-4-turbo', # 首选
'gpt-4', # 备选 1
'gpt-3.5-turbo', # 备选 2
'claude-3-haiku' # 备选 3
]
def execute_with_fallback(
self,
query: Dict,
llm_clients: Dict
) -> Dict:
"""执行带降级的调用"""
last_error = None
for model in self.model_hierarchy:
if model not in llm_clients:
continue
try:
result = llm_clients[model].generate(
query['prompt'],
**query.get('options', {})
)
return {
'success': True,
'result': result,
'model': model,
'attempted_models': self.model_hierarchy[
:self.model_hierarchy.index(model)+1
]
}
except Exception as e:
last_error = e
continue
# 所有模型都失败
return {
'success': False,
'error': str(last_error),
'attempted_models': self.model_hierarchy
}
def get_cost_savings(
self,
original_model: str,
fallback_model: str,
tokens: int
) -> float:
"""计算降级节省的成本"""
pricing = {
'gpt-4': 0.03,
'gpt-4-turbo': 0.01,
'gpt-3.5-turbo': 0.0005,
'claude-3-haiku': 0.00025
}
original_cost = (tokens / 1000) * pricing.get(original_model, 0.01)
fallback_cost = (tokens / 1000) * pricing.get(fallback_model, 0.0005)
return original_cost - fallback_cost
五、资源调度
5.1 批量处理
# batch_processing.py
from typing import Dict, List
import asyncio
class BatchOptimizer:
"""批量优化器"""
def __init__(
self,
max_batch_size: int = 20,
max_wait_ms: int = 100
):
self.max_batch_size = max_batch_size
self.max_wait_ms = max_wait_ms
self.pending_queue: List[Dict] = []
async def submit(self, request: Dict) -> Dict:
"""提交请求"""
future = asyncio.Future()
self.pending_queue.append({
'request': request,
'future': future,
'submitted_at': datetime.now()
})
# 检查是否应该执行批次
if len(self.pending_queue) >= self.max_batch_size:
asyncio.create_task(self._execute_batch())
elif len(self.pending_queue) == 1:
# 第一个请求,启动定时器
asyncio.create_task(self._batch_timer())
return await future
async def _batch_timer(self):
"""批次定时器"""
await asyncio.sleep(self.max_wait_ms / 1000)
await self._execute_batch()
async def _execute_batch(self):
"""执行批次"""
if not self.pending_queue:
return
batch = self.pending_queue[:]
self.pending_queue = []
# 批量处理
requests = [item['request'] for item in batch]
results = await self._process_batch(requests)
# 返回结果
for item, result in zip(batch, results):
if not item['future'].done():
item['future'].set_result(result)
async def _process_batch(
self,
requests: List[Dict]
) -> List[Dict]:
"""处理批次"""
# 利用 LLM 的 batch API
# 可以显著降低成本
return []
5.2 资源池化
# resource_pooling.py
from typing import Dict, List, Optional
import asyncio
class LLMResourcePool:
"""LLM 资源池"""
def __init__(self, pool_size: int = 10):
self.pool_size = pool_size
self.available: asyncio.Queue = asyncio.Queue(maxsize=pool_size)
self.in_use: Dict = {}
async def initialize(self, llm_clients: Dict):
"""初始化资源池"""
for i in range(self.pool_size):
client_key = list(llm_clients.keys())[i % len(llm_clients)]
await self.available.put({
'id': i,
'client': llm_clients[client_key],
'model': client_key
})
async def acquire(self) -> Dict:
"""获取资源"""
resource = await self.available.get()
self.in_use[resource['id']] = resource
return resource
async def release(self, resource_id: int):
"""释放资源"""
if resource_id in self.in_use:
resource = self.in_use.pop(resource_id)
await self.available.put(resource)
def get_pool_stats(self) -> Dict:
"""获取池统计"""
return {
'total': self.pool_size,
'available': self.available.qsize(),
'in_use': len(self.in_use),
'utilization': len(self.in_use) / self.pool_size
}
六、成本监控
6.1 成本追踪
# cost_tracking.py
from typing import Dict, List
from datetime import datetime, timedelta
class CostTracker:
"""成本追踪器"""
def __init__(self):
self.transactions: List[Dict] = []
self.daily_budgets: Dict[str, float] = {}
def record_transaction(
self,
model: str,
input_tokens: int,
output_tokens: int,
cost: float,
query_id: str = None
):
"""记录交易"""
self.transactions.append({
'timestamp': datetime.now().isoformat(),
'model': model,
'input_tokens': input_tokens,
'output_tokens': output_tokens,
'cost': cost,
'query_id': query_id
})
def get_daily_cost(self, date: datetime = None) -> float:
"""获取每日成本"""
if date is None:
date = datetime.now()
date_str = date.strftime('%Y-%m-%d')
daily_total = sum(
t['cost'] for t in self.transactions
if t['timestamp'].startswith(date_str)
)
return daily_total
def check_budget(
self,
date: datetime = None
) -> Dict:
"""检查预算"""
if date is None:
date = datetime.now()
date_str = date.strftime('%Y-%m-%d')
daily_cost = self.get_daily_cost(date)
budget = self.daily_budgets.get(date_str, 100.0)
return {
'date': date_str,
'spent': daily_cost,
'budget': budget,
'remaining': budget - daily_cost,
'utilization': daily_cost / budget,
'over_budget': daily_cost > budget
}
def get_cost_breakdown(
self,
start_date: datetime,
end_date: datetime
) -> Dict:
"""获取成本分解"""
breakdown = {}
for t in self.transactions:
t_date = datetime.fromisoformat(t['timestamp'])
if start_date <= t_date <= end_date:
model = t['model']
if model not in breakdown:
breakdown[model] = 0
breakdown[model] += t['cost']
return breakdown
6.2 成本告警
# cost_alerting.py
from typing import Dict, List
class CostAlerter:
"""成本告警器"""
def __init__(self):
self.alert_rules: List[Dict] = []
self.alert_history: List[Dict] = []
def add_rule(
self,
name: str,
metric: str,
threshold: float,
severity: str
):
"""添加告警规则"""
self.alert_rules.append({
'name': name,
'metric': metric,
'threshold': threshold,
'severity': severity
})
def check_alerts(
self,
current_metrics: Dict
) -> List[Dict]:
"""检查告警"""
triggered = []
for rule in self.alert_rules:
metric_value = current_metrics.get(rule['metric'])
if metric_value and metric_value > rule['threshold']:
alert = {
'rule': rule['name'],
'metric': rule['metric'],
'value': metric_value,
'threshold': rule['threshold'],
'severity': rule['severity'],
'timestamp': datetime.now().isoformat()
}
triggered.append(alert)
self.alert_history.append(alert)
return triggered
# 预定义告警规则
DEFAULT_COST_ALERTS = [
{
'name': 'Daily Budget Exceeded',
'metric': 'daily_cost',
'threshold': 100.0,
'severity': 'critical'
},
{
'name': 'High Cost Per Query',
'metric': 'avg_cost_per_query',
'threshold': 0.05,
'severity': 'warning'
},
{
'name': 'Low Cache Hit Rate',
'metric': 'cache_hit_rate',
'threshold': 0.3,
'severity': 'warning'
}
]
七、总结
7.1 核心要点
-
Token 优化
- Prompt 压缩
- 上下文优化
- 移除冗余
-
缓存策略
- 响应缓存
- 语义缓存
- 多级缓存
-
模型选择
- 智能路由
- 降级策略
- 成本感知
7.2 最佳实践
-
监控先行
- 实时成本追踪
- 预算告警
- 成本分解
-
持续优化
- 定期分析成本结构
- 识别优化机会
- A/B 测试验证
-
平衡质量与成本
- 根据场景选择模型
- 设置合理的缓存策略
- 监控用户满意度
参考资料