diff --git a/AGENTS.md b/AGENTS.md index eb83b0350..dd553c8dc 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -47,6 +47,7 @@ When adding or modifying features, prefer extending existing packages before cre - `pkg/manager/config/` - Auto-reloads configuration files and provides interfaces to query them. - `pkg/manager/elect/` - Manages TiProxy owner elections (for example, metrics reader and VIP modules need an owner). - `pkg/manager/id/` - Generates global IDs. +- `pkg/manager/backendcluster/` - Manages cluster-scoped backend runtimes and shared resources such as PD or etcd clients. - `pkg/manager/infosync/` - Queries the topology of TiDB and Prometheus from PD and updates TiProxy information to PD. - `pkg/manager/logger/` - Manages the logger service. - `pkg/manager/memory/` - Records heap and goroutine profiles when memory usage is high. diff --git a/pkg/balance/observer/backend_fetcher.go b/pkg/balance/observer/backend_fetcher.go index c29f7ab5e..543fa02af 100644 --- a/pkg/balance/observer/backend_fetcher.go +++ b/pkg/balance/observer/backend_fetcher.go @@ -25,6 +25,10 @@ type BackendFetcher interface { // TopologyFetcher is an interface to fetch the tidb topology from ETCD. type TopologyFetcher interface { GetTiDBTopology(ctx context.Context) (map[string]*infosync.TiDBTopologyInfo, error) + // HasBackendClusters reports whether dynamic PD-backed clusters are configured at all. + // PDFetcher uses it to preserve the legacy behavior that static backend.instances still work + // when TiProxy starts without any PD cluster and clusters are added later through the API. + HasBackendClusters() bool } // PDFetcher fetches backend list from PD. @@ -32,25 +36,35 @@ type PDFetcher struct { tpFetcher TopologyFetcher logger *zap.Logger config *config.HealthCheck + static *StaticFetcher } -func NewPDFetcher(tpFetcher TopologyFetcher, logger *zap.Logger, config *config.HealthCheck) *PDFetcher { +func NewPDFetcher(tpFetcher TopologyFetcher, staticAddrs []string, logger *zap.Logger, config *config.HealthCheck) *PDFetcher { config.Check() return &PDFetcher{ tpFetcher: tpFetcher, logger: logger, config: config, + static: NewStaticFetcher(staticAddrs), } } func (pf *PDFetcher) GetBackendList(ctx context.Context) (map[string]*BackendInfo, error) { + // Keep backward compatibility with the legacy static-namespace flow: before any backend cluster + // is configured, backend.instances must still be routable even though namespace now always sees + // a non-nil topology fetcher from the cluster manager. + if !pf.tpFetcher.HasBackendClusters() { + return pf.static.GetBackendList(ctx) + } backends := pf.fetchBackendList(ctx) infos := make(map[string]*BackendInfo, len(backends)) - for addr, backend := range backends { - infos[addr] = &BackendInfo{ - Labels: backend.Labels, - IP: backend.IP, - StatusPort: backend.StatusPort, + for key, backend := range backends { + infos[key] = &BackendInfo{ + Addr: backend.Addr, + ClusterName: backend.ClusterName, + Labels: backend.Labels, + IP: backend.IP, + StatusPort: backend.StatusPort, } } return infos, nil @@ -98,7 +112,7 @@ func (sf *StaticFetcher) GetBackendList(context.Context) (map[string]*BackendInf func backendListToMap(addrs []string) map[string]*BackendInfo { backends := make(map[string]*BackendInfo, len(addrs)) for _, addr := range addrs { - backends[addr] = &BackendInfo{} + backends[addr] = &BackendInfo{Addr: addr} } return backends } diff --git a/pkg/balance/observer/backend_fetcher_test.go b/pkg/balance/observer/backend_fetcher_test.go index 5ce882d12..c56cb9ffe 100644 --- a/pkg/balance/observer/backend_fetcher_test.go +++ b/pkg/balance/observer/backend_fetcher_test.go @@ -26,6 +26,7 @@ func TestPDFetcher(t *testing.T) { { infos: map[string]*infosync.TiDBTopologyInfo{ "1.1.1.1:4000": { + Addr: "1.1.1.1:4000", Labels: map[string]string{"k1": "v1"}, IP: "1.1.1.1", StatusPort: 10080, @@ -34,6 +35,7 @@ func TestPDFetcher(t *testing.T) { check: func(m map[string]*BackendInfo) { require.Len(t, m, 1) require.NotNil(t, m["1.1.1.1:4000"]) + require.Equal(t, "1.1.1.1:4000", m["1.1.1.1:4000"].Addr) require.Equal(t, "1.1.1.1", m["1.1.1.1:4000"].IP) require.Equal(t, uint(10080), m["1.1.1.1:4000"].StatusPort) require.Equal(t, map[string]string{"k1": "v1"}, m["1.1.1.1:4000"].Labels) @@ -42,10 +44,12 @@ func TestPDFetcher(t *testing.T) { { infos: map[string]*infosync.TiDBTopologyInfo{ "1.1.1.1:4000": { + Addr: "1.1.1.1:4000", IP: "1.1.1.1", StatusPort: 10080, }, "2.2.2.2:4000": { + Addr: "2.2.2.2:4000", IP: "2.2.2.2", StatusPort: 10080, }, @@ -53,13 +57,30 @@ func TestPDFetcher(t *testing.T) { check: func(m map[string]*BackendInfo) { require.Len(t, m, 2) require.NotNil(t, m["1.1.1.1:4000"]) + require.Equal(t, "1.1.1.1:4000", m["1.1.1.1:4000"].Addr) require.Equal(t, "1.1.1.1", m["1.1.1.1:4000"].IP) require.Equal(t, uint(10080), m["1.1.1.1:4000"].StatusPort) require.NotNil(t, m["2.2.2.2:4000"]) + require.Equal(t, "2.2.2.2:4000", m["2.2.2.2:4000"].Addr) require.Equal(t, "2.2.2.2", m["2.2.2.2:4000"].IP) require.Equal(t, uint(10080), m["2.2.2.2:4000"].StatusPort) }, }, + { + infos: map[string]*infosync.TiDBTopologyInfo{ + "cluster-a/shared.tidb:4000": { + Addr: "shared.tidb:4000", + IP: "10.0.0.1", + StatusPort: 10080, + }, + }, + check: func(m map[string]*BackendInfo) { + require.Len(t, m, 1) + require.NotNil(t, m["cluster-a/shared.tidb:4000"]) + require.Equal(t, "shared.tidb:4000", m["cluster-a/shared.tidb:4000"].Addr) + require.Equal(t, "10.0.0.1", m["cluster-a/shared.tidb:4000"].IP) + }, + }, { ctx: func() context.Context { ctx, cancel := context.WithCancel(context.Background()) @@ -74,9 +95,10 @@ func TestPDFetcher(t *testing.T) { tpFetcher := newMockTpFetcher(t) lg, _ := logger.CreateLoggerForTest(t) - pf := NewPDFetcher(tpFetcher, lg, newHealthCheckConfigForTest()) + pf := NewPDFetcher(tpFetcher, nil, lg, newHealthCheckConfigForTest()) for _, test := range tests { tpFetcher.infos = test.infos + tpFetcher.hasClusters = true if test.ctx == nil { test.ctx = context.Background() } @@ -85,3 +107,27 @@ func TestPDFetcher(t *testing.T) { require.NoError(t, err) } } + +func TestPDFetcherFallbackToStaticWithoutBackendClusters(t *testing.T) { + tpFetcher := newMockTpFetcher(t) + lg, _ := logger.CreateLoggerForTest(t) + fetcher := NewPDFetcher(tpFetcher, []string{"127.0.0.1:4000"}, lg, newHealthCheckConfigForTest()) + + backends, err := fetcher.GetBackendList(context.Background()) + require.NoError(t, err) + require.Len(t, backends, 1) + require.Contains(t, backends, "127.0.0.1:4000") + + tpFetcher.hasClusters = true + tpFetcher.infos = map[string]*infosync.TiDBTopologyInfo{ + "cluster-a/10.0.0.1:4000": { + Addr: "10.0.0.1:4000", + ClusterName: "cluster-a", + }, + } + backends, err = fetcher.GetBackendList(context.Background()) + require.NoError(t, err) + require.Len(t, backends, 1) + require.Equal(t, "10.0.0.1:4000", backends["cluster-a/10.0.0.1:4000"].Addr) + require.Equal(t, "cluster-a", backends["cluster-a/10.0.0.1:4000"].ClusterName) +} diff --git a/pkg/balance/observer/backend_health.go b/pkg/balance/observer/backend_health.go index 2fec40bbb..45e1755d7 100644 --- a/pkg/balance/observer/backend_health.go +++ b/pkg/balance/observer/backend_health.go @@ -76,13 +76,17 @@ func (bh *BackendHealth) String() string { // BackendInfo stores the status info of each backend. type BackendInfo struct { - Labels map[string]string - IP string - StatusPort uint + Addr string + ClusterName string + Labels map[string]string + IP string + StatusPort uint } func (bi BackendInfo) Equals(other BackendInfo) bool { - return bi.IP == other.IP && + return bi.Addr == other.Addr && + bi.ClusterName == other.ClusterName && + bi.IP == other.IP && bi.StatusPort == other.StatusPort && maps.Equal(bi.Labels, other.Labels) } diff --git a/pkg/balance/observer/backend_health_test.go b/pkg/balance/observer/backend_health_test.go index a6fa0ae05..1b90b1391 100644 --- a/pkg/balance/observer/backend_health_test.go +++ b/pkg/balance/observer/backend_health_test.go @@ -15,6 +15,7 @@ func TestBackendHealthToString(t *testing.T) { {}, { BackendInfo: BackendInfo{ + Addr: "127.0.0.1:4000", IP: "127.0.0.1", StatusPort: 1, Labels: map[string]string{"k1": "v1", "k2": "v2"}, @@ -45,6 +46,7 @@ func TestBackendHealthEquals(t *testing.T) { { a: BackendHealth{ BackendInfo: BackendInfo{ + Addr: "127.0.0.1:4000", IP: "127.0.0.1", StatusPort: 1, Labels: map[string]string{"k1": "v1", "k2": "v2"}, @@ -52,6 +54,7 @@ func TestBackendHealthEquals(t *testing.T) { }, b: BackendHealth{ BackendInfo: BackendInfo{ + Addr: "127.0.0.1:4000", IP: "127.0.0.1", StatusPort: 1, }, @@ -61,6 +64,7 @@ func TestBackendHealthEquals(t *testing.T) { { a: BackendHealth{ BackendInfo: BackendInfo{ + Addr: "127.0.0.1:4000", IP: "127.0.0.1", StatusPort: 1, Labels: map[string]string{"k1": "v1", "k2": "v2"}, @@ -68,6 +72,7 @@ func TestBackendHealthEquals(t *testing.T) { }, b: BackendHealth{ BackendInfo: BackendInfo{ + Addr: "127.0.0.1:4000", IP: "127.0.0.1", StatusPort: 1, Labels: map[string]string{"k1": "v1", "k2": "v2"}, @@ -78,6 +83,7 @@ func TestBackendHealthEquals(t *testing.T) { { a: BackendHealth{ BackendInfo: BackendInfo{ + Addr: "127.0.0.1:4000", IP: "127.0.0.1", StatusPort: 1, Labels: map[string]string{"k1": "v1", "k2": "v2"}, diff --git a/pkg/balance/observer/backend_observer_test.go b/pkg/balance/observer/backend_observer_test.go index 8c3b9b221..ebb7fda22 100644 --- a/pkg/balance/observer/backend_observer_test.go +++ b/pkg/balance/observer/backend_observer_test.go @@ -279,6 +279,7 @@ func (ts *observerTestSuite) addBackend() (string, BackendInfo) { ts.backendIdx++ addr := fmt.Sprintf("%d", ts.backendIdx) info := &BackendInfo{ + Addr: addr, IP: "127.0.0.1", StatusPort: uint(ts.backendIdx), } diff --git a/pkg/balance/observer/health_check.go b/pkg/balance/observer/health_check.go index 5e7578fa0..410374351 100644 --- a/pkg/balance/observer/health_check.go +++ b/pkg/balance/observer/health_check.go @@ -21,7 +21,7 @@ import ( // HealthCheck is used to check the backends of one backend. One can pass a customized health check function to the observer. type HealthCheck interface { - Check(ctx context.Context, addr string, info *BackendInfo, lastHealth *BackendHealth) *BackendHealth + Check(ctx context.Context, backendID string, info *BackendInfo, lastHealth *BackendHealth) *BackendHealth } const ( @@ -62,7 +62,7 @@ func NewDefaultHealthCheck(httpCli *http.Client, cfg *config.HealthCheck, logger } } -func (dhc *DefaultHealthCheck) Check(ctx context.Context, addr string, info *BackendInfo, lastBh *BackendHealth) *BackendHealth { +func (dhc *DefaultHealthCheck) Check(ctx context.Context, _ string, info *BackendInfo, lastBh *BackendHealth) *BackendHealth { bh := &BackendHealth{ BackendInfo: *info, Healthy: true, @@ -80,7 +80,7 @@ func (dhc *DefaultHealthCheck) Check(ctx context.Context, addr string, info *Bac if !bh.Healthy { return bh } - dhc.checkSqlPort(ctx, addr, bh) + dhc.checkSqlPort(ctx, info, bh) if !bh.Healthy { return bh } @@ -88,8 +88,14 @@ func (dhc *DefaultHealthCheck) Check(ctx context.Context, addr string, info *Bac return bh } -func (dhc *DefaultHealthCheck) checkSqlPort(ctx context.Context, addr string, bh *BackendHealth) { +func (dhc *DefaultHealthCheck) checkSqlPort(ctx context.Context, info *BackendInfo, bh *BackendHealth) { // Also dial the SQL port just in case that the SQL port hangs. + if info == nil || info.Addr == "" { + bh.Healthy = false + bh.PingErr = errors.New("backend address is empty") + return + } + addr := info.Addr b := backoff.WithContext(backoff.WithMaxRetries(backoff.NewConstantBackOff(dhc.cfg.RetryInterval), uint64(dhc.cfg.MaxRetries)), ctx) err := http.ConnectWithRetry(func() error { startTime := time.Now() diff --git a/pkg/balance/observer/health_check_test.go b/pkg/balance/observer/health_check_test.go index 98b2be67c..a3ec94d1f 100644 --- a/pkg/balance/observer/health_check_test.go +++ b/pkg/balance/observer/health_check_test.go @@ -143,6 +143,7 @@ func newBackendServer(t *testing.T) (*backendServer, *BackendInfo) { backend.setSqlResp(true) backend.startSQLServer() return backend, &BackendInfo{ + Addr: backend.sqlAddr, IP: backend.ip, StatusPort: backend.statusPort, } diff --git a/pkg/balance/observer/mock_test.go b/pkg/balance/observer/mock_test.go index c8dbfd7d2..19d3e0edb 100644 --- a/pkg/balance/observer/mock_test.go +++ b/pkg/balance/observer/mock_test.go @@ -19,9 +19,10 @@ import ( ) type mockTpFetcher struct { - t *testing.T - infos map[string]*infosync.TiDBTopologyInfo - err error + t *testing.T + infos map[string]*infosync.TiDBTopologyInfo + err error + hasClusters bool } func newMockTpFetcher(t *testing.T) *mockTpFetcher { @@ -34,6 +35,10 @@ func (ft *mockTpFetcher) GetTiDBTopology(ctx context.Context) (map[string]*infos return ft.infos, ft.err } +func (ft *mockTpFetcher) HasBackendClusters() bool { + return ft.hasClusters +} + type mockBackendFetcher struct { sync.Mutex backends map[string]*BackendInfo @@ -82,11 +87,11 @@ func newMockHealthCheck() *mockHealthCheck { } } -func (mhc *mockHealthCheck) Check(_ context.Context, addr string, info *BackendInfo, _ *BackendHealth) *BackendHealth { +func (mhc *mockHealthCheck) Check(_ context.Context, backendID string, info *BackendInfo, _ *BackendHealth) *BackendHealth { mhc.Lock() defer mhc.Unlock() - mhc.backends[addr].BackendInfo = *info - return mhc.backends[addr] + mhc.backends[backendID].BackendInfo = *info + return mhc.backends[backendID] } func (mhc *mockHealthCheck) setBackend(addr string, health *BackendHealth) { diff --git a/pkg/balance/router/group.go b/pkg/balance/router/group.go index 43341054b..12ffb0e23 100644 --- a/pkg/balance/router/group.go +++ b/pkg/balance/router/group.go @@ -158,17 +158,17 @@ func (g *Group) RefreshCidr() { } } -func (g *Group) AddBackend(addr string, backend *backendWrapper) { +func (g *Group) AddBackend(backendID string, backend *backendWrapper) { g.Lock() defer g.Unlock() - g.backends[addr] = backend + g.backends[backendID] = backend backend.group = g } -func (g *Group) RemoveBackend(addr string) { +func (g *Group) RemoveBackend(backendID string) { g.Lock() defer g.Unlock() - delete(g.backends, addr) + delete(g.backends, backendID) } func (g *Group) Empty() bool { @@ -200,7 +200,7 @@ func (g *Group) Route(excluded []BackendInst) (policy.BackendCtx, error) { // Exclude the backends that are already tried. found := false for _, e := range excluded { - if backend.Addr() == e.Addr() { + if backend.ID() == e.ID() { found = true break } @@ -273,7 +273,7 @@ func (g *Group) Balance(ctx context.Context) { func (g *Group) onCreateConn(backendInst BackendInst, conn RedirectableConn, succeed bool) { g.Lock() defer g.Unlock() - backend := g.ensureBackend(backendInst.Addr()) + backend := g.ensureBackend(backendInst.ID()) if succeed { connWrapper := &connWrapper{ RedirectableConn: conn, @@ -319,23 +319,18 @@ func (g *Group) RedirectConnections() error { return nil } -func (g *Group) ensureBackend(addr string) *backendWrapper { - backend, ok := g.backends[addr] +func (g *Group) ensureBackend(backendID string) *backendWrapper { + backend, ok := g.backends[backendID] if ok { return backend } // The backend should always exist if it will be needed. Add a warning and add it back. - g.lg.Warn("backend is not found in the router", zap.String("backend_addr", addr), zap.Stack("stack")) - ip, _, _ := net.SplitHostPort(addr) - backend = newBackendWrapper(addr, observer.BackendHealth{ - BackendInfo: observer.BackendInfo{ - IP: ip, - StatusPort: 10080, // impossible anyway - }, + g.lg.Warn("backend is not found in the router", zap.String("backend_id", backendID), zap.Stack("stack")) + backend = newBackendWrapper(backendID, observer.BackendHealth{ SupportRedirection: true, Healthy: false, }) - g.backends[addr] = backend + g.backends[backendID] = backend return backend } @@ -375,16 +370,16 @@ func (g *Group) onRedirectFinished(from, to string, conn RedirectableConn, succe } // OnConnClosed implements ConnEventReceiver.OnConnClosed interface. -func (g *Group) OnConnClosed(addr, redirectingAddr string, conn RedirectableConn) error { +func (g *Group) OnConnClosed(backendID, redirectingBackendID string, conn RedirectableConn) error { g.Lock() defer g.Unlock() - backend := g.ensureBackend(addr) + backend := g.ensureBackend(backendID) connWrapper := getConnWrapper(conn) // If this connection has not redirected yet, decrease the score of the target backend. - if redirectingAddr != "" { - redirectingBackend := g.ensureBackend(redirectingAddr) + if redirectingBackendID != "" { + redirectingBackend := g.ensureBackend(redirectingBackendID) redirectingBackend.connScore-- - metrics.PendingMigrateGuage.WithLabelValues(addr, redirectingAddr, connWrapper.Value.redirectReason).Dec() + metrics.PendingMigrateGuage.WithLabelValues(backendID, redirectingBackendID, connWrapper.Value.redirectReason).Dec() } else { backend.connScore-- } diff --git a/pkg/balance/router/mock_test.go b/pkg/balance/router/mock_test.go index d8eb98950..e2918a59c 100644 --- a/pkg/balance/router/mock_test.go +++ b/pkg/balance/router/mock_test.go @@ -69,13 +69,13 @@ func (conn *mockRedirectableConn) Redirect(inst BackendInst) bool { return true } -func (conn *mockRedirectableConn) GetRedirectingAddr() string { +func (conn *mockRedirectableConn) GetRedirectingBackendID() string { conn.Lock() defer conn.Unlock() if conn.to == nil { return "" } - return conn.to.Addr() + return conn.to.ID() } func (conn *mockRedirectableConn) ConnectionID() uint64 { @@ -86,14 +86,14 @@ func (conn *mockRedirectableConn) ConnInfo() []zap.Field { return nil } -func (conn *mockRedirectableConn) getAddr() (string, string) { +func (conn *mockRedirectableConn) getBackendIDs() (string, string) { conn.Lock() defer conn.Unlock() var to string if conn.to != nil && !reflect.ValueOf(conn.to).IsNil() { - to = conn.to.Addr() + to = conn.to.ID() } - return conn.from.Addr(), to + return conn.from.ID(), to } func (conn *mockRedirectableConn) redirectSucceed() { @@ -138,6 +138,7 @@ func (mbo *mockBackendObserver) addBackend(addr string, labels map[string]string mbo.healths[addr] = &observer.BackendHealth{ Healthy: true, BackendInfo: observer.BackendInfo{ + Addr: addr, Labels: labels, }, } @@ -182,8 +183,9 @@ func (mbo *mockBackendObserver) notify(err error) { func (mbo *mockBackendObserver) Close() { mbo.subscriberLock.Lock() defer mbo.subscriberLock.Unlock() - for _, subscriber := range mbo.subscribers { + for name, subscriber := range mbo.subscribers { close(subscriber) + delete(mbo.subscribers, name) } } diff --git a/pkg/balance/router/router.go b/pkg/balance/router/router.go index ee844bf9e..bca70fbbe 100644 --- a/pkg/balance/router/router.go +++ b/pkg/balance/router/router.go @@ -23,7 +23,7 @@ var ( type ConnEventReceiver interface { OnRedirectSucceed(from, to string, conn RedirectableConn) error OnRedirectFail(from, to string, conn RedirectableConn) error - OnConnClosed(addr, redirectingAddr string, conn RedirectableConn) error + OnConnClosed(backendID, redirectingBackendID string, conn RedirectableConn) error } // Router routes client connections to backends. @@ -74,6 +74,7 @@ type RedirectableConn interface { // BackendInst defines a backend that a connection is redirecting to. type BackendInst interface { + ID() string Addr() string Healthy() bool Local() bool @@ -86,6 +87,7 @@ type backendWrapper struct { sync.RWMutex observer.BackendHealth } + id string addr string // connScore is used for calculating backend scores and check if the backend can be removed from the list. // connScore = connList.Len() + incoming connections - outgoing connections. @@ -97,9 +99,10 @@ type backendWrapper struct { group *Group } -func newBackendWrapper(addr string, health observer.BackendHealth) *backendWrapper { +func newBackendWrapper(id string, health observer.BackendHealth) *backendWrapper { wrapper := &backendWrapper{ - addr: addr, + id: id, + addr: health.Addr, connList: glist.New[*connWrapper](), } wrapper.setHealth(health) @@ -123,6 +126,10 @@ func (b *backendWrapper) ConnScore() int { return b.connScore } +func (b *backendWrapper) ID() string { + return b.id +} + func (b *backendWrapper) Addr() string { return b.addr } @@ -176,6 +183,11 @@ func (b *backendWrapper) Keyspace() string { return labels[config.KeyspaceLabelName] } +func (b *backendWrapper) ClusterName() string { + b.mu.RLock() + defer b.mu.RUnlock() + return b.mu.BackendHealth.ClusterName +} func (b *backendWrapper) Cidr() []string { labels := b.getHealth().Labels if len(labels) == 0 { diff --git a/pkg/balance/router/router_score.go b/pkg/balance/router/router_score.go index d70afa4a9..fc52d3037 100644 --- a/pkg/balance/router/router_score.go +++ b/pkg/balance/router/router_score.go @@ -116,7 +116,11 @@ func (router *ScoreBasedRouter) GetBackendSelector(clientInfo ClientInfo) Backen return } // The router may remove this group concurrently, make sure the group can be accessed after it's removed. - backend, err = group.Route(excluded) + var backendCtx policy.BackendCtx + backendCtx, err = group.Route(excluded) + if err == nil && backendCtx != nil { + backend = backendCtx.(BackendInst) + } return }, onCreate: func(backend BackendInst, conn RedirectableConn, succeed bool) { @@ -181,11 +185,12 @@ func (router *ScoreBasedRouter) updateBackendHealth(healthResults observer.Healt // `backends` contain all the backends, not only the updated ones. backends := healthResults.Backends() // If some backends are removed from the list, add them to `backends`. - for addr, backend := range router.backends { - if _, ok := backends[addr]; !ok { + for backendID, backend := range router.backends { + if _, ok := backends[backendID]; !ok { health := backend.getHealth() - router.logger.Debug("backend is removed from the list, add it back to router", zap.String("addr", addr), zap.Stringer("health", &health)) - backends[addr] = &observer.BackendHealth{ + router.logger.Debug("backend is removed from the list, add it back to router", + zap.String("backend_id", backendID), zap.String("addr", backend.Addr()), zap.Stringer("health", &health)) + backends[backendID] = &observer.BackendHealth{ BackendInfo: backend.GetBackendInfo(), SupportRedirection: backend.SupportRedirection(), Healthy: false, @@ -195,22 +200,25 @@ func (router *ScoreBasedRouter) updateBackendHealth(healthResults observer.Healt } var serverVersion string supportRedirection := true - for addr, health := range backends { - backend, ok := router.backends[addr] + for backendID, health := range backends { + backend, ok := router.backends[backendID] if !ok && health.Healthy { - router.logger.Debug("add new backend to router", zap.String("addr", addr), zap.Stringer("health", health)) - router.backends[addr] = newBackendWrapper(addr, *health) + router.logger.Debug("add new backend to router", + zap.String("backend_id", backendID), zap.String("addr", health.Addr), zap.Stringer("health", health)) + router.backends[backendID] = newBackendWrapper(backendID, *health) serverVersion = health.ServerVersion } else if ok { if !health.Equals(backend.getHealth()) { - router.logger.Debug("update backend in router", zap.String("addr", addr), zap.Stringer("health", health)) + router.logger.Debug("update backend in router", + zap.String("backend_id", backendID), zap.String("addr", health.Addr), zap.Stringer("health", health)) } backend.setHealth(*health) if health.Healthy { serverVersion = health.ServerVersion } } else { - router.logger.Debug("unhealthy backend is not in router", zap.String("addr", addr), zap.Stringer("health", health)) + router.logger.Debug("unhealthy backend is not in router", + zap.String("backend_id", backendID), zap.String("addr", health.Addr), zap.Stringer("health", health)) } supportRedirection = health.SupportRedirection && supportRedirection } @@ -232,9 +240,9 @@ func (router *ScoreBasedRouter) updateGroups() { // If connList.Len() == 0, there won't be any outgoing connections. // And if also connScore == 0, there won't be any incoming connections. if !backend.Healthy() && backend.connList.Len() == 0 && backend.connScore <= 0 { - delete(router.backends, backend.addr) + delete(router.backends, backend.id) if backend.group != nil { - backend.group.RemoveBackend(backend.addr) + backend.group.RemoveBackend(backend.id) // remove empty groups if backend.group.Empty() { router.groups = slices.DeleteFunc(router.groups, func(g *Group) bool { @@ -280,7 +288,7 @@ func (router *ScoreBasedRouter) updateGroups() { } } if group != nil { - group.AddBackend(backend.addr, backend) + group.AddBackend(backend.id, backend) } } for _, group := range router.groups { @@ -295,9 +303,19 @@ func (router *ScoreBasedRouter) rebalanceLoop(ctx context.Context) { case <-ctx.Done(): ticker.Stop() return - case healthResults := <-router.healthCh: + case healthResults, ok := <-router.healthCh: + if !ok { + router.logger.Warn("health channel is closed, stop watching channel") + router.healthCh = nil + continue + } router.updateBackendHealth(healthResults) - case cfg := <-router.cfgCh: + case cfg, ok := <-router.cfgCh: + if !ok { + router.logger.Warn("config channel is closed, stop watching channel") + router.cfgCh = nil + continue + } router.setConfig(cfg) case <-ticker.C: router.rebalance(ctx) diff --git a/pkg/balance/router/router_score_test.go b/pkg/balance/router/router_score_test.go index 1891f5eab..177b8294a 100644 --- a/pkg/balance/router/router_score_test.go +++ b/pkg/balance/router/router_score_test.go @@ -71,6 +71,9 @@ func (tester *routerTester) addBackends(num int) { tester.backendID++ addr := strconv.Itoa(tester.backendID) tester.backends[addr] = &observer.BackendHealth{ + BackendInfo: observer.BackendInfo{ + Addr: addr, + }, Healthy: true, SupportRedirection: true, } @@ -114,6 +117,9 @@ func (tester *routerTester) updateBackendStatusByAddr(addr string, healthy bool) health.Healthy = healthy } else { tester.backends[addr] = &observer.BackendHealth{ + BackendInfo: observer.BackendInfo{ + Addr: addr, + }, SupportRedirection: true, Healthy: healthy, } @@ -159,11 +165,11 @@ func (tester *routerTester) closeConnections(num int, redirecting bool) { conns := make(map[uint64]*mockRedirectableConn, num) for id, conn := range tester.conns { if redirecting { - if len(conn.GetRedirectingAddr()) == 0 { + if len(conn.GetRedirectingBackendID()) == 0 { continue } } else { - if len(conn.GetRedirectingAddr()) > 0 { + if len(conn.GetRedirectingBackendID()) > 0 { continue } } @@ -173,7 +179,7 @@ func (tester *routerTester) closeConnections(num int, redirecting bool) { } } for _, conn := range conns { - err := tester.router.groups[0].OnConnClosed(conn.from.Addr(), conn.GetRedirectingAddr(), conn) + err := tester.router.groups[0].OnConnClosed(conn.from.ID(), conn.GetRedirectingBackendID(), conn) require.NoError(tester.t, err) delete(tester.conns, conn.connID) } @@ -191,7 +197,7 @@ func (tester *routerTester) rebalance(num int) { func (tester *routerTester) redirectFinish(num int, succeed bool) { i := 0 for _, conn := range tester.conns { - if len(conn.GetRedirectingAddr()) == 0 { + if len(conn.GetRedirectingBackendID()) == 0 { continue } @@ -199,11 +205,11 @@ func (tester *routerTester) redirectFinish(num int, succeed bool) { prevCount, err := readMigrateCounter(from.Addr(), to.Addr(), succeed) require.NoError(tester.t, err) if succeed { - err = tester.router.groups[0].OnRedirectSucceed(from.Addr(), to.Addr(), conn) + err = tester.router.groups[0].OnRedirectSucceed(from.ID(), to.ID(), conn) require.NoError(tester.t, err) conn.redirectSucceed() } else { - err = tester.router.groups[0].OnRedirectFail(from.Addr(), to.Addr(), conn) + err = tester.router.groups[0].OnRedirectFail(from.ID(), to.ID(), conn) require.NoError(tester.t, err) conn.redirectFail() } @@ -239,7 +245,7 @@ func (tester *routerTester) checkBalanced() { func (tester *routerTester) checkRedirectingNum(num int) { redirectingNum := 0 for _, conn := range tester.conns { - if len(conn.GetRedirectingAddr()) > 0 { + if len(conn.GetRedirectingBackendID()) > 0 { redirectingNum++ } } @@ -626,13 +632,13 @@ func TestConcurrency(t *testing.T) { require.NoError(t, err) selector.Finish(conn, true) conn.from = backend - } else if len(conn.GetRedirectingAddr()) > 0 { + } else if len(conn.GetRedirectingBackendID()) > 0 { // redirecting, 70% success, 20% fail, 10% close i := rand.Intn(10) - from, to := conn.getAddr() + from, to := conn.getBackendIDs() var err error if i < 1 { - err = router.groups[0].OnConnClosed(from, conn.GetRedirectingAddr(), conn) + err = router.groups[0].OnConnClosed(from, conn.GetRedirectingBackendID(), conn) conn = nil } else if i < 3 { conn.redirectFail() @@ -647,8 +653,8 @@ func TestConcurrency(t *testing.T) { i := rand.Intn(10) if i < 2 { // The balancer may happen to redirect it concurrently - that's exactly what may happen. - from, _ := conn.getAddr() - err := router.groups[0].OnConnClosed(from, conn.GetRedirectingAddr(), conn) + from, _ := conn.getBackendIDs() + err := router.groups[0].OnConnClosed(from, conn.GetRedirectingBackendID(), conn) require.NoError(t, err) conn = nil } @@ -730,10 +736,16 @@ func TestGetServerVersion(t *testing.T) { t.Cleanup(rt.Close) backends := map[string]*observer.BackendHealth{ "0": { + BackendInfo: observer.BackendInfo{ + Addr: "0", + }, Healthy: true, ServerVersion: "1.0", }, "1": { + BackendInfo: observer.BackendInfo{ + Addr: "1", + }, Healthy: true, ServerVersion: "2.0", }, @@ -794,8 +806,8 @@ func TestUpdateBackendHealth(t *testing.T) { // Test locality of some backends are changed. tester.updateBackendLocalityByAddr(tester.getBackendByIndex(0).Addr(), false) tester.updateBackendLocalityByAddr(tester.getBackendByIndex(1).Addr(), true) - require.Equal(t, false, tester.router.backends[tester.getBackendByIndex(0).Addr()].Local()) - require.Equal(t, true, tester.router.backends[tester.getBackendByIndex(1).Addr()].Local()) + require.Equal(t, false, tester.router.backends[tester.getBackendByIndex(0).ID()].Local()) + require.Equal(t, true, tester.router.backends[tester.getBackendByIndex(1).ID()].Local()) // Test some backends are not in the list anymore. tester.removeBackends(1) tester.checkBackendNum(2) @@ -840,6 +852,55 @@ func TestWatchConfig(t *testing.T) { }, 3*time.Second, 10*time.Millisecond) } +func TestChannelClosed(t *testing.T) { + tests := []struct { + name string + closeChannel func(cfgCh chan *config.Config, bo *mockBackendObserver) + }{ + { + name: "config", + closeChannel: func(cfgCh chan *config.Config, _ *mockBackendObserver) { + close(cfgCh) + }, + }, + { + name: "health", + closeChannel: func(_ chan *config.Config, bo *mockBackendObserver) { + bo.Close() + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + lg, _ := logger.CreateLoggerForTest(t) + router := NewScoreBasedRouter(lg) + cfgCh := make(chan *config.Config) + cfg := &config.Config{} + cfgGetter := newMockConfigGetter(cfg) + p := &mockBalancePolicy{} + bpCreator := func(lg *zap.Logger) policy.BalancePolicy { + p.Init(cfg) + return p + } + bo := newMockBackendObserver() + router.Init(context.Background(), bo, bpCreator, cfgGetter, cfgCh) + t.Cleanup(bo.Close) + t.Cleanup(router.Close) + + bo.addBackend("0", nil) + bo.notify(nil) + require.Eventually(t, func() bool { + router.Lock() + defer router.Unlock() + return len(router.groups) == 1 + }, 3*time.Second, 10*time.Millisecond) + + tt.closeChannel(cfgCh, bo) + time.Sleep(100 * time.Millisecond) + }) + } +} + func TestControlSpeed(t *testing.T) { tests := []struct { balanceCount float64 @@ -978,10 +1039,16 @@ func TestSkipRedirection(t *testing.T) { tester := newRouterTester(t, nil) backends := map[string]*observer.BackendHealth{ "0": { + BackendInfo: observer.BackendInfo{ + Addr: "0", + }, Healthy: true, SupportRedirection: false, }, "1": { + BackendInfo: observer.BackendInfo{ + Addr: "1", + }, Healthy: true, SupportRedirection: true, }, @@ -1109,3 +1176,35 @@ func TestGroupBackends(t *testing.T) { }, 3*time.Second, 10*time.Millisecond, "test %d", i) } } + +func TestRouteBackendsWithSameAddrDifferentIDs(t *testing.T) { + tester := newRouterTester(t, nil) + tester.router.matchType = MatchAll + tester.backends["cluster-a/shared:4000"] = &observer.BackendHealth{ + Healthy: true, + SupportRedirection: true, + BackendInfo: observer.BackendInfo{ + Addr: "shared:4000", + ClusterName: "cluster-a", + }, + } + tester.backends["cluster-b/shared:4000"] = &observer.BackendHealth{ + Healthy: true, + SupportRedirection: true, + BackendInfo: observer.BackendInfo{ + Addr: "shared:4000", + ClusterName: "cluster-b", + }, + } + tester.notifyHealth() + + selector := tester.router.GetBackendSelector(ClientInfo{}) + first, err := selector.Next() + require.NoError(t, err) + second, err := selector.Next() + require.NoError(t, err) + + require.Equal(t, "shared:4000", first.Addr()) + require.Equal(t, "shared:4000", second.Addr()) + require.NotEqual(t, first.ID(), second.ID()) +} diff --git a/pkg/balance/router/router_static.go b/pkg/balance/router/router_static.go index b4230bd2f..9eddddd4d 100644 --- a/pkg/balance/router/router_static.go +++ b/pkg/balance/router/router_static.go @@ -26,7 +26,7 @@ func (r *StaticRouter) GetBackendSelector(_ ClientInfo) BackendSelector { for _, backend := range r.backends { found := false for _, e := range excluded { - if e.Addr() == backend.Addr() { + if e.ID() == backend.ID() { found = true break } @@ -74,7 +74,7 @@ func (r *StaticRouter) OnRedirectFail(from, to string, conn RedirectableConn) er return nil } -func (r *StaticRouter) OnConnClosed(addr, redirectingAddr string, conn RedirectableConn) error { +func (r *StaticRouter) OnConnClosed(backendID, redirectingBackendID string, conn RedirectableConn) error { r.cnt-- return nil } @@ -97,6 +97,10 @@ func (b *StaticBackend) Addr() string { return b.addr } +func (b *StaticBackend) ID() string { + return b.addr +} + func (b *StaticBackend) Healthy() bool { return b.healthy.Load() } diff --git a/pkg/manager/backendcluster/backend_id.go b/pkg/manager/backendcluster/backend_id.go new file mode 100644 index 000000000..b5177d59e --- /dev/null +++ b/pkg/manager/backendcluster/backend_id.go @@ -0,0 +1,12 @@ +// Copyright 2026 PingCAP, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package backendcluster + +import "fmt" + +// backendID returns the opaque identity key for one backend in one backend cluster. +// It is only used as an in-memory map key and must not be parsed or used as a network address. +func backendID(clusterName, addr string) string { + return fmt.Sprintf("%s/%s", clusterName, addr) +} diff --git a/pkg/manager/backendcluster/manager.go b/pkg/manager/backendcluster/manager.go new file mode 100644 index 000000000..3b280fb74 --- /dev/null +++ b/pkg/manager/backendcluster/manager.go @@ -0,0 +1,312 @@ +// Copyright 2026 PingCAP, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package backendcluster + +import ( + "context" + "crypto/tls" + "maps" + "slices" + "strings" + "sync" + + "github.com/pingcap/tiproxy/lib/config" + "github.com/pingcap/tiproxy/lib/util/errors" + "github.com/pingcap/tiproxy/pkg/manager/infosync" + "github.com/pingcap/tiproxy/pkg/util/etcd" + "github.com/pingcap/tiproxy/pkg/util/waitgroup" + clientv3 "go.etcd.io/etcd/client/v3" + "go.uber.org/zap" +) + +// Cluster is the cluster-scoped container for one backend PD cluster. +type Cluster struct { + cfg config.BackendCluster + etcdCli *clientv3.Client + infoSyncer *infosync.InfoSyncer +} + +func (c *Cluster) Config() config.BackendCluster { + return c.cfg +} + +func (c *Cluster) EtcdClient() *clientv3.Client { + return c.etcdCli +} + +func (c *Cluster) GetTiDBTopology(ctx context.Context) (map[string]*infosync.TiDBTopologyInfo, error) { + return c.infoSyncer.GetTiDBTopology(ctx) +} + +func (c *Cluster) GetPromInfo(ctx context.Context) (*infosync.PrometheusInfo, error) { + return c.infoSyncer.GetPromInfo(ctx) +} + +type Manager struct { + lg *zap.Logger + clusterTLS func() *tls.Config + + wg waitgroup.WaitGroup + cancel context.CancelFunc + + mu struct { + sync.RWMutex + clusters map[string]*Cluster + } +} + +func NewManager(lg *zap.Logger, clusterTLS func() *tls.Config) *Manager { + mgr := &Manager{ + lg: lg, + clusterTLS: clusterTLS, + } + mgr.mu.clusters = make(map[string]*Cluster) + return mgr +} + +func (m *Manager) Start(ctx context.Context, cfgGetter config.ConfigGetter, cfgCh <-chan *config.Config) error { + if err := m.syncClusters(ctx, cfgGetter.GetConfig()); err != nil { + return err + } + childCtx, cancel := context.WithCancel(ctx) + m.cancel = cancel + m.wg.Run(func() { + m.watchConfig(childCtx, cfgCh) + }, m.lg) + return nil +} + +func (m *Manager) watchConfig(ctx context.Context, cfgCh <-chan *config.Config) { + if cfgCh == nil { + return + } + for { + select { + case <-ctx.Done(): + return + case cfg, ok := <-cfgCh: + if !ok { + m.lg.Warn("config channel is closed, stop watching backend clusters") + return + } + if cfg == nil { + continue + } + if err := m.syncClusters(ctx, cfg); err != nil { + m.lg.Error("sync backend clusters failed", zap.Error(err)) + } + } + } +} + +func (m *Manager) syncClusters(ctx context.Context, cfg *config.Config) error { + if cfg == nil { + return nil + } + desiredClusters := cfg.GetBackendClusters() + desiredMap := make(map[string]config.BackendCluster, len(desiredClusters)) + for _, cluster := range desiredClusters { + desiredMap[cluster.Name] = cluster + } + + m.mu.Lock() + oldClusters := m.mu.clusters + newClusters := make(map[string]*Cluster, len(desiredClusters)) + closeList := make([]*Cluster, 0, len(oldClusters)) + + for _, clusterCfg := range desiredClusters { + oldCluster, ok := oldClusters[clusterCfg.Name] + if ok && clusterReusable(oldCluster, clusterCfg) { + newClusters[clusterCfg.Name] = oldCluster + delete(oldClusters, clusterCfg.Name) + continue + } + + cluster, err := m.buildCluster(ctx, cfg, clusterCfg) + if err != nil { + if ok { + m.lg.Warn("failed to update backend cluster, keep the old one", + zap.String("cluster", clusterCfg.Name), zap.Error(err)) + newClusters[clusterCfg.Name] = oldCluster + delete(oldClusters, clusterCfg.Name) + continue + } + m.lg.Error("failed to add backend cluster", + zap.String("cluster", clusterCfg.Name), zap.Error(err)) + continue + } + newClusters[clusterCfg.Name] = cluster + if ok { + closeList = append(closeList, oldCluster) + delete(oldClusters, clusterCfg.Name) + m.lg.Info("updated backend cluster", + zap.String("cluster", clusterCfg.Name), zap.String("pd_addrs", clusterCfg.PDAddrs)) + } else { + m.lg.Info("added backend cluster", + zap.String("cluster", clusterCfg.Name), zap.String("pd_addrs", clusterCfg.PDAddrs)) + } + } + + for name, cluster := range oldClusters { + if _, ok := desiredMap[name]; ok { + continue + } + closeList = append(closeList, cluster) + m.lg.Info("removed backend cluster", + zap.String("cluster", name), zap.String("pd_addrs", cluster.cfg.PDAddrs)) + } + + m.mu.clusters = newClusters + m.mu.Unlock() + + for _, cluster := range closeList { + if err := m.closeCluster(cluster); err != nil { + m.lg.Warn("close backend cluster failed", + zap.String("cluster", cluster.cfg.Name), zap.Error(err)) + } + } + return nil +} + +func normalizeCluster(cluster config.BackendCluster) config.BackendCluster { + cluster.Name = strings.TrimSpace(cluster.Name) + cluster.PDAddrs = strings.TrimSpace(cluster.PDAddrs) + return cluster +} + +func clusterReusable(cluster *Cluster, cfg config.BackendCluster) bool { + if cluster == nil { + return false + } + left := normalizeCluster(cluster.cfg) + right := normalizeCluster(cfg) + return left.Name == right.Name && + left.PDAddrs == right.PDAddrs && + slices.Equal(left.NSServers, right.NSServers) +} + +func (m *Manager) buildCluster(ctx context.Context, cfg *config.Config, clusterCfg config.BackendCluster) (*Cluster, error) { + clusterCfg = normalizeCluster(clusterCfg) + etcdCli, err := etcd.InitEtcdClientWithAddrs( + m.lg.With(zap.String("cluster", clusterCfg.Name)).Named("etcd"), + clusterCfg.PDAddrs, + m.clusterTLS(), + ) + if err != nil { + return nil, err + } + + infoSyncer := infosync.NewInfoSyncer(m.lg.With(zap.String("cluster", clusterCfg.Name)).Named("infosync"), etcdCli) + if err := infoSyncer.Init(ctx, cfg); err != nil { + if closeErr := etcdCli.Close(); closeErr != nil { + m.lg.Warn("close cluster etcd client failed after infosync init error", + zap.String("cluster", clusterCfg.Name), zap.Error(closeErr)) + } + return nil, err + } + + return &Cluster{ + cfg: clusterCfg, + etcdCli: etcdCli, + infoSyncer: infoSyncer, + }, nil +} + +func (m *Manager) closeCluster(cluster *Cluster) error { + if cluster == nil { + return nil + } + errs := make([]error, 0, 2) + if cluster.infoSyncer != nil { + errs = append(errs, cluster.infoSyncer.Close()) + } + if cluster.etcdCli != nil { + errs = append(errs, cluster.etcdCli.Close()) + } + return errors.Collect(errors.New("close backend cluster"), errs...) +} + +func (m *Manager) Snapshot() map[string]*Cluster { + m.mu.RLock() + snapshot := make(map[string]*Cluster, len(m.mu.clusters)) + maps.Copy(snapshot, m.mu.clusters) + m.mu.RUnlock() + return snapshot +} + +func (m *Manager) HasBackendClusters() bool { + m.mu.RLock() + defer m.mu.RUnlock() + return len(m.mu.clusters) > 0 +} + +// PrimaryCluster returns the only configured cluster when the cluster count is exactly one. +// It exists for features that are only well-defined in the single-cluster case, such as VIP, +// and for temporary transition points that still require a unique cluster. +func (m *Manager) PrimaryCluster() *Cluster { + m.mu.RLock() + defer m.mu.RUnlock() + if len(m.mu.clusters) != 1 { + return nil + } + for _, cluster := range m.mu.clusters { + return cluster + } + return nil +} + +func (m *Manager) GetTiDBTopology(ctx context.Context) (map[string]*infosync.TiDBTopologyInfo, error) { + clusters := m.Snapshot() + merged := make(map[string]*infosync.TiDBTopologyInfo, 128) + errs := make([]error, 0, len(clusters)) + for clusterName, cluster := range clusters { + infos, err := cluster.GetTiDBTopology(ctx) + if err != nil { + errs = append(errs, err) + continue + } + for _, info := range infos { + cloned := *info + backendID := backendID(clusterName, cloned.Addr) + if oldInfo, ok := merged[backendID]; ok { + m.lg.Warn("duplicate backend in cluster, keep the first one", + zap.String("backend_id", backendID), + zap.String("addr", cloned.Addr), + zap.String("cluster", clusterName), + zap.String("first_cluster", oldInfo.ClusterName)) + continue + } + cloned.Labels = info.Labels + cloned.ClusterName = clusterName + merged[backendID] = &cloned + } + } + if len(merged) == 0 && len(errs) > 0 { + return nil, errors.Collect(errors.New("fetch from backend clusters"), errs...) + } + return merged, nil +} + +func (m *Manager) Close() error { + if m.cancel != nil { + m.cancel() + } + m.wg.Wait() + + m.mu.Lock() + clusters := m.mu.clusters + m.mu.clusters = make(map[string]*Cluster) + m.mu.Unlock() + + errs := make([]error, 0, len(clusters)) + for _, cluster := range clusters { + if err := m.closeCluster(cluster); err != nil { + errs = append(errs, err) + } + } + if len(errs) == 0 { + return nil + } + return errors.Collect(errors.New("close backend cluster manager"), errs...) +} diff --git a/pkg/manager/backendcluster/manager_test.go b/pkg/manager/backendcluster/manager_test.go new file mode 100644 index 000000000..b4b39076d --- /dev/null +++ b/pkg/manager/backendcluster/manager_test.go @@ -0,0 +1,273 @@ +// Copyright 2026 PingCAP, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package backendcluster + +import ( + "context" + "crypto/tls" + "encoding/json" + "path" + "sync" + "testing" + "time" + + "github.com/pingcap/tiproxy/lib/config" + "github.com/pingcap/tiproxy/lib/util/logger" + "github.com/pingcap/tiproxy/pkg/manager/infosync" + "github.com/pingcap/tiproxy/pkg/util/etcd" + "github.com/stretchr/testify/require" + clientv3 "go.etcd.io/etcd/client/v3" + "go.etcd.io/etcd/server/v3/embed" + "go.uber.org/zap" +) + +const ( + testTiDBTopologyPath = "/topology/tidb" + testInfoSuffix = "info" + testTTLSuffix = "ttl" +) + +func nilClusterTLS() *tls.Config { + return nil +} + +func TestManagerFetchesAllClusters(t *testing.T) { + clusterA := newManagerTestEtcdCluster(t) + clusterB := newManagerTestEtcdCluster(t) + t.Cleanup(func() { clusterA.close(t) }) + t.Cleanup(func() { clusterB.close(t) }) + + clusterA.putTopology(t, "10.0.0.1:4000", &infosync.TiDBTopologyInfo{IP: "10.0.0.1", StatusPort: 10080}) + clusterB.putTopology(t, "10.0.0.2:4000", &infosync.TiDBTopologyInfo{IP: "10.0.0.2", StatusPort: 10080}) + + cfg := newManagerTestConfig() + cfg.Proxy.BackendClusters = []config.BackendCluster{ + {Name: "cluster-a", PDAddrs: clusterA.addr}, + {Name: "cluster-b", PDAddrs: clusterB.addr}, + } + cfgGetter := newManagerTestConfigGetter(cfg) + + mgr := NewManager(zapLoggerForTest(t), nilClusterTLS) + require.NoError(t, mgr.Start(context.Background(), cfgGetter, nil)) + t.Cleanup(func() { + require.NoError(t, mgr.Close()) + }) + + require.Eventually(t, func() bool { + topology, err := mgr.GetTiDBTopology(context.Background()) + if err != nil || len(topology) != 2 { + return false + } + return topology[backendID("cluster-a", "10.0.0.1:4000")].ClusterName == "cluster-a" && + topology[backendID("cluster-b", "10.0.0.2:4000")].ClusterName == "cluster-b" + }, 5*time.Second, 100*time.Millisecond) +} + +func TestManagerDynamicClusterUpdate(t *testing.T) { + clusterA := newManagerTestEtcdCluster(t) + clusterB := newManagerTestEtcdCluster(t) + t.Cleanup(func() { clusterA.close(t) }) + t.Cleanup(func() { clusterB.close(t) }) + + clusterA.putTopology(t, "10.0.0.1:4000", &infosync.TiDBTopologyInfo{IP: "10.0.0.1", StatusPort: 10080}) + clusterB.putTopology(t, "10.0.0.2:4000", &infosync.TiDBTopologyInfo{IP: "10.0.0.2", StatusPort: 10080}) + + cfg := newManagerTestConfig() + cfg.Proxy.PDAddrs = "" + cfg.Proxy.BackendClusters = nil + cfgGetter := newManagerTestConfigGetter(cfg) + cfgCh := make(chan *config.Config, 4) + + mgr := NewManager(zapLoggerForTest(t), nilClusterTLS) + require.NoError(t, mgr.Start(context.Background(), cfgGetter, cfgCh)) + t.Cleanup(func() { + require.NoError(t, mgr.Close()) + }) + + topology, err := mgr.GetTiDBTopology(context.Background()) + require.NoError(t, err) + require.Empty(t, topology) + + nextCfg := cfg.Clone() + nextCfg.Proxy.BackendClusters = []config.BackendCluster{ + {Name: "cluster-a", PDAddrs: clusterA.addr}, + } + cfgGetter.setConfig(nextCfg) + cfgCh <- nextCfg.Clone() + require.Eventually(t, func() bool { + topology, err := mgr.GetTiDBTopology(context.Background()) + if err != nil || len(topology) != 1 { + return false + } + info, ok := topology[backendID("cluster-a", "10.0.0.1:4000")] + return ok && info.ClusterName == "cluster-a" + }, 5*time.Second, 100*time.Millisecond) + + nextCfg = cfg.Clone() + nextCfg.Proxy.BackendClusters = []config.BackendCluster{ + {Name: "cluster-b", PDAddrs: clusterB.addr}, + } + cfgGetter.setConfig(nextCfg) + cfgCh <- nextCfg.Clone() + require.Eventually(t, func() bool { + topology, err := mgr.GetTiDBTopology(context.Background()) + if err != nil || len(topology) != 1 { + return false + } + info, ok := topology[backendID("cluster-b", "10.0.0.2:4000")] + return ok && info.ClusterName == "cluster-b" + }, 5*time.Second, 100*time.Millisecond) +} + +func TestManagerKeepsOldClusterWhenUpdateFails(t *testing.T) { + clusterA := newManagerTestEtcdCluster(t) + clusterB := newManagerTestEtcdCluster(t) + t.Cleanup(func() { clusterA.close(t) }) + t.Cleanup(func() { clusterB.close(t) }) + + clusterA.putTopology(t, "10.0.0.1:4000", &infosync.TiDBTopologyInfo{IP: "10.0.0.1", StatusPort: 10080}) + clusterB.putTopology(t, "10.0.0.2:4000", &infosync.TiDBTopologyInfo{IP: "10.0.0.2", StatusPort: 10080}) + + cfg := newManagerTestConfig() + cfg.Proxy.BackendClusters = []config.BackendCluster{ + {Name: "cluster-a", PDAddrs: clusterA.addr}, + } + cfgGetter := newManagerTestConfigGetter(cfg) + + mgr := NewManager(zapLoggerForTest(t), nilClusterTLS) + require.NoError(t, mgr.Start(context.Background(), cfgGetter, nil)) + t.Cleanup(func() { + require.NoError(t, mgr.Close()) + }) + + require.Eventually(t, func() bool { + topology, err := mgr.GetTiDBTopology(context.Background()) + if err != nil || len(topology) != 1 { + return false + } + _, ok := topology[backendID("cluster-a", "10.0.0.1:4000")] + return ok + }, 5*time.Second, 100*time.Millisecond) + + originalCluster := mgr.Snapshot()["cluster-a"] + require.NotNil(t, originalCluster) + + nextCfg := cfg.Clone() + nextCfg.Proxy.Addr = "invalid" + nextCfg.Proxy.BackendClusters = []config.BackendCluster{ + {Name: "cluster-a", PDAddrs: clusterB.addr}, + } + require.NoError(t, mgr.syncClusters(context.Background(), nextCfg)) + + currentCluster := mgr.Snapshot()["cluster-a"] + require.Same(t, originalCluster, currentCluster) + + topology, err := mgr.GetTiDBTopology(context.Background()) + require.NoError(t, err) + require.Contains(t, topology, backendID("cluster-a", "10.0.0.1:4000")) + require.NotContains(t, topology, backendID("cluster-a", "10.0.0.2:4000")) +} +func TestManagerKeepsDuplicateBackendAddrsAcrossClusters(t *testing.T) { + clusterA := newManagerTestEtcdCluster(t) + clusterB := newManagerTestEtcdCluster(t) + t.Cleanup(func() { clusterA.close(t) }) + t.Cleanup(func() { clusterB.close(t) }) + + clusterA.putTopology(t, "shared.tidb:4000", &infosync.TiDBTopologyInfo{IP: "10.0.0.1", StatusPort: 10080}) + clusterB.putTopology(t, "shared.tidb:4000", &infosync.TiDBTopologyInfo{IP: "10.0.0.2", StatusPort: 10080}) + + cfg := newManagerTestConfig() + cfg.Proxy.BackendClusters = []config.BackendCluster{ + {Name: "cluster-a", PDAddrs: clusterA.addr}, + {Name: "cluster-b", PDAddrs: clusterB.addr}, + } + cfgGetter := newManagerTestConfigGetter(cfg) + + mgr := NewManager(zapLoggerForTest(t), nilClusterTLS) + require.NoError(t, mgr.Start(context.Background(), cfgGetter, nil)) + t.Cleanup(func() { + require.NoError(t, mgr.Close()) + }) + + require.Eventually(t, func() bool { + topology, err := mgr.GetTiDBTopology(context.Background()) + if err != nil || len(topology) != 2 { + return false + } + infoA, okA := topology[backendID("cluster-a", "shared.tidb:4000")] + infoB, okB := topology[backendID("cluster-b", "shared.tidb:4000")] + return okA && okB && infoA.Addr == "shared.tidb:4000" && infoB.Addr == "shared.tidb:4000" + }, 5*time.Second, 100*time.Millisecond) +} + +type managerTestConfigGetter struct { + mu sync.RWMutex + cfg *config.Config +} + +func newManagerTestConfigGetter(cfg *config.Config) *managerTestConfigGetter { + return &managerTestConfigGetter{cfg: cfg} +} + +func (g *managerTestConfigGetter) GetConfig() *config.Config { + g.mu.RLock() + defer g.mu.RUnlock() + return g.cfg +} + +func (g *managerTestConfigGetter) setConfig(cfg *config.Config) { + g.mu.Lock() + g.cfg = cfg + g.mu.Unlock() +} + +type managerTestEtcdCluster struct { + etcd *embed.Etcd + client *clientv3.Client + kv clientv3.KV + addr string +} + +func newManagerTestEtcdCluster(t *testing.T) *managerTestEtcdCluster { + lg, _ := logger.CreateLoggerForTest(t) + etcdSrv, err := etcd.CreateEtcdServer("0.0.0.0:0", t.TempDir(), lg) + require.NoError(t, err) + addr := etcdSrv.Clients[0].Addr().String() + cli, err := etcd.InitEtcdClientWithAddrs(lg, addr, nil) + require.NoError(t, err) + return &managerTestEtcdCluster{ + etcd: etcdSrv, + client: cli, + kv: clientv3.NewKV(cli), + addr: addr, + } +} + +func (tec *managerTestEtcdCluster) close(t *testing.T) { + require.NoError(t, tec.client.Close()) + tec.etcd.Close() +} + +func (tec *managerTestEtcdCluster) putTopology(t *testing.T, sqlAddr string, info *infosync.TiDBTopologyInfo) { + data, err := json.Marshal(info) + require.NoError(t, err) + _, err = tec.kv.Put(context.Background(), path.Join(testTiDBTopologyPath, sqlAddr, testInfoSuffix), string(data)) + require.NoError(t, err) + _, err = tec.kv.Put(context.Background(), path.Join(testTiDBTopologyPath, sqlAddr, testTTLSuffix), "1") + require.NoError(t, err) +} + +func newManagerTestConfig() *config.Config { + cfg := config.NewConfig() + cfg.Proxy.Addr = "127.0.0.1:6000" + cfg.API.Addr = "127.0.0.1:3080" + cfg.Proxy.PDAddrs = "" + cfg.Proxy.BackendClusters = nil + return cfg +} + +func zapLoggerForTest(t *testing.T) *zap.Logger { + lg, _ := logger.CreateLoggerForTest(t) + return lg +} diff --git a/pkg/manager/infosync/info.go b/pkg/manager/infosync/info.go index 4d4e9f3d4..9944abbed 100644 --- a/pkg/manager/infosync/info.go +++ b/pkg/manager/infosync/info.go @@ -93,6 +93,8 @@ type TopologyInfo struct { type TiDBTopologyInfo struct { Version string `json:"version"` GitHash string `json:"git_hash"` + Addr string `json:"-"` + ClusterName string `json:"-"` IP string `json:"ip"` StatusPort uint `json:"status_port"` DeployPath string `json:"deploy_path"` @@ -295,6 +297,7 @@ func (is *InfoSyncer) GetTiDBTopology(ctx context.Context) (map[string]*TiDBTopo zap.String("value", hack.String(kv.Value)), zap.Error(err)) } else { infos[addr] = topology + topology.Addr = addr topology.Keyspace = keyspace } } diff --git a/pkg/manager/infosync/info_test.go b/pkg/manager/infosync/info_test.go index e6ac7d311..966e8fad7 100644 --- a/pkg/manager/infosync/info_test.go +++ b/pkg/manager/infosync/info_test.go @@ -128,6 +128,7 @@ func TestFetchTiDBTopology(t *testing.T) { check: func(info map[string]*TiDBTopologyInfo) { require.Len(ts.t, info, 1) require.NotNil(ts.t, info["1.1.1.1:4000"]) + require.Equal(ts.t, "1.1.1.1:4000", info["1.1.1.1:4000"].Addr) require.Equal(ts.t, "1.1.1.1", info["1.1.1.1:4000"].IP) require.Equal(ts.t, uint(10080), info["1.1.1.1:4000"].StatusPort) }, @@ -144,6 +145,7 @@ func TestFetchTiDBTopology(t *testing.T) { check: func(info map[string]*TiDBTopologyInfo) { require.Len(ts.t, info, 2) require.NotNil(ts.t, info["2.2.2.2:4000"]) + require.Equal(ts.t, "2.2.2.2:4000", info["2.2.2.2:4000"].Addr) require.Equal(ts.t, "2.2.2.2", info["2.2.2.2:4000"].IP) require.Equal(ts.t, uint(10080), info["2.2.2.2:4000"].StatusPort) }, @@ -170,6 +172,7 @@ func TestFetchTiDBTopology(t *testing.T) { check: func(info map[string]*TiDBTopologyInfo) { require.Len(ts.t, info, 2) require.NotNil(ts.t, info["3.3.3.3:4000"]) + require.Equal(ts.t, "3.3.3.3:4000", info["3.3.3.3:4000"].Addr) require.Equal(ts.t, "3.3.3.3", info["3.3.3.3:4000"].IP) require.Equal(ts.t, uint(10080), info["3.3.3.3:4000"].StatusPort) require.Equal(ts.t, "test", info["3.3.3.3:4000"].Keyspace) diff --git a/pkg/manager/namespace/manager.go b/pkg/manager/namespace/manager.go index f66c2301f..def3a5776 100644 --- a/pkg/manager/namespace/manager.go +++ b/pkg/manager/namespace/manager.go @@ -10,7 +10,6 @@ import ( "context" "fmt" "maps" - "reflect" "sync" "github.com/pingcap/tiproxy/lib/config" @@ -54,14 +53,11 @@ func NewNamespaceManager() *namespaceManager { func (mgr *namespaceManager) buildNamespace(cfg *config.Namespace) (*Namespace, error) { logger := mgr.logger.With(zap.String("namespace", cfg.Namespace)) - // init BackendFetcher - var fetcher observer.BackendFetcher healthCheckCfg := config.NewDefaultHealthCheckConfig() - if mgr.tpFetcher != nil && !reflect.ValueOf(mgr.tpFetcher).IsNil() { - fetcher = observer.NewPDFetcher(mgr.tpFetcher, logger.Named("be_fetcher"), healthCheckCfg) - } else { - fetcher = observer.NewStaticFetcher(cfg.Backend.Instances) - } + // Namespace always receives a topology fetcher from the cluster manager. PDFetcher preserves + // legacy static backend.instances compatibility by falling back internally before any backend + // cluster is configured. + fetcher := observer.NewPDFetcher(mgr.tpFetcher, cfg.Backend.Instances, logger.Named("be_fetcher"), healthCheckCfg) // init Router rt := router.NewScoreBasedRouter(logger.Named("router")) diff --git a/pkg/manager/namespace/manager_test.go b/pkg/manager/namespace/manager_test.go index fb2609975..fbe6f15fd 100644 --- a/pkg/manager/namespace/manager_test.go +++ b/pkg/manager/namespace/manager_test.go @@ -4,16 +4,28 @@ package namespace import ( + "context" "testing" "github.com/pingcap/tiproxy/pkg/balance/router" + "github.com/pingcap/tiproxy/pkg/manager/infosync" "github.com/stretchr/testify/require" "go.uber.org/zap" ) +type mockTopologyFetcher struct{} + +func (*mockTopologyFetcher) GetTiDBTopology(context.Context) (map[string]*infosync.TiDBTopologyInfo, error) { + return nil, nil +} + +func (*mockTopologyFetcher) HasBackendClusters() bool { + return false +} + func TestReady(t *testing.T) { nsMgr := NewNamespaceManager() - require.NoError(t, nsMgr.Init(zap.NewNop(), nil, nil, nil, nil, nil, nil)) + require.NoError(t, nsMgr.Init(zap.NewNop(), nil, &mockTopologyFetcher{}, nil, nil, nil, nil)) require.False(t, nsMgr.Ready()) rt := router.NewStaticRouter([]string{}) diff --git a/pkg/proxy/backend/backend_conn_mgr.go b/pkg/proxy/backend/backend_conn_mgr.go index 9619d60cd..faeb33189 100644 --- a/pkg/proxy/backend/backend_conn_mgr.go +++ b/pkg/proxy/backend/backend_conn_mgr.go @@ -591,8 +591,8 @@ func (mgr *BackendConnManager) tryRedirect(ctx context.Context) { } rs := &redirectResult{ - from: mgr.ServerAddr(), - to: (*backendInst).Addr(), + from: mgr.curBackend.ID(), + to: (*backendInst).ID(), } defer func() { // The `mgr` won't be notified again before it calls `OnRedirectSucceed`, so simply `StorePointer` is also fine. @@ -639,12 +639,12 @@ func (mgr *BackendConnManager) tryRedirect(ctx context.Context) { } var cn net.Conn - cn, rs.err = net.DialTimeout("tcp", rs.to, mgr.config.DialTimeout) + cn, rs.err = net.DialTimeout("tcp", (*backendInst).Addr(), mgr.config.DialTimeout) if rs.err != nil { - mgr.handshakeHandler.OnHandshake(mgr, rs.to, rs.err, SrcBackendNetwork) + mgr.handshakeHandler.OnHandshake(mgr, (*backendInst).Addr(), rs.err, SrcBackendNetwork) return } - newBackendIO := pnet.PacketIO(pnet.NewPacketIO(cn, mgr.logger, mgr.config.ConnBufferSize, pnet.WithRemoteAddr(rs.to, cn.RemoteAddr()), pnet.WithWrapError(ErrBackendConn))) + newBackendIO := pnet.PacketIO(pnet.NewPacketIO(cn, mgr.logger, mgr.config.ConnBufferSize, pnet.WithRemoteAddr((*backendInst).Addr(), cn.RemoteAddr()), pnet.WithWrapError(ErrBackendConn))) if rs.err = mgr.authenticator.handshakeSecondTime(mgr.logger, mgr.clientIO, newBackendIO, mgr.backendTLS, sessionToken); rs.err == nil { rs.err = mgr.initSessionStates(newBackendIO, sessionStates) @@ -833,9 +833,7 @@ func (mgr *BackendConnManager) Close() error { handErr := mgr.handshakeHandler.OnConnClose(mgr, mgr.quitSource) var connErr error - var addr string if backendIO := mgr.backendIO.Swap(nil); backendIO != nil { - addr = (*backendIO).RemoteAddr().String() connErr = (*backendIO).Close() } @@ -846,13 +844,14 @@ func (mgr *BackendConnManager) Close() error { mgr.notifyRedirectResult(context.Background(), <-mgr.redirectResCh) } // The connection may have just received the redirecting signal. - if len(addr) > 0 { - var redirectingAddr string + if mgr.curBackend != nil { + var redirectingBackendID string if redirectingBackend := mgr.redirectInfo.Load(); redirectingBackend != nil { - redirectingAddr = (*redirectingBackend).Addr() + redirectingBackendID = (*redirectingBackend).ID() } - if err := eventReceiver.OnConnClosed(addr, redirectingAddr, mgr); err != nil { - mgr.logger.Error("close connection error", zap.String("backend_addr", addr), zap.NamedError("notify_err", err)) + if err := eventReceiver.OnConnClosed(mgr.curBackend.ID(), redirectingBackendID, mgr); err != nil { + mgr.logger.Error("close connection error", + zap.String("backend_id", mgr.curBackend.ID()), zap.String("backend_addr", mgr.curBackend.Addr()), zap.NamedError("notify_err", err)) } } } diff --git a/pkg/proxy/backend/backend_conn_mgr_test.go b/pkg/proxy/backend/backend_conn_mgr_test.go index 05b95b46e..2532bd83f 100644 --- a/pkg/proxy/backend/backend_conn_mgr_test.go +++ b/pkg/proxy/backend/backend_conn_mgr_test.go @@ -62,10 +62,10 @@ func (mer *mockEventReceiver) OnRedirectFail(from, to string, conn router.Redire return nil } -func (mer *mockEventReceiver) OnConnClosed(from, to string, conn router.RedirectableConn) error { +func (mer *mockEventReceiver) OnConnClosed(backendID, redirectingBackendID string, conn router.RedirectableConn) error { mer.eventCh <- event{ - from: from, - to: to, + from: backendID, + to: redirectingBackendID, eventName: eventClose, } return nil @@ -97,6 +97,10 @@ func (mbi *mockBackendInst) Addr() string { return mbi.addr } +func (mbi *mockBackendInst) ID() string { + return mbi.addr +} + func (mbi *mockBackendInst) Healthy() bool { return mbi.healthy.Load() } diff --git a/pkg/server/server.go b/pkg/server/server.go index ab80c8912..2e1cc836e 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -11,10 +11,10 @@ import ( "github.com/pingcap/tiproxy/lib/config" "github.com/pingcap/tiproxy/lib/util/errors" "github.com/pingcap/tiproxy/pkg/balance/metricsreader" + "github.com/pingcap/tiproxy/pkg/manager/backendcluster" "github.com/pingcap/tiproxy/pkg/manager/cert" mgrcfg "github.com/pingcap/tiproxy/pkg/manager/config" "github.com/pingcap/tiproxy/pkg/manager/id" - "github.com/pingcap/tiproxy/pkg/manager/infosync" "github.com/pingcap/tiproxy/pkg/manager/logger" "github.com/pingcap/tiproxy/pkg/manager/memory" "github.com/pingcap/tiproxy/pkg/manager/meter" @@ -26,7 +26,6 @@ import ( "github.com/pingcap/tiproxy/pkg/sctx" "github.com/pingcap/tiproxy/pkg/server/api" mgrrp "github.com/pingcap/tiproxy/pkg/sqlreplay/manager" - "github.com/pingcap/tiproxy/pkg/util/etcd" "github.com/pingcap/tiproxy/pkg/util/http" "github.com/pingcap/tiproxy/pkg/util/versioninfo" "github.com/pingcap/tiproxy/pkg/util/waitgroup" @@ -43,8 +42,8 @@ type Server struct { metricsManager *metrics.MetricsManager loggerManager *logger.LoggerManager certManager *cert.CertManager + clusterManager *backendcluster.Manager vipManager vip.VIPManager - infoSyncer *infosync.InfoSyncer metricsReader metricsreader.MetricsReader replay mgrrp.JobManager meter *meter.Meter @@ -107,29 +106,26 @@ func NewServer(ctx context.Context, sctx *sctx.Context) (srv *Server, err error) return } - // setup etcd client - srv.etcdCli, err = etcd.InitEtcdClient(lg.Named("etcd"), cfg, srv.certManager) - if err != nil { + // setup backend cluster manager + srv.clusterManager = backendcluster.NewManager(lg.Named("backendcluster"), srv.certManager.ClusterTLS) + if err = srv.clusterManager.Start(ctx, srv.configManager, srv.configManager.WatchConfig()); err != nil { return } + var promFetcher metricsreader.PromInfoFetcher + if cluster := srv.clusterManager.PrimaryCluster(); cluster != nil { + srv.etcdCli = cluster.EtcdClient() + promFetcher = cluster + } // general cluster HTTP client { srv.httpCli = http.NewHTTPClient(srv.certManager.ClusterTLS) } - // setup info syncer - if cfg.Proxy.PDAddrs != "" { - srv.infoSyncer = infosync.NewInfoSyncer(lg.Named("infosync"), srv.etcdCli) - if err = srv.infoSyncer.Init(ctx, cfg); err != nil { - return - } - } - // setup metrics reader { healthCheckCfg := config.NewDefaultHealthCheckConfig() - srv.metricsReader = metricsreader.NewDefaultMetricsReader(lg.Named("mr"), srv.infoSyncer, srv.infoSyncer, srv.httpCli, srv.etcdCli, healthCheckCfg, srv.configManager) + srv.metricsReader = metricsreader.NewDefaultMetricsReader(lg.Named("mr"), promFetcher, srv.clusterManager, srv.httpCli, srv.etcdCli, healthCheckCfg, srv.configManager) if err = srv.metricsReader.Start(ctx); err != nil { return } @@ -157,7 +153,7 @@ func NewServer(ctx context.Context, sctx *sctx.Context) (srv *Server, err error) nscs = append(nscs, nsc) } - err = srv.namespaceManager.Init(lg.Named("nsmgr"), nscs, srv.infoSyncer, srv.infoSyncer, srv.httpCli, srv.configManager, srv.metricsReader) + err = srv.namespaceManager.Init(lg.Named("nsmgr"), nscs, srv.clusterManager, promFetcher, srv.httpCli, srv.configManager, srv.metricsReader) if err != nil { return } @@ -214,8 +210,12 @@ func NewServer(ctx context.Context, sctx *sctx.Context) (srv *Server, err error) return } if srv.vipManager != nil && !reflect.ValueOf(srv.vipManager).IsNil() { - if err = srv.vipManager.Start(ctx, srv.etcdCli); err != nil { - return + if srv.etcdCli != nil { + if err = srv.vipManager.Start(ctx, srv.etcdCli); err != nil { + return + } + } else { + lg.Info("VIP is disabled because backend cluster count is not 1") } } } @@ -283,9 +283,6 @@ func (s *Server) Close() error { if s.memManager != nil { s.memManager.Close() } - if s.infoSyncer != nil { - errs = append(errs, s.infoSyncer.Close()) - } if s.configManager != nil { errs = append(errs, s.configManager.Close()) } @@ -295,8 +292,8 @@ func (s *Server) Close() error { if s.loggerManager != nil { errs = append(errs, s.loggerManager.Close()) } - if s.etcdCli != nil { - errs = append(errs, s.etcdCli.Close()) + if s.clusterManager != nil { + errs = append(errs, s.clusterManager.Close()) } s.wg.Wait() return errors.Collect(ErrCloseServer, errs...) diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index 98f6bcd6a..e45a3f840 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -12,10 +12,14 @@ import ( "github.com/pingcap/tiproxy/lib/util/logger" "github.com/pingcap/tiproxy/pkg/sctx" "github.com/pingcap/tiproxy/pkg/util/etcd" + "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/require" ) func TestServer(t *testing.T) { + restore := resetPromRegistry() + defer restore() + dir := t.TempDir() lg, _ := logger.CreateLoggerForTest(t) etcdServer, err := etcd.CreateEtcdServer("0.0.0.0:0", dir, lg) @@ -34,3 +38,30 @@ func TestServer(t *testing.T) { require.NoError(t, server.Close()) etcdServer.Close() } + +func TestServerWithoutBackendCluster(t *testing.T) { + restore := resetPromRegistry() + defer restore() + + dir := t.TempDir() + configFile := dir + "/config.toml" + require.NoError(t, os.WriteFile(configFile, []byte("[proxy]\npd-addrs = \"\"\n"), 0o644)) + + server, err := NewServer(context.Background(), &sctx.Context{ + ConfigFile: configFile, + }) + require.NoError(t, err) + require.NoError(t, server.Close()) +} + +func resetPromRegistry() func() { + registry := prometheus.NewRegistry() + oldRegisterer := prometheus.DefaultRegisterer + oldGatherer := prometheus.DefaultGatherer + prometheus.DefaultRegisterer = registry + prometheus.DefaultGatherer = registry + return func() { + prometheus.DefaultRegisterer = oldRegisterer + prometheus.DefaultGatherer = oldGatherer + } +} diff --git a/pkg/util/etcd/etcd.go b/pkg/util/etcd/etcd.go index 1508e9141..58049d21d 100644 --- a/pkg/util/etcd/etcd.go +++ b/pkg/util/etcd/etcd.go @@ -5,6 +5,7 @@ package etcd import ( "context" + "crypto/tls" "fmt" "net/url" "strings" @@ -31,11 +32,16 @@ func InitEtcdClient(logger *zap.Logger, cfg *config.Config, certMgr *cert.CertMa // use tidb server addresses directly return nil, nil } - pdEndpoints := strings.Split(pdAddr, ",") + return InitEtcdClientWithAddrs(logger, pdAddr, certMgr.ClusterTLS()) +} + +// InitEtcdClientWithAddrs initializes an etcd client that connects to PD ETCD servers. +func InitEtcdClientWithAddrs(logger *zap.Logger, pdAddrs string, tlsConfig *tls.Config) (*clientv3.Client, error) { + pdEndpoints := strings.Split(pdAddrs, ",") logger.Info("connect ETCD servers", zap.Strings("addrs", pdEndpoints)) etcdClient, err := clientv3.New(clientv3.Config{ Endpoints: pdEndpoints, - TLS: certMgr.ClusterTLS(), + TLS: tlsConfig, Logger: logger.Named("etcdcli"), AutoSyncInterval: 30 * time.Second, DialTimeout: 5 * time.Second,