Go Context 上下文实战
context 包是 Go 语言处理并发任务生命周期的核心工具。它在 goroutine 之间传递取消信号、超时信息和请求范围的元数据。本文将深入 context 的实现原理,分享实战中的最佳实践。
一、Context 基础回顾
1.1 核心接口
type Context interface {
// 返回截止时间(若有)
Deadline() (deadline time.Time, ok bool)
// 返回取消信号通道
Done() <-chan struct{}
// 返回取消原因
Err() error
// 根据 key 查找值
Value(key any) any
}
1.2 基本使用
// 创建根上下文
ctx := context.Background()
// 创建可取消上下文
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// 创建超时上下文
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// 在 goroutine 中使用
go func() {
select {
case <-ctx.Done():
fmt.Println("收到取消信号")
return
case result := <-doWork():
fmt.Println("完成:", result)
}
}()
二、Context 实现原理
2.1 四种核心类型
context 包通过四种结构体实现所有功能:
// 1. 空上下文(根节点)
type emptyCtx int
// 2. 可取消上下文
type cancelCtx struct {
Context
mu sync.Mutex
done chan struct{}
children map[canceler]struct{}
err error
}
// 3. 超时上下文
type timerCtx struct {
cancelCtx
timer *time.Timer
deadline time.Time
}
// 4. 值上下文
type valueCtx struct {
Context
key, val any
}
2.2 层级关系
context.Background() (emptyCtx)
└── WithCancel() (cancelCtx)
├── WithTimeout() (timerCtx)
│ └── WithValue() (valueCtx)
└── WithValue() (valueCtx)
└── WithValue() (valueCtx)
特点:
- 子上下文继承父上下文的所有特性
- 父上下文取消时,所有子上下文递归取消
- 子上下文取消不影响父上下文
2.3 cancelCtx 核心实现
// cancelCtx 结构
type cancelCtx struct {
Context // 嵌入父上下文
mu sync.Mutex // 保护以下字段
done chan struct{} // 取消信号通道
children map[canceler]struct{} // 子上下文集合
err error // 取消原因
}
// 取消操作
func (c *cancelCtx) cancel(removeFromParent bool, err error) {
c.mu.Lock()
defer c.mu.Unlock()
// 1. 检查是否已取消
if c.err != nil {
return
}
// 2. 设置取消原因
c.err = err
// 3. 关闭 done 通道
if c.done == nil {
c.done = closedchan
} else {
close(c.done)
}
// 4. 递归取消所有子上下文
for child := range c.children {
child.cancel(false, err)
}
c.children = nil
// 5. 从父上下文移除自己
if removeFromParent {
removeFromParent(c.Context, c)
}
}
2.4 取消信号传播
// 创建层级上下文
ctx1, cancel1 := context.WithCancel(context.Background())
ctx2, cancel2 := context.WithCancel(ctx1)
ctx3, cancel3 := context.WithCancel(ctx2)
// 取消父上下文
cancel1()
// 结果:
// - ctx1 被取消
// - ctx2 自动被取消(子上下文)
// - ctx3 自动被取消(孙子上下文)
// - ctx2 和 ctx3 的 Done 通道都会关闭
传播流程:
cancel1()
↓
1. 关闭 ctx1.done 通道
↓
2. 遍历 ctx1.children
├─ ctx2.cancel(false, err)
│ ├─ 关闭 ctx2.done
│ └─ 遍历 ctx2.children
│ └─ ctx3.cancel(false, err)
│ └─ 关闭 ctx3.done
└─ 清理 children
↓
3. 所有监听 Done 的 goroutine 收到信号
三、WithCancel 详解
3.1 创建可取消上下文
func WithCancel(parent Context) (ctx Context, cancel CancelFunc) {
if parent == nil {
panic("cannot create context from nil parent")
}
// 创建 cancelCtx
c := newCancelCtx(parent)
// 传播取消信号
propagateCancel(parent, c)
return c, func() { c.cancel(true, Canceled) }
}
// 传播取消信号
func propagateCancel(parent Context, child canceler) {
// 1. 检查父上下文是否已取消
if parent.Done() == nil {
return // 父上下文永不会被取消
}
// 2. 检查父上下文是否是 cancelCtx
if p, ok := parentCancelCtx(parent); ok {
if p.err != nil {
// 父已取消,立即取消子
child.cancel(false, p.err)
} else {
// 加入父的 children 集合
p.children[child] = struct{}{}
}
return
}
// 3. 启动 goroutine 监听父上下文
go func() {
select {
case <-parent.Done():
child.cancel(false, parent.Err())
case <-child.Done():
}
}()
}
3.2 使用场景
// 场景 1:优雅退出
func worker(ctx context.Context) {
for {
select {
case <-ctx.Done():
fmt.Println("收到退出信号")
return
default:
doWork()
}
}
}
// 启动 worker
ctx, cancel := context.WithCancel(context.Background())
go worker(ctx)
// 需要退出时
cancel()
// 场景 2:取消耗时操作
func processWithCancel(ctx context.Context, data string) error {
resultCh := make(chan string, 1)
go func() {
result, err := heavyComputation(data)
if err != nil {
resultCh <- ""
} else {
resultCh <- result
}
}()
select {
case <-ctx.Done():
return ctx.Err()
case result := <-resultCh:
if result == "" {
return errors.New("计算失败")
}
return nil
}
}
// 场景 3:批量任务取消
func batchProcess(ctx context.Context, items []string) error {
errCh := make(chan error, len(items))
for _, item := range items {
go func(it string) {
err := processItem(ctx, it)
errCh <- err
}(item)
}
// 等待所有任务完成或被取消
for i := 0; i < len(items); i++ {
select {
case <-ctx.Done():
return ctx.Err()
case err := <-errCh:
if err != nil {
// 第一个错误,取消所有任务
return err
}
}
}
return nil
}
四、WithTimeout 详解
4.1 创建超时上下文
func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) {
return WithDeadline(parent, time.Now().Add(timeout))
}
func WithDeadline(parent Context, deadline time.Time) (Context, CancelFunc) {
// 创建 timerCtx
c := newCancelCtx(parent)
t := &timerCtx{
cancelCtx: c,
deadline: deadline,
}
// 传播取消信号
propagateCancel(parent, t)
// 计算超时时间
d := time.Until(deadline)
if d <= 0 {
// 已经超时,立即取消
t.cancel(true, DeadlineExceeded)
return t, func() { t.cancel(true, Canceled) }
}
// 启动定时器
t.timer = time.AfterFunc(d, func() {
t.cancel(true, DeadlineExceeded)
})
return t, func() { t.cancel(true, Canceled) }
}
4.2 timerCtx 实现
type timerCtx struct {
cancelCtx // 嵌入 cancelCtx
timer *time.Timer
deadline time.Time
}
// 取消操作
func (c *timerCtx) cancel(removeFromParent bool, err error) {
// 1. 先调用 cancelCtx 的取消
c.cancelCtx.cancel(removeFromParent, err)
// 2. 停止定时器
if c.timer != nil {
c.timer.Stop()
c.timer = nil
}
}
4.3 使用场景
// 场景 1:HTTP 请求超时
func fetchWithTimeout(url string, timeout time.Duration) ([]byte, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, err
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
return io.ReadAll(resp.Body)
}
// 场景 2:数据库查询超时
func queryWithTimeout(ctx context.Context, db *sql.DB, query string) (*sql.Rows, error) {
// 设置查询超时
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
rows, err := db.QueryContext(ctx, query)
if err != nil {
return nil, err
}
return rows, nil
}
// 场景 3:RPC 调用超时
func rpcCall(ctx context.Context, client MyClient, req *Request) (*Response, error) {
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
resp, err := client.Call(ctx, req)
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
return nil, errors.New("RPC 调用超时")
}
return nil, err
}
return resp, nil
}
4.4 级联超时
// 父上下文超时时间 > 子上下文
ctx1, cancel1 := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel1()
// 子上下文超时时间更短
ctx2, cancel2 := context.WithTimeout(ctx1, 5*time.Second)
defer cancel2()
// 结果:
// - 5 秒后 ctx2 超时
// - ctx2 的子上下文也会被取消
// - ctx1 不受影响(还有 5 秒)
五、WithValue 详解
5.1 创建值上下文
func WithValue(parent Context, key, val any) Context {
if key == nil {
panic("nil key")
}
// 创建 valueCtx
return &valueCtx{
Context: parent,
key: key,
val: val,
}
}
// valueCtx 实现
type valueCtx struct {
Context
key, val any
}
// Value 方法(递归向上查找)
func (c *valueCtx) Value(key any) any {
if key == c.key {
return c.val
}
return c.Context.Value(key) // 递归查找父上下文
}
5.2 键类型最佳实践
// ❌ 不推荐:使用 string 作为键(可能冲突)
const userIDKey = "user_id"
ctx := context.WithValue(ctx, userIDKey, 123)
// ✅ 推荐:使用自定义类型作为键
type contextKey string
const UserIDKey contextKey = "user_id"
func WithUserID(ctx context.Context, userID string) context.Context {
return context.WithValue(ctx, UserIDKey, userID)
}
func UserIDFromContext(ctx context.Context) string {
if userID, ok := ctx.Value(UserIDKey).(string); ok {
return userID
}
return ""
}
5.3 使用场景
// 场景 1:传递请求 ID
type contextKey string
const RequestIDKey contextKey = "request_id"
func middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestID := r.Header.Get("X-Request-ID")
if requestID == "" {
requestID = generateID()
}
// 将 requestID 放入 context
ctx := context.WithValue(r.Context(), RequestIDKey, requestID)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func handler(w http.ResponseWriter, r *http.Request) {
requestID := r.Context().Value(RequestIDKey).(string)
log.Printf("[%s] 处理请求", requestID)
}
// 场景 2:传递用户信息
type User struct {
ID string
Name string
}
type contextKey string
const UserKey contextKey = "user"
func authMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := r.Header.Get("Authorization")
user, err := validateToken(token)
if err != nil {
http.Error(w, "Unauthorized", 401)
return
}
ctx := context.WithValue(r.Context(), UserKey, user)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func getUserFromContext(ctx context.Context) *User {
if user, ok := ctx.Value(UserKey).(*User); ok {
return user
}
return nil
}
// 场景 3:传递数据库事务
func handleTransaction(ctx context.Context, db *sql.DB) error {
// 开启事务
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
// 将事务放入 context
ctx = context.WithValue(ctx, txKey, tx)
// 执行多个操作
if err := doSomething(ctx); err != nil {
tx.Rollback()
return err
}
if err := doAnotherThing(ctx); err != nil {
tx.Rollback()
return err
}
return tx.Commit()
}
func getTxFromContext(ctx context.Context) *sql.Tx {
if tx, ok := ctx.Value(txKey).(*sql.Tx); ok {
return tx
}
return nil
}
5.4 注意事项
// ❌ 不要传递大量数据
ctx := context.WithValue(ctx, "data", largeData) // 不推荐
// ✅ 只传递请求范围的元数据
ctx := context.WithValue(ctx, RequestIDKey, requestID)
// ❌ 不要用 context 传递函数参数
func process(ctx context.Context, data string) {
// 不推荐:从 context 获取参数
data := ctx.Value(DataKey).(string)
}
// ✅ 参数应该显式传递
func process(ctx context.Context, data string) {
// 推荐:显式参数
}
六、实战模式
6.1 优雅关闭
type Server struct {
ctx context.Context
cancel context.CancelFunc
}
func NewServer() *Server {
ctx, cancel := context.WithCancel(context.Background())
return &Server{ctx: ctx, cancel: cancel}
}
func (s *Server) Start() {
// 启动 worker
for i := 0; i < 10; i++ {
go s.worker(i)
}
}
func (s *Server) worker(id int) {
for {
select {
case <-s.ctx.Done():
log.Printf("Worker %d 收到退出信号", id)
return
default:
// 处理任务
}
}
}
func (s *Server) Stop() {
s.cancel() // 通知所有 worker 退出
}
// 使用
server := NewServer()
server.Start()
// 监听退出信号
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
<-sigCh
// 优雅关闭
server.Stop()
6.2 超时重试
func doWithRetry(ctx context.Context, fn func(context.Context) error, maxRetries int) error {
var lastErr error
for i := 0; i < maxRetries; i++ {
// 创建带超时的子上下文
retryCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
// 执行操作
err := fn(retryCtx)
cancel()
if err == nil {
return nil // 成功
}
lastErr = err
// 检查是否可重试
if errors.Is(err, context.Canceled) {
return err // 被取消,不重试
}
// 等待一段时间后重试
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(time.Second * time.Duration(i+1)):
}
}
return lastErr
}
// 使用
err := doWithRetry(ctx, func(ctx context.Context) error {
return fetchURL(ctx, "https://api.example.com/data")
}, 3)
6.3 并行任务控制
// 控制并发数量
func processWithLimit(ctx context.Context, items []string, limit int) error {
sem := make(chan struct{}, limit)
errCh := make(chan error, len(items))
for _, item := range items {
sem <- struct{}{} // 获取令牌
go func(it string) {
defer func() { <-sem }() // 释放令牌
select {
case <-ctx.Done():
errCh <- ctx.Err()
default:
err := processItem(ctx, it)
errCh <- err
}
}(item)
}
// 等待所有任务完成
for i := 0; i < len(items); i++ {
select {
case <-ctx.Done():
return ctx.Err()
case err := <-errCh:
if err != nil {
return err
}
}
}
return nil
}
6.4 管道模式
// 阶段 1:生成数据
func generator(ctx context.Context, n int) <-chan int {
out := make(chan int)
go func() {
defer close(out)
for i := 0; i < n; i++ {
select {
case <-ctx.Done():
return
case out <- i:
}
}
}()
return out
}
// 阶段 2:处理数据
func processor(ctx context.Context, in <-chan int) <-chan int {
out := make(chan int)
go func() {
defer close(out)
for n := range in {
select {
case <-ctx.Done():
return
case out <- n * 2:
}
}
}()
return out
}
// 阶段 3:输出结果
func printer(ctx context.Context, in <-chan int) {
for n := range in {
select {
case <-ctx.Done():
return
default:
fmt.Println(n)
}
}
}
// 使用
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ch1 := generator(ctx, 10)
ch2 := processor(ctx, ch1)
printer(ctx, ch2)
6.5 扇出扇入
// 扇出:一个输入,多个 worker 处理
func fanOut(ctx context.Context, in <-chan int, n int) []<-chan int {
outs := make([]<-chan int, n)
for i := 0; i < n; i++ {
out := make(chan int)
go func(out chan<- int) {
defer close(out)
for n := range in {
select {
case <-ctx.Done():
return
case out <- n:
}
}
}(out)
outs[i] = out
}
return outs
}
// 扇入:多个输入,合并到一个输出
func fanIn(ctx context.Context, ins ...<-chan int) <-chan int {
out := make(chan int)
var wg sync.WaitGroup
for _, in := range ins {
wg.Add(1)
go func(in <-chan int) {
defer wg.Done()
for n := range in {
select {
case <-ctx.Done():
return
case out <- n:
}
}
}(in)
}
go func() {
wg.Wait()
close(out)
}()
return out
}
七、常见问题
7.1 忘记调用 cancel
// ❌ 泄漏:未调用 cancel
ctx, cancel := context.WithTimeout(context.Background(), time.Hour)
// cancel() 未调用,timer 和 goroutine 泄漏
// ✅ 修复:使用 defer
ctx, cancel := context.WithTimeout(context.Background(), time.Hour)
defer cancel()
7.2 传递 nil context
// ❌ 错误:传递 nil
func process(ctx context.Context) {
if ctx == nil {
// 可能 panic
}
}
// ✅ 正确:使用 Background
process(context.Background())
7.3 在结构体中存储 context
// ❌ 不推荐
type Server struct {
ctx context.Context // 不应该存储
}
// ✅ 推荐:作为参数传递
type Server struct{}
func (s *Server) Start(ctx context.Context) {
// 使用 ctx
}
7.4 滥用 WithValue
// ❌ 不推荐:传递业务数据
ctx := context.WithValue(ctx, "user", user)
ctx = context.WithValue(ctx, "data", data)
// ✅ 推荐:只传递元数据
ctx := context.WithValue(ctx, RequestIDKey, requestID)
ctx = context.WithValue(ctx, UserKey, userID)
// 业务数据应该作为参数传递
func process(ctx context.Context, user *User, data string) {
// ...
}
八、最佳实践总结
8.1 使用原则
| 原则 | 说明 |
|---|---|
| 第一个参数 | ctx 应该是函数的第一个参数 |
| 不要存储 | 不要在结构体中存储 context |
| 传递 nil | 即使不使用,也传 context.Background() |
| 及时 cancel | 创建后确保调用 cancel() |
| 明确键类型 | 使用自定义类型作为 Value 的键 |
8.2 选择指南
| 场景 | 选择 |
|---|---|
| 根上下文 | context.Background() |
| 不确定 | context.TODO() |
| 需要取消 | WithCancel(parent) |
| 需要超时 | WithTimeout(parent, d) |
| 需要截止时间 | WithDeadline(parent, t) |
| 传递元数据 | WithValue(parent, key, val) |
8.3 性能考虑
// Context 链过长影响性能
ctx := context.Background()
ctx = context.WithValue(ctx, k1, v1)
ctx = context.WithValue(ctx, k2, v2)
ctx = context.WithValue(ctx, k3, v3)
// ... 继续嵌套
// Value 查找是 O(n) 复杂度
// 建议:只传递必要的元数据(3-5 个以内)
总结
Context 是 Go 并发编程的核心工具:
| 类型 | 用途 | 关键方法 |
|---|---|---|
| emptyCtx | 根上下文 | 无 |
| cancelCtx | 可取消 | cancel() |
| timerCtx | 超时控制 | WithTimeout() |
| valueCtx | 值传递 | WithValue() |
核心机制:
- 层级派生,父死子亡
- 取消信号递归传播
- Done 通道关闭通知
- Value 递归向上查找
掌握 Context,能帮助你:
- 优雅地控制 goroutine 生命周期
- 实现超时和取消逻辑
- 传递请求范围的元数据
- 避免 goroutine 泄漏