diff --git a/core/discov/internal/registry.go b/core/discov/internal/registry.go index 9fcb5aabbde3..e1dce491d077 100644 --- a/core/discov/internal/registry.go +++ b/core/discov/internal/registry.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io" + "slices" "sort" "strings" "sync" @@ -59,6 +60,17 @@ func (r *Registry) Monitor(endpoints []string, key string, l UpdateListener, exa return c.monitor(key, l, exactMatch) } +// Unmonitor cancel monitoring of given endpoints and keys, and remove the listener. +func (r *Registry) Unmonitor(endpoints []string, key string, l UpdateListener) { + c, exists := r.getCluster(endpoints) + // if not exists, return. + if !exists { + return + } + + c.unmonitor(key, l) +} + func (r *Registry) getCluster(endpoints []string) (c *cluster, exists bool) { clusterKey := getClusterKey(endpoints) r.lock.RLock() @@ -273,6 +285,14 @@ func (c *cluster) monitor(key string, l UpdateListener, exactMatch bool) error { return nil } +func (c *cluster) unmonitor(key string, l UpdateListener) { + c.lock.Lock() + defer c.lock.Unlock() + c.listeners[key] = slices.DeleteFunc(c.listeners[key], func(listener UpdateListener) bool { + return l == listener + }) +} + func (c *cluster) newClient() (EtcdClient, error) { cli, err := NewClient(c.endpoints) if err != nil { diff --git a/core/discov/internal/registry_test.go b/core/discov/internal/registry_test.go index bb9fd629b3bd..3c78a2b2d226 100644 --- a/core/discov/internal/registry_test.go +++ b/core/discov/internal/registry_test.go @@ -292,6 +292,31 @@ func TestRegistry_Monitor(t *testing.T) { assert.Error(t, GetRegistry().Monitor(endpoints, "foo", new(mockListener), false)) } +func TestRegistry_Unmonitor(t *testing.T) { + svr, err := mockserver.StartMockServers(1) + assert.NoError(t, err) + svr.StartAt(0) + + endpoints := []string{svr.Servers[0].Address} + GetRegistry().lock.Lock() + GetRegistry().clusters = map[string]*cluster{ + getClusterKey(endpoints): { + listeners: map[string][]UpdateListener{}, + values: map[string]map[string]string{ + "foo": { + "bar": "baz", + }, + }, + }, + } + GetRegistry().lock.Unlock() + l := new(mockListener) + assert.Error(t, GetRegistry().Monitor(endpoints, "foo", l, false)) + assert.Equal(t, 1, len(GetRegistry().clusters[getClusterKey(endpoints)].listeners["foo"])) + GetRegistry().Unmonitor(endpoints, "foo", l) + assert.Equal(t, 0, len(GetRegistry().clusters[getClusterKey(endpoints)].listeners["foo"])) +} + type mockListener struct { } diff --git a/core/discov/subscriber.go b/core/discov/subscriber.go index 08f89a601ff7..cbf1c69ddf12 100644 --- a/core/discov/subscriber.go +++ b/core/discov/subscriber.go @@ -4,6 +4,7 @@ import ( "sync" "sync/atomic" + "github.com/zeromicro/go-zero/core/collection" "github.com/zeromicro/go-zero/core/discov/internal" "github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/syncx" @@ -16,6 +17,7 @@ type ( // A Subscriber is used to subscribe the given key on an etcd cluster. Subscriber struct { endpoints []string + key string exclusive bool exactMatch bool items *container @@ -52,6 +54,11 @@ func (s *Subscriber) Values() []string { return s.items.getValues() } +// Close s. +func (s *Subscriber) Close() { + internal.GetRegistry().Unmonitor(s.endpoints, s.key, s.items) +} + // Exclusive means that key value can only be 1:1, // which means later added value will remove the keys associated with the same value previously. func Exclusive() SubOption { @@ -83,7 +90,7 @@ func WithSubEtcdTLS(certFile, certKeyFile, caFile string, insecureSkipVerify boo type container struct { exclusive bool - values map[string][]string + values map[string]*collection.Set mapping map[string]string snapshot atomic.Value dirty *syncx.AtomicBool @@ -94,7 +101,7 @@ type container struct { func newContainer(exclusive bool) *container { return &container{ exclusive: exclusive, - values: make(map[string][]string), + values: make(map[string]*collection.Set), mapping: make(map[string]string), dirty: syncx.ForAtomicBool(true), } @@ -116,15 +123,21 @@ func (c *container) addKv(key, value string) ([]string, bool) { defer c.lock.Unlock() c.dirty.Set(true) - keys := c.values[value] + if c.values[value] == nil { + c.values[value] = collection.NewSet() + } + keys := c.values[value].KeysStr() previous := append([]string(nil), keys...) early := len(keys) > 0 if c.exclusive && early { for _, each := range keys { c.doRemoveKey(each) } + if c.values[value] == nil { + c.values[value] = collection.NewSet() + } } - c.values[value] = append(c.values[value], key) + c.values[value].AddStr(key) c.mapping[key] = value if early { @@ -147,18 +160,12 @@ func (c *container) doRemoveKey(key string) { } delete(c.mapping, key) - keys := c.values[server] - remain := keys[:0] - - for _, k := range keys { - if k != key { - remain = append(remain, k) - } + if c.values[server] == nil { + return } + c.values[server].Remove(key) - if len(remain) > 0 { - c.values[server] = remain - } else { + if c.values[server].Count() == 0 { delete(c.values, server) } } diff --git a/zrpc/resolver/internal/discovbuilder.go b/zrpc/resolver/internal/discovbuilder.go index ed377d138fa1..5a91ee73e40c 100644 --- a/zrpc/resolver/internal/discovbuilder.go +++ b/zrpc/resolver/internal/discovbuilder.go @@ -38,7 +38,7 @@ func (b *discovBuilder) Build(target resolver.Target, cc resolver.ClientConn, _ sub.AddListener(update) update() - return &nopResolver{cc: cc}, nil + return &nopResolver{cc: cc, closeFunc: func() { sub.Close() }}, nil } func (b *discovBuilder) Scheme() string { diff --git a/zrpc/resolver/internal/resolver.go b/zrpc/resolver/internal/resolver.go index 7868eca8cecf..e04d65d875c0 100644 --- a/zrpc/resolver/internal/resolver.go +++ b/zrpc/resolver/internal/resolver.go @@ -37,10 +37,14 @@ func register() { } type nopResolver struct { - cc resolver.ClientConn + cc resolver.ClientConn + closeFunc func() } func (r *nopResolver) Close() { + if r.closeFunc != nil { + r.closeFunc() + } } func (r *nopResolver) ResolveNow(_ resolver.ResolveNowOptions) { diff --git a/zrpc/resolver/internal/resolver_test.go b/zrpc/resolver/internal/resolver_test.go index 7dd10ee79d10..387997262bcb 100644 --- a/zrpc/resolver/internal/resolver_test.go +++ b/zrpc/resolver/internal/resolver_test.go @@ -18,6 +18,20 @@ func TestNopResolver(t *testing.T) { }) } +func TestNopResolver_Close(t *testing.T) { + var isChanged bool + r := nopResolver{} + r.Close() + assert.False(t, isChanged) + r = nopResolver{ + closeFunc: func() { + isChanged = true + }, + } + r.Close() + assert.True(t, isChanged) +} + type mockedClientConn struct { state resolver.State err error