diff --git a/client/v3/lease.go b/client/v3/lease.go index 4e7d1caf831..4877ee94962 100644 --- a/client/v3/lease.go +++ b/client/v3/lease.go @@ -263,6 +263,12 @@ func (l *lessor) Leases(ctx context.Context) (*LeaseLeasesResponse, error) { return nil, ContextError(ctx, err) } +// To identify the context passed to `KeepAlive`, a key/value pair is +// attached to the context. The key is a `keepAliveCtxKey` object, and +// the value is the pointer to the context object itself, ensuring +// uniqueness as each context has a unique memory address. +type keepAliveCtxKey struct{} + func (l *lessor) KeepAlive(ctx context.Context, id LeaseID) (<-chan *LeaseKeepAliveResponse, error) { ch := make(chan *LeaseKeepAliveResponse, LeaseResponseChSize) @@ -277,6 +283,10 @@ func (l *lessor) KeepAlive(ctx context.Context, id LeaseID) (<-chan *LeaseKeepAl default: } ka, ok := l.keepAlives[id] + + if ctx.Done() != nil { + ctx = context.WithValue(ctx, keepAliveCtxKey{}, &ctx) + } if !ok { // create fresh keep alive ka = &keepAlive{ @@ -347,7 +357,7 @@ func (l *lessor) keepAliveCtxCloser(ctx context.Context, id LeaseID, donec <-cha // close channel and remove context if still associated with keep alive for i, c := range ka.ctxs { - if c == ctx { + if c.Value(keepAliveCtxKey{}) == ctx.Value(keepAliveCtxKey{}) { close(ka.chs[i]) ka.ctxs = append(ka.ctxs[:i], ka.ctxs[i+1:]...) ka.chs = append(ka.chs[:i], ka.chs[i+1:]...) diff --git a/tests/integration/clientv3/lease/lease_test.go b/tests/integration/clientv3/lease/lease_test.go index ef7065eb592..cd783952dcd 100644 --- a/tests/integration/clientv3/lease/lease_test.go +++ b/tests/integration/clientv3/lease/lease_test.go @@ -133,7 +133,14 @@ func TestLeaseKeepAlive(t *testing.T) { t.Errorf("failed to create lease %v", err) } - rc, kerr := lapi.KeepAlive(context.Background(), resp.ID) + type uncomparableCtx struct { + context.Context + _ func() + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + rc, kerr := lapi.KeepAlive(uncomparableCtx{Context: ctx}, resp.ID) if kerr != nil { t.Errorf("failed to keepalive lease %v", kerr) } @@ -151,6 +158,26 @@ func TestLeaseKeepAlive(t *testing.T) { t.Errorf("ID = %x, want %x", kresp.ID, resp.ID) } + ctx2, cancel2 := context.WithCancel(context.Background()) + rc2, kerr2 := lapi.KeepAlive(uncomparableCtx{Context: ctx2}, resp.ID) + if kerr2 != nil { + t.Errorf("failed to keepalive lease %v", kerr2) + } + + cancel2() + + _, ok = <-rc2 + if ok { + t.Errorf("chan is not closed, want cancel stop keepalive") + } + + select { + case <-rc: + // cancel2() should not affect first keepalive + t.Errorf("chan is closed, want keepalive continue") + default: + } + lapi.Close() _, ok = <-rc