Redis 计数器实战
计数器是 Redis 最常见的应用场景之一。凭借原子操作和高性能,Redis 计数器广泛应用于限流、统计、分布式 ID 等场景。本文将深入计数器的各种实现方案。
一、基础计数器
1.1 原子操作
import redis
class Counter:
def __init__(self, redis_client, key):
self.redis = redis_client
self.key = key
def increment(self, amount=1):
"""自增"""
return self.redis.incrby(self.key, amount)
def decrement(self, amount=1):
"""自减"""
return self.redis.decrby(self.key, amount)
def get(self):
"""获取当前值"""
value = self.redis.get(self.key)
return int(value) if value else 0
def reset(self):
"""重置"""
self.redis.delete(self.key)
# 使用示例
counter = Counter(redis, "page:view:1001")
# 增加计数
count = counter.increment() # +1
count = counter.increment(5) # +5
# 获取当前值
current = counter.get()
# 重置
counter.reset()
1.2 带过期时间的计数器
class ExpiringCounter:
def __init__(self, redis_client, key, ttl=86400):
self.redis = redis_client
self.key = key
self.ttl = ttl
def increment(self, amount=1):
"""自增并设置过期时间"""
pipe = self.redis.pipeline()
# 自增
pipe.incrby(self.key, amount)
# 设置过期时间(仅当 key 不存在时)
pipe.expire(self.key, self.ttl)
pipe.execute()
return self.get()
def get(self):
"""获取当前值"""
value = self.redis.get(self.key)
return int(value) if value else 0
# 使用示例
# 每日计数器
daily_counter = ExpiringCounter(redis, "daily:visits", ttl=86400)
# 增加访问数
visits = daily_counter.increment()
1.3 多位计数器
class MultiCounter:
def __init__(self, redis_client, prefix):
self.redis = redis_client
self.prefix = prefix
def get_key(self, *parts):
"""生成 Key"""
return f"{self.prefix}:{':'.join(str(p) for p in parts)}"
def increment(self, *parts, amount=1):
"""增加计数"""
key = self.get_key(*parts)
return self.redis.incrby(key, amount)
def get(self, *parts):
"""获取计数"""
key = self.get_key(*parts)
value = self.redis.get(key)
return int(value) if value else 0
def get_all(self, *prefix_parts):
"""获取所有子计数"""
pattern = f"{self.get_key(*prefix_parts)}:*"
keys = self.redis.keys(pattern)
result = {}
for key in keys:
value = self.redis.get(key)
result[key.decode()] = int(value) if value else 0
return result
# 使用示例
mc = MultiCounter(redis, "article")
# 文章 1001 的阅读量
mc.increment(1001, 'views')
# 文章 1001 的点赞数
mc.increment(1001, 'likes')
# 获取文章 1001 的所有统计
stats = mc.get_all(1001)
# {'article:1001:views': 1000, 'article:1001:likes': 100}
二、限流器
2.1 固定窗口限流
class FixedWindowRateLimiter:
def __init__(self, redis_client, key_prefix, max_requests, window_seconds):
self.redis = redis_client
self.key_prefix = key_prefix
self.max_requests = max_requests
self.window = window_seconds
def _get_key(self, identifier):
"""生成 Key"""
window_id = int(time.time() // self.window)
return f"{self.key_prefix}:{identifier}:{window_id}"
def is_allowed(self, identifier):
"""检查是否允许请求"""
key = self._get_key(identifier)
# 原子自增
current = self.redis.incr(key)
# 首次请求设置过期时间
if current == 1:
self.redis.expire(key, self.window)
return current <= self.max_requests
# 使用示例
# 每分钟最多 100 次请求
limiter = FixedWindowRateLimiter(redis, "rate:user", max_requests=100, window_seconds=60)
# 检查请求
if limiter.is_allowed("user:1001"):
process_request()
else:
return "Too many requests", 429
2.2 滑动窗口限流
class SlidingWindowRateLimiter:
def __init__(self, redis_client, key_prefix, max_requests, window_seconds):
self.redis = redis_client
self.key_prefix = key_prefix
self.max_requests = max_requests
self.window = window_seconds
def _get_key(self, identifier):
"""生成 Key"""
return f"{self.key_prefix}:{identifier}"
def is_allowed(self, identifier):
"""检查是否允许请求"""
key = self._get_key(identifier)
now = time.time()
window_start = now - self.window
pipe = self.redis.pipeline()
# 删除窗口外的请求
pipe.zremrangebyscore(key, 0, window_start)
# 添加当前请求
pipe.zadd(key, {str(now): now})
# 设置过期时间
pipe.expire(key, self.window)
# 统计窗口内请求数
pipe.zcard(key)
results = pipe.execute()
count = results[3]
return count <= self.max_requests
# 使用示例
# 每分钟最多 100 次请求(滑动窗口)
limiter = SlidingWindowRateLimiter(redis, "rate:user", max_requests=100, window_seconds=60)
if limiter.is_allowed("user:1001"):
process_request()
else:
return "Too many requests", 429
2.3 令牌桶限流
class TokenBucketRateLimiter:
def __init__(self, redis_client, key_prefix, capacity, refill_rate):
self.redis = redis_client
self.key_prefix = key_prefix
self.capacity = capacity # 桶容量
self.refill_rate = refill_rate # 每秒补充的令牌数
def _get_key(self, identifier):
"""生成 Key"""
return f"{self.key_prefix}:{identifier}"
def is_allowed(self, identifier, tokens=1):
"""检查是否允许请求"""
key = self._get_key(identifier)
now = time.time()
lua_script = """
local key = KEYS[1]
local capacity = tonumber(ARGV[1])
local refill_rate = tonumber(ARGV[2])
local now = tonumber(ARGV[3])
local requested = tonumber(ARGV[4])
-- 获取当前令牌数和时间戳
local bucket = redis.call('HMGET', key, 'tokens', 'last_refill')
local tokens = tonumber(bucket[1]) or capacity
local last_refill = tonumber(bucket[2]) or now
-- 计算应补充的令牌数
local elapsed = now - last_refill
local refill = elapsed * refill_rate
tokens = math.min(capacity, tokens + refill)
-- 检查是否有足够令牌
if tokens >= requested then
tokens = tokens - requested
redis.call('HMSET', key, 'tokens', tokens, 'last_refill', now)
redis.call('EXPIRE', key, math.ceil(capacity / refill_rate))
return 1
else
redis.call('HMSET', key, 'tokens', tokens, 'last_refill', now)
return 0
end
"""
result = self.redis.eval(
lua_script, 1, key,
self.capacity, self.refill_rate, now, tokens
)
return result == 1
# 使用示例
# 容量 100 个令牌,每秒补充 10 个
limiter = TokenBucketRateLimiter(redis, "rate:user", capacity=100, refill_rate=10)
if limiter.is_allowed("user:1001", tokens=1):
process_request()
else:
return "Too many requests", 429
三、去重计数器
3.1 HyperLogLog 去重
class UniqueCounter:
def __init__(self, redis_client, key):
self.redis = redis_client
self.key = key
def add(self, *items):
"""添加元素"""
self.redis.pfadd(self.key, *items)
def count(self):
"""获取去重后的数量"""
return self.redis.pfcount(self.key)
def merge(self, *other_keys):
"""合并多个 HyperLogLog"""
self.redis.pfmerge(self.key, *other_keys)
# 使用示例
# 统计页面 UV
uv_counter = UniqueCounter(redis, "page:1001:uv")
# 添加用户访问
uv_counter.add("user:1001", "user:1002", "user:1003")
# 获取 UV 数量
uv_count = uv_counter.count()
# 合并多页面 UV
page1_uv = UniqueCounter(redis, "page:1001:uv")
page2_uv = UniqueCounter(redis, "page:1002:uv")
total_uv = UniqueCounter(redis, "site:total:uv")
total_uv.merge(page1_uv.key, page2_uv.key)
total_count = total_uv.count()
3.2 Set 去重
class SetUniqueCounter:
def __init__(self, redis_client, key):
self.redis = redis_client
self.key = key
def add(self, *items):
"""添加元素"""
return self.redis.sadd(self.key, *items)
def count(self):
"""获取去重后的数量"""
return self.redis.scard(self.key)
def is_member(self, item):
"""检查是否已存在"""
return self.redis.sismember(self.key, item)
def get_all(self):
"""获取所有元素"""
return self.redis.smembers(self.key)
# 使用示例
# 签到统计
checkin_counter = SetUniqueCounter(redis, "user:1001:checkin:2024-01")
# 签到
checkin_counter.add("2024-01-01")
checkin_counter.add("2024-01-02")
# 获取签到天数
days = checkin_counter.count()
# 检查是否已签到
is_checkin = checkin_counter.is_member("2024-01-01")
3.3 Bitmap 去重
class BitmapCounter:
def __init__(self, redis_client, key):
self.redis = redis_client
self.key = key
def set(self, offset):
"""设置位"""
return self.redis.setbit(self.key, offset, 1)
def get(self, offset):
"""获取位"""
return self.redis.getbit(self.key, offset)
def count(self):
"""统计 1 的数量"""
return self.redis.bitcount(self.key)
# 使用示例
# 用户签到(每月)
bitmap = BitmapCounter(redis, "user:1001:checkin:2024-01")
# 第 1 天签到
bitmap.set(0)
# 第 2 天签到
bitmap.set(1)
# 获取签到天数
days = bitmap.count()
# 检查第 1 天是否签到
is_checkin = bitmap.get(0)
四、分布式 ID 生成器
4.1 基于 INCR 的 ID 生成器
class IDGenerator:
def __init__(self, redis_client, prefix):
self.redis = redis_client
self.prefix = prefix
def generate(self):
"""生成唯一 ID"""
key = f"{self.prefix}:id"
return self.redis.incr(key)
def generate_batch(self, count=100):
"""批量生成 ID"""
key = f"{self.prefix}:id"
return self.redis.incrby(key, count) - count + 1
# 使用示例
id_gen = IDGenerator(redis, "order")
# 生成订单 ID
order_id = id_gen.generate() # 1001
# 批量生成
start_id = id_gen.generate_batch(100) # 1002-1101
4.2 带日期的 ID 生成器
from datetime import datetime
class DateIDGenerator:
def __init__(self, redis_client, prefix):
self.redis = redis_client
self.prefix = prefix
def generate(self):
"""生成带日期的 ID"""
date = datetime.now().strftime('%Y%m%d')
key = f"{self.prefix}:id:{date}"
# 原子自增
seq = self.redis.incr(key)
# 设置过期时间(保留 30 天)
self.redis.expire(key, 86400 * 30)
# 生成 ID:日期 + 序号
return f"{date}{seq:06d}"
# 使用示例
id_gen = DateIDGenerator(redis, "order")
# 生成订单 ID
order_id = id_gen.generate() # 20240101000001
4.3 雪花算法 + Redis
class SnowflakeIDGenerator:
def __init__(self, redis_client, worker_id, prefix="id"):
self.redis = redis_client
self.worker_id = worker_id & 0x3FF # 10 位
self.prefix = prefix
# 起始时间戳(2024-01-01)
self.start_timestamp = 1704067200000
# 位数分配
self.timestamp_bits = 41
self.worker_bits = 10
self.sequence_bits = 12
self.sequence = 0
self.last_timestamp = -1
def _next_timestamp(self):
"""获取下一个时间戳"""
timestamp = int(time.time() * 1000)
if timestamp < self.last_timestamp:
raise Exception("Clock moved backwards")
return timestamp
def _wait_next_millis(self, last_timestamp):
"""等待到下一毫秒"""
timestamp = self._next_timestamp()
while timestamp <= last_timestamp:
timestamp = self._next_timestamp()
return timestamp
def generate(self):
"""生成雪花 ID"""
timestamp = self._next_timestamp()
# 时间戳相同,序号 +1
if timestamp == self.last_timestamp:
self.sequence = (self.sequence + 1) & 0xFFF
if self.sequence == 0:
timestamp = self._wait_next_millis(self.last_timestamp)
else:
self.sequence = 0
self.last_timestamp = timestamp
# 生成 ID
timestamp_offset = timestamp - self.start_timestamp
worker_id_offset = self.worker_id << 12
return (timestamp_offset << 22) | worker_id_offset | self.sequence
# 使用示例
id_gen = SnowflakeIDGenerator(redis, worker_id=1)
# 生成 ID
snowflake_id = id_gen.generate() # 1704067200000001
五、统计计数器
5.1 多维度统计
class StatsCounter:
def __init__(self, redis_client, prefix):
self.redis = redis_client
self.prefix = prefix
def increment(self, metric, dimension, value=1):
"""增加统计"""
key = f"{self.prefix}:{metric}:{dimension}"
return self.redis.incrby(key, value)
def get(self, metric, dimension):
"""获取统计"""
key = f"{self.prefix}:{metric}:{dimension}"
value = self.redis.get(key)
return int(value) if value else 0
def get_all(self, metric):
"""获取所有维度的统计"""
pattern = f"{self.prefix}:{metric}:*"
keys = self.redis.keys(pattern)
result = {}
for key in keys:
dimension = key.decode().split(':')[-1]
value = self.redis.get(key)
result[dimension] = int(value) if value else 0
return result
# 使用示例
stats = StatsCounter(redis, "article")
# 文章阅读量
stats.increment("views", "1001")
# 文章点赞数
stats.increment("likes", "1001")
# 获取文章 1001 的所有统计
article_stats = stats.get_all("1001")
# {'views': 1000, 'likes': 100}
5.2 时间段统计
class TimeStatsCounter:
def __init__(self, redis_client, prefix):
self.redis = redis_client
self.prefix = prefix
def _get_period_key(self, metric, period):
"""生成时间段 Key"""
now = datetime.now()
if period == 'hour':
time_str = now.strftime('%Y%m%d%H')
elif period == 'day':
time_str = now.strftime('%Y%m%d')
elif period == 'week':
time_str = now.strftime('%Y%W')
elif period == 'month':
time_str = now.strftime('%Y%m')
else:
time_str = now.strftime('%Y%m%d')
return f"{self.prefix}:{metric}:{period}:{time_str}"
def increment(self, metric, period='day', value=1):
"""增加统计"""
key = self._get_period_key(metric, period)
return self.redis.incrby(key, value)
def get(self, metric, period='day'):
"""获取统计"""
key = self._get_period_key(metric, period)
value = self.redis.get(key)
return int(value) if value else 0
# 使用示例
time_stats = TimeStatsCounter(redis, "site")
# 今日访问量
time_stats.increment("visits", period='day')
# 本月访问量
time_stats.increment("visits", period='month')
# 获取今日统计
today_visits = time_stats.get("visits", period='day')
六、最佳实践
6.1 性能优化
# 批量操作
def batch_increment(counters):
"""批量增加计数"""
pipe = redis.pipeline()
for key, value in counters.items():
pipe.incrby(key, value)
pipe.execute()
# 使用示例
batch_increment({
"article:1001:views": 1,
"article:1002:views": 1,
"article:1003:views": 1
})
6.2 内存优化
# 设置过期时间
def increment_with_expiry(key, amount=1, ttl=86400):
"""增加计数并设置过期时间"""
pipe = redis.pipeline()
pipe.incrby(key, amount)
pipe.expire(key, ttl)
pipe.execute()
# 使用示例
increment_with_expiry("temp:counter", ttl=3600)
6.3 监控指标
# 监控计数器大小
def get_counter_size(pattern):
"""获取计数器数量"""
cursor = 0
count = 0
while True:
cursor, keys = redis.scan(cursor, match=pattern, count=100)
count += len(keys)
if cursor == 0:
break
return count
# 使用示例
counter_count = get_counter_size("article:*:views")
总结
Redis 计数器核心要点:
| 场景 | 实现方式 | 特点 |
|---|---|---|
| 基础计数 | INCR/INCRBY | 原子操作 |
| 限流 | 固定窗口/滑动窗口/令牌桶 | 多种策略 |
| 去重 | HyperLogLog/Set/Bitmap | 不同精度 |
| 分布式 ID | INCR/雪花算法 | 唯一性保证 |
| 统计 | 多维度/时间段 | 灵活统计 |
最佳实践:
- 使用原子操作保证准确性
- 设置合理的过期时间
- 批量操作减少网络往返
- 选择合适的去重方案
- 监控计数器大小
- 定期清理过期数据
掌握计数器实现,构建高性能统计系统!