Golang: Pattern: Close Once

 3rd November 2022 at 11:00am

场景:一个类提供了 Close() 方法供外部调用,用于释放资源;资源只能释放一次。问如何实现这样的 Close() 函数?

使用 channel

Sarama.consumerGroup 提供了一个不错的例子:

type consumerGroup struct {
	// 去除无关代码

	closed    chan none
	closeOnce sync.Once
}

// Close implements ConsumerGroup.
func (c *consumerGroup) Close() (err error) {
	c.closeOnce.Do(func() {
		close(c.closed)

		// leave group
		if e := c.leave(); e != nil {
			err = e
		}

		// drain errors
		go func() {
			close(c.errors)
		}()
		for e := range c.errors {
			err = e
		}

		if e := c.client.Close(); e != nil {
			err = e
		}
	})
	return
}

func (c *consumerGroup) Consume(ctx context.Context, topics []string, handler ConsumerGroupHandler) error {
	// Ensure group is not closed
	select {
	case <-c.closed:
		return ErrClosedConsumerGroup
	default:
	}

	// 去除无关代码
}

重点在于 closedcloseOnce 两个变量。closed 是一个 chan struct{}nonestruct{})。Close 的几个实现点:

  • 关闭了 closed channel。这使得后面调用 Consume() 时,会在开头的 select 语句处退出。这里利用 channel 巧妙地实现一个并发安全的关闭判断。
  • 使用了 closeOnce,使关闭逻辑只做一次。
  • 使用返回变量 err 来返回错误;如果有多个错误,后面的错误覆盖前面的。
  • 没有使用 sync.Mutex 加锁;有些其他场景可能要加锁。

使用 atomic

github.com/go-redis/redis 库有一个例子:

type ConnPool struct {
	// ...

	_closed uint32 // atomic
}

func (p *ConnPool) closed() bool {
	return atomic.LoadUint32(&p._closed) == 1
}

func (p *ConnPool) Close() error {
	if !atomic.CompareAndSwapUint32(&p._closed, 0, 1) {
		return ErrClosed
	}

	// ...
}

func (p *ConnPool) Get() (*Conn, error) {
	if p.closed() {
		return nil, ErrClosed
	}

	// ...
}

但使用这种方式要多加注意atomic 只能保证 p._close 变量的原子读写,但是不能保证 Close 过程阻塞其他读操作。试想下面的调用顺序:

func (p *ConnPool) Close() error {
	// 执行顺序 2 开始
	if !atomic.CompareAndSwapUint32(&p._closed, 0, 1) {
		return ErrClosed
	}

	p.conn.Close()
	// 执行顺序 2 结束
}

func (p *ConnPool) Get() (*Conn, error) {
	// 执行顺序 1 开始
	if p.closed() {
		return nil, ErrClosed
	}
	// 执行顺序 1 结束

	// 执行顺序 3 开始
	p.conn.Get(...)
	// 执行顺序 3 结束
}

此时执行到第 3 块代码时,p.conn 实际上已经被关闭了,这里调用其 Get 就可能出错。

使用 mutex

使用锁的版本是我自己写的,比较挫:

alImpl struct {
	// ...
	mu     sync.RWMutex
	closed bool
}

func (c *SomeClass) Close() error {
	c.mu.Lock()
	defer c.mu.Unlock()

	if c.closed {
		return nil
	}
	c.closed = true

	// ...
}

func (c *SomeClass) Get() error {
	c.mu.RLock()
	defer c.mu.RUnlock()

	if c.closed {
		return nil, fmt.Errorf("is closed")
	}

	// ...
}