Skip to content
清晨的一缕阳光
返回

Go Context 上下文实战

Go Context 上下文实战

context 包是 Go 语言处理并发任务生命周期的核心工具。它在 goroutine 之间传递取消信号、超时信息和请求范围的元数据。本文将深入 context 的实现原理,分享实战中的最佳实践。

一、Context 基础回顾

1.1 核心接口

type Context interface {
    // 返回截止时间(若有)
    Deadline() (deadline time.Time, ok bool)
    
    // 返回取消信号通道
    Done() <-chan struct{}
    
    // 返回取消原因
    Err() error
    
    // 根据 key 查找值
    Value(key any) any
}

1.2 基本使用

// 创建根上下文
ctx := context.Background()

// 创建可取消上下文
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

// 创建超时上下文
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

// 在 goroutine 中使用
go func() {
    select {
    case <-ctx.Done():
        fmt.Println("收到取消信号")
        return
    case result := <-doWork():
        fmt.Println("完成:", result)
    }
}()

二、Context 实现原理

2.1 四种核心类型

context 包通过四种结构体实现所有功能:

// 1. 空上下文(根节点)
type emptyCtx int

// 2. 可取消上下文
type cancelCtx struct {
    Context
    mu       sync.Mutex
    done     chan struct{}
    children map[canceler]struct{}
    err      error
}

// 3. 超时上下文
type timerCtx struct {
    cancelCtx
    timer    *time.Timer
    deadline time.Time
}

// 4. 值上下文
type valueCtx struct {
    Context
    key, val any
}

2.2 层级关系

context.Background() (emptyCtx)
    └── WithCancel() (cancelCtx)
        ├── WithTimeout() (timerCtx)
        │   └── WithValue() (valueCtx)
        └── WithValue() (valueCtx)
            └── WithValue() (valueCtx)

特点:
- 子上下文继承父上下文的所有特性
- 父上下文取消时,所有子上下文递归取消
- 子上下文取消不影响父上下文

2.3 cancelCtx 核心实现

// cancelCtx 结构
type cancelCtx struct {
    Context          // 嵌入父上下文
    
    mu       sync.Mutex              // 保护以下字段
    done     chan struct{}           // 取消信号通道
    children map[canceler]struct{}   // 子上下文集合
    err      error                   // 取消原因
}

// 取消操作
func (c *cancelCtx) cancel(removeFromParent bool, err error) {
    c.mu.Lock()
    defer c.mu.Unlock()
    
    // 1. 检查是否已取消
    if c.err != nil {
        return
    }
    
    // 2. 设置取消原因
    c.err = err
    
    // 3. 关闭 done 通道
    if c.done == nil {
        c.done = closedchan
    } else {
        close(c.done)
    }
    
    // 4. 递归取消所有子上下文
    for child := range c.children {
        child.cancel(false, err)
    }
    c.children = nil
    
    // 5. 从父上下文移除自己
    if removeFromParent {
        removeFromParent(c.Context, c)
    }
}

2.4 取消信号传播

// 创建层级上下文
ctx1, cancel1 := context.WithCancel(context.Background())
ctx2, cancel2 := context.WithCancel(ctx1)
ctx3, cancel3 := context.WithCancel(ctx2)

// 取消父上下文
cancel1()

// 结果:
// - ctx1 被取消
// - ctx2 自动被取消(子上下文)
// - ctx3 自动被取消(孙子上下文)
// - ctx2 和 ctx3 的 Done 通道都会关闭

传播流程

cancel1()

1. 关闭 ctx1.done 通道

2. 遍历 ctx1.children
    ├─ ctx2.cancel(false, err)
    │   ├─ 关闭 ctx2.done
    │   └─ 遍历 ctx2.children
    │       └─ ctx3.cancel(false, err)
    │           └─ 关闭 ctx3.done
    └─ 清理 children

3. 所有监听 Done 的 goroutine 收到信号

三、WithCancel 详解

3.1 创建可取消上下文

func WithCancel(parent Context) (ctx Context, cancel CancelFunc) {
    if parent == nil {
        panic("cannot create context from nil parent")
    }
    
    // 创建 cancelCtx
    c := newCancelCtx(parent)
    
    // 传播取消信号
    propagateCancel(parent, c)
    
    return c, func() { c.cancel(true, Canceled) }
}

// 传播取消信号
func propagateCancel(parent Context, child canceler) {
    // 1. 检查父上下文是否已取消
    if parent.Done() == nil {
        return  // 父上下文永不会被取消
    }
    
    // 2. 检查父上下文是否是 cancelCtx
    if p, ok := parentCancelCtx(parent); ok {
        if p.err != nil {
            // 父已取消,立即取消子
            child.cancel(false, p.err)
        } else {
            // 加入父的 children 集合
            p.children[child] = struct{}{}
        }
        return
    }
    
    // 3. 启动 goroutine 监听父上下文
    go func() {
        select {
        case <-parent.Done():
            child.cancel(false, parent.Err())
        case <-child.Done():
        }
    }()
}

3.2 使用场景

// 场景 1:优雅退出
func worker(ctx context.Context) {
    for {
        select {
        case <-ctx.Done():
            fmt.Println("收到退出信号")
            return
        default:
            doWork()
        }
    }
}

// 启动 worker
ctx, cancel := context.WithCancel(context.Background())
go worker(ctx)

// 需要退出时
cancel()
// 场景 2:取消耗时操作
func processWithCancel(ctx context.Context, data string) error {
    resultCh := make(chan string, 1)
    
    go func() {
        result, err := heavyComputation(data)
        if err != nil {
            resultCh <- ""
        } else {
            resultCh <- result
        }
    }()
    
    select {
    case <-ctx.Done():
        return ctx.Err()
    case result := <-resultCh:
        if result == "" {
            return errors.New("计算失败")
        }
        return nil
    }
}
// 场景 3:批量任务取消
func batchProcess(ctx context.Context, items []string) error {
    errCh := make(chan error, len(items))
    
    for _, item := range items {
        go func(it string) {
            err := processItem(ctx, it)
            errCh <- err
        }(item)
    }
    
    // 等待所有任务完成或被取消
    for i := 0; i < len(items); i++ {
        select {
        case <-ctx.Done():
            return ctx.Err()
        case err := <-errCh:
            if err != nil {
                // 第一个错误,取消所有任务
                return err
            }
        }
    }
    
    return nil
}

四、WithTimeout 详解

4.1 创建超时上下文

func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) {
    return WithDeadline(parent, time.Now().Add(timeout))
}

func WithDeadline(parent Context, deadline time.Time) (Context, CancelFunc) {
    // 创建 timerCtx
    c := newCancelCtx(parent)
    t := &timerCtx{
        cancelCtx: c,
        deadline:  deadline,
    }
    
    // 传播取消信号
    propagateCancel(parent, t)
    
    // 计算超时时间
    d := time.Until(deadline)
    if d <= 0 {
        // 已经超时,立即取消
        t.cancel(true, DeadlineExceeded)
        return t, func() { t.cancel(true, Canceled) }
    }
    
    // 启动定时器
    t.timer = time.AfterFunc(d, func() {
        t.cancel(true, DeadlineExceeded)
    })
    
    return t, func() { t.cancel(true, Canceled) }
}

4.2 timerCtx 实现

type timerCtx struct {
    cancelCtx          // 嵌入 cancelCtx
    timer    *time.Timer
    deadline time.Time
}

// 取消操作
func (c *timerCtx) cancel(removeFromParent bool, err error) {
    // 1. 先调用 cancelCtx 的取消
    c.cancelCtx.cancel(removeFromParent, err)
    
    // 2. 停止定时器
    if c.timer != nil {
        c.timer.Stop()
        c.timer = nil
    }
}

4.3 使用场景

// 场景 1:HTTP 请求超时
func fetchWithTimeout(url string, timeout time.Duration) ([]byte, error) {
    ctx, cancel := context.WithTimeout(context.Background(), timeout)
    defer cancel()
    
    req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
    if err != nil {
        return nil, err
    }
    
    resp, err := http.DefaultClient.Do(req)
    if err != nil {
        return nil, err
    }
    defer resp.Body.Close()
    
    return io.ReadAll(resp.Body)
}
// 场景 2:数据库查询超时
func queryWithTimeout(ctx context.Context, db *sql.DB, query string) (*sql.Rows, error) {
    // 设置查询超时
    ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
    defer cancel()
    
    rows, err := db.QueryContext(ctx, query)
    if err != nil {
        return nil, err
    }
    
    return rows, nil
}
// 场景 3:RPC 调用超时
func rpcCall(ctx context.Context, client MyClient, req *Request) (*Response, error) {
    ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
    defer cancel()
    
    resp, err := client.Call(ctx, req)
    if err != nil {
        if errors.Is(err, context.DeadlineExceeded) {
            return nil, errors.New("RPC 调用超时")
        }
        return nil, err
    }
    
    return resp, nil
}

4.4 级联超时

// 父上下文超时时间 > 子上下文
ctx1, cancel1 := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel1()

// 子上下文超时时间更短
ctx2, cancel2 := context.WithTimeout(ctx1, 5*time.Second)
defer cancel2()

// 结果:
// - 5 秒后 ctx2 超时
// - ctx2 的子上下文也会被取消
// - ctx1 不受影响(还有 5 秒)

五、WithValue 详解

5.1 创建值上下文

func WithValue(parent Context, key, val any) Context {
    if key == nil {
        panic("nil key")
    }
    
    // 创建 valueCtx
    return &valueCtx{
        Context: parent,
        key:     key,
        val:     val,
    }
}

// valueCtx 实现
type valueCtx struct {
    Context
    key, val any
}

// Value 方法(递归向上查找)
func (c *valueCtx) Value(key any) any {
    if key == c.key {
        return c.val
    }
    return c.Context.Value(key)  // 递归查找父上下文
}

5.2 键类型最佳实践

// ❌ 不推荐:使用 string 作为键(可能冲突)
const userIDKey = "user_id"
ctx := context.WithValue(ctx, userIDKey, 123)

// ✅ 推荐:使用自定义类型作为键
type contextKey string

const UserIDKey contextKey = "user_id"

func WithUserID(ctx context.Context, userID string) context.Context {
    return context.WithValue(ctx, UserIDKey, userID)
}

func UserIDFromContext(ctx context.Context) string {
    if userID, ok := ctx.Value(UserIDKey).(string); ok {
        return userID
    }
    return ""
}

5.3 使用场景

// 场景 1:传递请求 ID
type contextKey string

const RequestIDKey contextKey = "request_id"

func middleware(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        requestID := r.Header.Get("X-Request-ID")
        if requestID == "" {
            requestID = generateID()
        }
        
        // 将 requestID 放入 context
        ctx := context.WithValue(r.Context(), RequestIDKey, requestID)
        next.ServeHTTP(w, r.WithContext(ctx))
    })
}

func handler(w http.ResponseWriter, r *http.Request) {
    requestID := r.Context().Value(RequestIDKey).(string)
    log.Printf("[%s] 处理请求", requestID)
}
// 场景 2:传递用户信息
type User struct {
    ID   string
    Name string
}

type contextKey string

const UserKey contextKey = "user"

func authMiddleware(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        token := r.Header.Get("Authorization")
        user, err := validateToken(token)
        if err != nil {
            http.Error(w, "Unauthorized", 401)
            return
        }
        
        ctx := context.WithValue(r.Context(), UserKey, user)
        next.ServeHTTP(w, r.WithContext(ctx))
    })
}

func getUserFromContext(ctx context.Context) *User {
    if user, ok := ctx.Value(UserKey).(*User); ok {
        return user
    }
    return nil
}
// 场景 3:传递数据库事务
func handleTransaction(ctx context.Context, db *sql.DB) error {
    // 开启事务
    tx, err := db.BeginTx(ctx, nil)
    if err != nil {
        return err
    }
    
    // 将事务放入 context
    ctx = context.WithValue(ctx, txKey, tx)
    
    // 执行多个操作
    if err := doSomething(ctx); err != nil {
        tx.Rollback()
        return err
    }
    
    if err := doAnotherThing(ctx); err != nil {
        tx.Rollback()
        return err
    }
    
    return tx.Commit()
}

func getTxFromContext(ctx context.Context) *sql.Tx {
    if tx, ok := ctx.Value(txKey).(*sql.Tx); ok {
        return tx
    }
    return nil
}

5.4 注意事项

// ❌ 不要传递大量数据
ctx := context.WithValue(ctx, "data", largeData)  // 不推荐

// ✅ 只传递请求范围的元数据
ctx := context.WithValue(ctx, RequestIDKey, requestID)

// ❌ 不要用 context 传递函数参数
func process(ctx context.Context, data string) {
    // 不推荐:从 context 获取参数
    data := ctx.Value(DataKey).(string)
}

// ✅ 参数应该显式传递
func process(ctx context.Context, data string) {
    // 推荐:显式参数
}

六、实战模式

6.1 优雅关闭

type Server struct {
    ctx    context.Context
    cancel context.CancelFunc
}

func NewServer() *Server {
    ctx, cancel := context.WithCancel(context.Background())
    return &Server{ctx: ctx, cancel: cancel}
}

func (s *Server) Start() {
    // 启动 worker
    for i := 0; i < 10; i++ {
        go s.worker(i)
    }
}

func (s *Server) worker(id int) {
    for {
        select {
        case <-s.ctx.Done():
            log.Printf("Worker %d 收到退出信号", id)
            return
        default:
            // 处理任务
        }
    }
}

func (s *Server) Stop() {
    s.cancel()  // 通知所有 worker 退出
}

// 使用
server := NewServer()
server.Start()

// 监听退出信号
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
<-sigCh

// 优雅关闭
server.Stop()

6.2 超时重试

func doWithRetry(ctx context.Context, fn func(context.Context) error, maxRetries int) error {
    var lastErr error
    
    for i := 0; i < maxRetries; i++ {
        // 创建带超时的子上下文
        retryCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
        
        // 执行操作
        err := fn(retryCtx)
        cancel()
        
        if err == nil {
            return nil  // 成功
        }
        
        lastErr = err
        
        // 检查是否可重试
        if errors.Is(err, context.Canceled) {
            return err  // 被取消,不重试
        }
        
        // 等待一段时间后重试
        select {
        case <-ctx.Done():
            return ctx.Err()
        case <-time.After(time.Second * time.Duration(i+1)):
        }
    }
    
    return lastErr
}

// 使用
err := doWithRetry(ctx, func(ctx context.Context) error {
    return fetchURL(ctx, "https://api.example.com/data")
}, 3)

6.3 并行任务控制

// 控制并发数量
func processWithLimit(ctx context.Context, items []string, limit int) error {
    sem := make(chan struct{}, limit)
    errCh := make(chan error, len(items))
    
    for _, item := range items {
        sem <- struct{}{}  // 获取令牌
        
        go func(it string) {
            defer func() { <-sem }()  // 释放令牌
            
            select {
            case <-ctx.Done():
                errCh <- ctx.Err()
            default:
                err := processItem(ctx, it)
                errCh <- err
            }
        }(item)
    }
    
    // 等待所有任务完成
    for i := 0; i < len(items); i++ {
        select {
        case <-ctx.Done():
            return ctx.Err()
        case err := <-errCh:
            if err != nil {
                return err
            }
        }
    }
    
    return nil
}

6.4 管道模式

// 阶段 1:生成数据
func generator(ctx context.Context, n int) <-chan int {
    out := make(chan int)
    go func() {
        defer close(out)
        for i := 0; i < n; i++ {
            select {
            case <-ctx.Done():
                return
            case out <- i:
            }
        }
    }()
    return out
}

// 阶段 2:处理数据
func processor(ctx context.Context, in <-chan int) <-chan int {
    out := make(chan int)
    go func() {
        defer close(out)
        for n := range in {
            select {
            case <-ctx.Done():
                return
            case out <- n * 2:
            }
        }
    }()
    return out
}

// 阶段 3:输出结果
func printer(ctx context.Context, in <-chan int) {
    for n := range in {
        select {
        case <-ctx.Done():
            return
        default:
            fmt.Println(n)
        }
    }
}

// 使用
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

ch1 := generator(ctx, 10)
ch2 := processor(ctx, ch1)
printer(ctx, ch2)

6.5 扇出扇入

// 扇出:一个输入,多个 worker 处理
func fanOut(ctx context.Context, in <-chan int, n int) []<-chan int {
    outs := make([]<-chan int, n)
    for i := 0; i < n; i++ {
        out := make(chan int)
        go func(out chan<- int) {
            defer close(out)
            for n := range in {
                select {
                case <-ctx.Done():
                    return
                case out <- n:
                }
            }
        }(out)
        outs[i] = out
    }
    return outs
}

// 扇入:多个输入,合并到一个输出
func fanIn(ctx context.Context, ins ...<-chan int) <-chan int {
    out := make(chan int)
    var wg sync.WaitGroup
    
    for _, in := range ins {
        wg.Add(1)
        go func(in <-chan int) {
            defer wg.Done()
            for n := range in {
                select {
                case <-ctx.Done():
                    return
                case out <- n:
                }
            }
        }(in)
    }
    
    go func() {
        wg.Wait()
        close(out)
    }()
    
    return out
}

七、常见问题

7.1 忘记调用 cancel

// ❌ 泄漏:未调用 cancel
ctx, cancel := context.WithTimeout(context.Background(), time.Hour)
// cancel() 未调用,timer 和 goroutine 泄漏

// ✅ 修复:使用 defer
ctx, cancel := context.WithTimeout(context.Background(), time.Hour)
defer cancel()

7.2 传递 nil context

// ❌ 错误:传递 nil
func process(ctx context.Context) {
    if ctx == nil {
        // 可能 panic
    }
}

// ✅ 正确:使用 Background
process(context.Background())

7.3 在结构体中存储 context

// ❌ 不推荐
type Server struct {
    ctx context.Context  // 不应该存储
}

// ✅ 推荐:作为参数传递
type Server struct{}

func (s *Server) Start(ctx context.Context) {
    // 使用 ctx
}

7.4 滥用 WithValue

// ❌ 不推荐:传递业务数据
ctx := context.WithValue(ctx, "user", user)
ctx = context.WithValue(ctx, "data", data)

// ✅ 推荐:只传递元数据
ctx := context.WithValue(ctx, RequestIDKey, requestID)
ctx = context.WithValue(ctx, UserKey, userID)

// 业务数据应该作为参数传递
func process(ctx context.Context, user *User, data string) {
    // ...
}

八、最佳实践总结

8.1 使用原则

原则说明
第一个参数ctx 应该是函数的第一个参数
不要存储不要在结构体中存储 context
传递 nil即使不使用,也传 context.Background()
及时 cancel创建后确保调用 cancel()
明确键类型使用自定义类型作为 Value 的键

8.2 选择指南

场景选择
根上下文context.Background()
不确定context.TODO()
需要取消WithCancel(parent)
需要超时WithTimeout(parent, d)
需要截止时间WithDeadline(parent, t)
传递元数据WithValue(parent, key, val)

8.3 性能考虑

// Context 链过长影响性能
ctx := context.Background()
ctx = context.WithValue(ctx, k1, v1)
ctx = context.WithValue(ctx, k2, v2)
ctx = context.WithValue(ctx, k3, v3)
// ... 继续嵌套

// Value 查找是 O(n) 复杂度
// 建议:只传递必要的元数据(3-5 个以内)

总结

Context 是 Go 并发编程的核心工具:

类型用途关键方法
emptyCtx根上下文
cancelCtx可取消cancel()
timerCtx超时控制WithTimeout()
valueCtx值传递WithValue()

核心机制

掌握 Context,能帮助你:

参考资料


分享这篇文章到:

上一篇文章
Java 日期时间 API 详解
下一篇文章
事务 ACID 特性详解