Skip to content
清晨的一缕阳光
返回

Redis 计数器实战

Redis 计数器实战

计数器是 Redis 最常见的应用场景之一。凭借原子操作和高性能,Redis 计数器广泛应用于限流、统计、分布式 ID 等场景。本文将深入计数器的各种实现方案。

一、基础计数器

1.1 原子操作

import redis

class Counter:
    def __init__(self, redis_client, key):
        self.redis = redis_client
        self.key = key
    
    def increment(self, amount=1):
        """自增"""
        return self.redis.incrby(self.key, amount)
    
    def decrement(self, amount=1):
        """自减"""
        return self.redis.decrby(self.key, amount)
    
    def get(self):
        """获取当前值"""
        value = self.redis.get(self.key)
        return int(value) if value else 0
    
    def reset(self):
        """重置"""
        self.redis.delete(self.key)

# 使用示例
counter = Counter(redis, "page:view:1001")

# 增加计数
count = counter.increment()  # +1
count = counter.increment(5)  # +5

# 获取当前值
current = counter.get()

# 重置
counter.reset()

1.2 带过期时间的计数器

class ExpiringCounter:
    def __init__(self, redis_client, key, ttl=86400):
        self.redis = redis_client
        self.key = key
        self.ttl = ttl
    
    def increment(self, amount=1):
        """自增并设置过期时间"""
        pipe = self.redis.pipeline()
        
        # 自增
        pipe.incrby(self.key, amount)
        
        # 设置过期时间(仅当 key 不存在时)
        pipe.expire(self.key, self.ttl)
        
        pipe.execute()
        
        return self.get()
    
    def get(self):
        """获取当前值"""
        value = self.redis.get(self.key)
        return int(value) if value else 0

# 使用示例
# 每日计数器
daily_counter = ExpiringCounter(redis, "daily:visits", ttl=86400)

# 增加访问数
visits = daily_counter.increment()

1.3 多位计数器

class MultiCounter:
    def __init__(self, redis_client, prefix):
        self.redis = redis_client
        self.prefix = prefix
    
    def get_key(self, *parts):
        """生成 Key"""
        return f"{self.prefix}:{':'.join(str(p) for p in parts)}"
    
    def increment(self, *parts, amount=1):
        """增加计数"""
        key = self.get_key(*parts)
        return self.redis.incrby(key, amount)
    
    def get(self, *parts):
        """获取计数"""
        key = self.get_key(*parts)
        value = self.redis.get(key)
        return int(value) if value else 0
    
    def get_all(self, *prefix_parts):
        """获取所有子计数"""
        pattern = f"{self.get_key(*prefix_parts)}:*"
        keys = self.redis.keys(pattern)
        
        result = {}
        for key in keys:
            value = self.redis.get(key)
            result[key.decode()] = int(value) if value else 0
        
        return result

# 使用示例
mc = MultiCounter(redis, "article")

# 文章 1001 的阅读量
mc.increment(1001, 'views')

# 文章 1001 的点赞数
mc.increment(1001, 'likes')

# 获取文章 1001 的所有统计
stats = mc.get_all(1001)
# {'article:1001:views': 1000, 'article:1001:likes': 100}

二、限流器

2.1 固定窗口限流

class FixedWindowRateLimiter:
    def __init__(self, redis_client, key_prefix, max_requests, window_seconds):
        self.redis = redis_client
        self.key_prefix = key_prefix
        self.max_requests = max_requests
        self.window = window_seconds
    
    def _get_key(self, identifier):
        """生成 Key"""
        window_id = int(time.time() // self.window)
        return f"{self.key_prefix}:{identifier}:{window_id}"
    
    def is_allowed(self, identifier):
        """检查是否允许请求"""
        key = self._get_key(identifier)
        
        # 原子自增
        current = self.redis.incr(key)
        
        # 首次请求设置过期时间
        if current == 1:
            self.redis.expire(key, self.window)
        
        return current <= self.max_requests

# 使用示例
# 每分钟最多 100 次请求
limiter = FixedWindowRateLimiter(redis, "rate:user", max_requests=100, window_seconds=60)

# 检查请求
if limiter.is_allowed("user:1001"):
    process_request()
else:
    return "Too many requests", 429

2.2 滑动窗口限流

class SlidingWindowRateLimiter:
    def __init__(self, redis_client, key_prefix, max_requests, window_seconds):
        self.redis = redis_client
        self.key_prefix = key_prefix
        self.max_requests = max_requests
        self.window = window_seconds
    
    def _get_key(self, identifier):
        """生成 Key"""
        return f"{self.key_prefix}:{identifier}"
    
    def is_allowed(self, identifier):
        """检查是否允许请求"""
        key = self._get_key(identifier)
        now = time.time()
        window_start = now - self.window
        
        pipe = self.redis.pipeline()
        
        # 删除窗口外的请求
        pipe.zremrangebyscore(key, 0, window_start)
        
        # 添加当前请求
        pipe.zadd(key, {str(now): now})
        
        # 设置过期时间
        pipe.expire(key, self.window)
        
        # 统计窗口内请求数
        pipe.zcard(key)
        
        results = pipe.execute()
        count = results[3]
        
        return count <= self.max_requests

# 使用示例
# 每分钟最多 100 次请求(滑动窗口)
limiter = SlidingWindowRateLimiter(redis, "rate:user", max_requests=100, window_seconds=60)

if limiter.is_allowed("user:1001"):
    process_request()
else:
    return "Too many requests", 429

2.3 令牌桶限流

class TokenBucketRateLimiter:
    def __init__(self, redis_client, key_prefix, capacity, refill_rate):
        self.redis = redis_client
        self.key_prefix = key_prefix
        self.capacity = capacity  # 桶容量
        self.refill_rate = refill_rate  # 每秒补充的令牌数
    
    def _get_key(self, identifier):
        """生成 Key"""
        return f"{self.key_prefix}:{identifier}"
    
    def is_allowed(self, identifier, tokens=1):
        """检查是否允许请求"""
        key = self._get_key(identifier)
        now = time.time()
        
        lua_script = """
        local key = KEYS[1]
        local capacity = tonumber(ARGV[1])
        local refill_rate = tonumber(ARGV[2])
        local now = tonumber(ARGV[3])
        local requested = tonumber(ARGV[4])
        
        -- 获取当前令牌数和时间戳
        local bucket = redis.call('HMGET', key, 'tokens', 'last_refill')
        local tokens = tonumber(bucket[1]) or capacity
        local last_refill = tonumber(bucket[2]) or now
        
        -- 计算应补充的令牌数
        local elapsed = now - last_refill
        local refill = elapsed * refill_rate
        tokens = math.min(capacity, tokens + refill)
        
        -- 检查是否有足够令牌
        if tokens >= requested then
            tokens = tokens - requested
            redis.call('HMSET', key, 'tokens', tokens, 'last_refill', now)
            redis.call('EXPIRE', key, math.ceil(capacity / refill_rate))
            return 1
        else
            redis.call('HMSET', key, 'tokens', tokens, 'last_refill', now)
            return 0
        end
        """
        
        result = self.redis.eval(
            lua_script, 1, key,
            self.capacity, self.refill_rate, now, tokens
        )
        
        return result == 1

# 使用示例
# 容量 100 个令牌,每秒补充 10 个
limiter = TokenBucketRateLimiter(redis, "rate:user", capacity=100, refill_rate=10)

if limiter.is_allowed("user:1001", tokens=1):
    process_request()
else:
    return "Too many requests", 429

三、去重计数器

3.1 HyperLogLog 去重

class UniqueCounter:
    def __init__(self, redis_client, key):
        self.redis = redis_client
        self.key = key
    
    def add(self, *items):
        """添加元素"""
        self.redis.pfadd(self.key, *items)
    
    def count(self):
        """获取去重后的数量"""
        return self.redis.pfcount(self.key)
    
    def merge(self, *other_keys):
        """合并多个 HyperLogLog"""
        self.redis.pfmerge(self.key, *other_keys)

# 使用示例
# 统计页面 UV
uv_counter = UniqueCounter(redis, "page:1001:uv")

# 添加用户访问
uv_counter.add("user:1001", "user:1002", "user:1003")

# 获取 UV 数量
uv_count = uv_counter.count()

# 合并多页面 UV
page1_uv = UniqueCounter(redis, "page:1001:uv")
page2_uv = UniqueCounter(redis, "page:1002:uv")
total_uv = UniqueCounter(redis, "site:total:uv")

total_uv.merge(page1_uv.key, page2_uv.key)
total_count = total_uv.count()

3.2 Set 去重

class SetUniqueCounter:
    def __init__(self, redis_client, key):
        self.redis = redis_client
        self.key = key
    
    def add(self, *items):
        """添加元素"""
        return self.redis.sadd(self.key, *items)
    
    def count(self):
        """获取去重后的数量"""
        return self.redis.scard(self.key)
    
    def is_member(self, item):
        """检查是否已存在"""
        return self.redis.sismember(self.key, item)
    
    def get_all(self):
        """获取所有元素"""
        return self.redis.smembers(self.key)

# 使用示例
# 签到统计
checkin_counter = SetUniqueCounter(redis, "user:1001:checkin:2024-01")

# 签到
checkin_counter.add("2024-01-01")
checkin_counter.add("2024-01-02")

# 获取签到天数
days = checkin_counter.count()

# 检查是否已签到
is_checkin = checkin_counter.is_member("2024-01-01")

3.3 Bitmap 去重

class BitmapCounter:
    def __init__(self, redis_client, key):
        self.redis = redis_client
        self.key = key
    
    def set(self, offset):
        """设置位"""
        return self.redis.setbit(self.key, offset, 1)
    
    def get(self, offset):
        """获取位"""
        return self.redis.getbit(self.key, offset)
    
    def count(self):
        """统计 1 的数量"""
        return self.redis.bitcount(self.key)

# 使用示例
# 用户签到(每月)
bitmap = BitmapCounter(redis, "user:1001:checkin:2024-01")

# 第 1 天签到
bitmap.set(0)

# 第 2 天签到
bitmap.set(1)

# 获取签到天数
days = bitmap.count()

# 检查第 1 天是否签到
is_checkin = bitmap.get(0)

四、分布式 ID 生成器

4.1 基于 INCR 的 ID 生成器

class IDGenerator:
    def __init__(self, redis_client, prefix):
        self.redis = redis_client
        self.prefix = prefix
    
    def generate(self):
        """生成唯一 ID"""
        key = f"{self.prefix}:id"
        return self.redis.incr(key)
    
    def generate_batch(self, count=100):
        """批量生成 ID"""
        key = f"{self.prefix}:id"
        return self.redis.incrby(key, count) - count + 1

# 使用示例
id_gen = IDGenerator(redis, "order")

# 生成订单 ID
order_id = id_gen.generate()  # 1001

# 批量生成
start_id = id_gen.generate_batch(100)  # 1002-1101

4.2 带日期的 ID 生成器

from datetime import datetime

class DateIDGenerator:
    def __init__(self, redis_client, prefix):
        self.redis = redis_client
        self.prefix = prefix
    
    def generate(self):
        """生成带日期的 ID"""
        date = datetime.now().strftime('%Y%m%d')
        key = f"{self.prefix}:id:{date}"
        
        # 原子自增
        seq = self.redis.incr(key)
        
        # 设置过期时间(保留 30 天)
        self.redis.expire(key, 86400 * 30)
        
        # 生成 ID:日期 + 序号
        return f"{date}{seq:06d}"

# 使用示例
id_gen = DateIDGenerator(redis, "order")

# 生成订单 ID
order_id = id_gen.generate()  # 20240101000001

4.3 雪花算法 + Redis

class SnowflakeIDGenerator:
    def __init__(self, redis_client, worker_id, prefix="id"):
        self.redis = redis_client
        self.worker_id = worker_id & 0x3FF  # 10 位
        self.prefix = prefix
        
        # 起始时间戳(2024-01-01)
        self.start_timestamp = 1704067200000
        
        # 位数分配
        self.timestamp_bits = 41
        self.worker_bits = 10
        self.sequence_bits = 12
        
        self.sequence = 0
        self.last_timestamp = -1
    
    def _next_timestamp(self):
        """获取下一个时间戳"""
        timestamp = int(time.time() * 1000)
        if timestamp < self.last_timestamp:
            raise Exception("Clock moved backwards")
        return timestamp
    
    def _wait_next_millis(self, last_timestamp):
        """等待到下一毫秒"""
        timestamp = self._next_timestamp()
        while timestamp <= last_timestamp:
            timestamp = self._next_timestamp()
        return timestamp
    
    def generate(self):
        """生成雪花 ID"""
        timestamp = self._next_timestamp()
        
        # 时间戳相同,序号 +1
        if timestamp == self.last_timestamp:
            self.sequence = (self.sequence + 1) & 0xFFF
            if self.sequence == 0:
                timestamp = self._wait_next_millis(self.last_timestamp)
        else:
            self.sequence = 0
        
        self.last_timestamp = timestamp
        
        # 生成 ID
        timestamp_offset = timestamp - self.start_timestamp
        worker_id_offset = self.worker_id << 12
        
        return (timestamp_offset << 22) | worker_id_offset | self.sequence

# 使用示例
id_gen = SnowflakeIDGenerator(redis, worker_id=1)

# 生成 ID
snowflake_id = id_gen.generate()  # 1704067200000001

五、统计计数器

5.1 多维度统计

class StatsCounter:
    def __init__(self, redis_client, prefix):
        self.redis = redis_client
        self.prefix = prefix
    
    def increment(self, metric, dimension, value=1):
        """增加统计"""
        key = f"{self.prefix}:{metric}:{dimension}"
        return self.redis.incrby(key, value)
    
    def get(self, metric, dimension):
        """获取统计"""
        key = f"{self.prefix}:{metric}:{dimension}"
        value = self.redis.get(key)
        return int(value) if value else 0
    
    def get_all(self, metric):
        """获取所有维度的统计"""
        pattern = f"{self.prefix}:{metric}:*"
        keys = self.redis.keys(pattern)
        
        result = {}
        for key in keys:
            dimension = key.decode().split(':')[-1]
            value = self.redis.get(key)
            result[dimension] = int(value) if value else 0
        
        return result

# 使用示例
stats = StatsCounter(redis, "article")

# 文章阅读量
stats.increment("views", "1001")

# 文章点赞数
stats.increment("likes", "1001")

# 获取文章 1001 的所有统计
article_stats = stats.get_all("1001")
# {'views': 1000, 'likes': 100}

5.2 时间段统计

class TimeStatsCounter:
    def __init__(self, redis_client, prefix):
        self.redis = redis_client
        self.prefix = prefix
    
    def _get_period_key(self, metric, period):
        """生成时间段 Key"""
        now = datetime.now()
        
        if period == 'hour':
            time_str = now.strftime('%Y%m%d%H')
        elif period == 'day':
            time_str = now.strftime('%Y%m%d')
        elif period == 'week':
            time_str = now.strftime('%Y%W')
        elif period == 'month':
            time_str = now.strftime('%Y%m')
        else:
            time_str = now.strftime('%Y%m%d')
        
        return f"{self.prefix}:{metric}:{period}:{time_str}"
    
    def increment(self, metric, period='day', value=1):
        """增加统计"""
        key = self._get_period_key(metric, period)
        return self.redis.incrby(key, value)
    
    def get(self, metric, period='day'):
        """获取统计"""
        key = self._get_period_key(metric, period)
        value = self.redis.get(key)
        return int(value) if value else 0

# 使用示例
time_stats = TimeStatsCounter(redis, "site")

# 今日访问量
time_stats.increment("visits", period='day')

# 本月访问量
time_stats.increment("visits", period='month')

# 获取今日统计
today_visits = time_stats.get("visits", period='day')

六、最佳实践

6.1 性能优化

# 批量操作
def batch_increment(counters):
    """批量增加计数"""
    pipe = redis.pipeline()
    
    for key, value in counters.items():
        pipe.incrby(key, value)
    
    pipe.execute()

# 使用示例
batch_increment({
    "article:1001:views": 1,
    "article:1002:views": 1,
    "article:1003:views": 1
})

6.2 内存优化

# 设置过期时间
def increment_with_expiry(key, amount=1, ttl=86400):
    """增加计数并设置过期时间"""
    pipe = redis.pipeline()
    pipe.incrby(key, amount)
    pipe.expire(key, ttl)
    pipe.execute()

# 使用示例
increment_with_expiry("temp:counter", ttl=3600)

6.3 监控指标

# 监控计数器大小
def get_counter_size(pattern):
    """获取计数器数量"""
    cursor = 0
    count = 0
    
    while True:
        cursor, keys = redis.scan(cursor, match=pattern, count=100)
        count += len(keys)
        
        if cursor == 0:
            break
    
    return count

# 使用示例
counter_count = get_counter_size("article:*:views")

总结

Redis 计数器核心要点:

场景实现方式特点
基础计数INCR/INCRBY原子操作
限流固定窗口/滑动窗口/令牌桶多种策略
去重HyperLogLog/Set/Bitmap不同精度
分布式 IDINCR/雪花算法唯一性保证
统计多维度/时间段灵活统计

最佳实践

  1. 使用原子操作保证准确性
  2. 设置合理的过期时间
  3. 批量操作减少网络往返
  4. 选择合适的去重方案
  5. 监控计数器大小
  6. 定期清理过期数据

掌握计数器实现,构建高性能统计系统!

参考资料


分享这篇文章到:

上一篇文章
GORM 实战指南 - Go 语言 ORM 最佳实践
下一篇文章
香港大埔宏福苑五级火灾事故:悲剧与反思