Commit a5ddd5ad authored by boojack's avatar boojack

fix(server): close SSE clients during shutdown

Close long-lived SSE streams before HTTP shutdown so graceful shutdown is not held until the deadline. Also wait for background runners before closing the store to make shutdown ordering explicit.
parent a7fd1dac
......@@ -56,6 +56,7 @@ type SSEClient struct {
type SSEHub struct {
mu sync.RWMutex
clients map[*SSEClient]struct{}
closed bool
}
// NewSSEHub creates a new SSE hub.
......@@ -75,7 +76,11 @@ func (h *SSEHub) Subscribe(userID int32, role store.Role) *SSEClient {
role: role,
}
h.mu.Lock()
if h.closed {
close(c.events)
} else {
h.clients[c] = struct{}{}
}
h.mu.Unlock()
return c
}
......@@ -90,6 +95,20 @@ func (h *SSEHub) Unsubscribe(c *SSEClient) {
h.mu.Unlock()
}
// Close disconnects all subscribed SSE clients.
func (h *SSEHub) Close() {
h.mu.Lock()
defer h.mu.Unlock()
if h.closed {
return
}
h.closed = true
for c := range h.clients {
delete(h.clients, c)
close(c.events)
}
}
// Broadcast sends an event to all connected clients.
// Slow clients that have a full buffer will have the event dropped
// to avoid blocking the broadcaster.
......
......@@ -47,6 +47,28 @@ func TestSSEHub_SubscribeUnsubscribe(t *testing.T) {
assert.False(t, ok, "channel should be closed after Unsubscribe")
}
func TestSSEHub_Close(t *testing.T) {
hub := NewSSEHub()
c1 := hub.Subscribe(1, store.RoleUser)
c2 := hub.Subscribe(2, store.RoleAdmin)
hub.Close()
hub.Close()
for _, ch := range []chan []byte{c1.events, c2.events} {
_, ok := <-ch
assert.False(t, ok, "channel should be closed after hub close")
}
late := hub.Subscribe(3, store.RoleUser)
_, ok := <-late.events
assert.False(t, ok, "late subscriber should be closed immediately")
hub.Broadcast(&SSEEvent{Type: SSEEventMemoCreated, Name: "memos/123"})
hub.Unsubscribe(c1)
hub.Unsubscribe(late)
}
func TestSSEHub_Broadcast(t *testing.T) {
hub := NewSSEHub()
client := hub.Subscribe(1, store.RoleUser)
......
......@@ -2,9 +2,11 @@ package test
import (
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/require"
......@@ -75,4 +77,34 @@ func TestSSEHandler_Authentication(t *testing.T) {
e.ServeHTTP(rec, req)
require.Equal(t, http.StatusUnauthorized, rec.Code)
})
t.Run("hub close disconnects stream", func(t *testing.T) {
server := httptest.NewServer(e)
defer server.Close()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL+"/api/v1/sse", nil)
require.NoError(t, err)
req.Header.Set("Authorization", "Bearer "+token)
resp, err := server.Client().Do(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, "text/event-stream", resp.Header.Get("Content-Type"))
ts.Service.SSEHub.Close()
done := make(chan error, 1)
go func() {
_, err := io.ReadAll(resp.Body)
done <- err
}()
select {
case err := <-done:
require.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("SSE stream did not close after hub close")
}
})
}
......@@ -6,7 +6,7 @@ import (
"log/slog"
"net"
"net/http"
"runtime"
"sync"
"time"
"github.com/google/uuid"
......@@ -25,6 +25,8 @@ import (
"github.com/usememos/memos/store"
)
const shutdownTimeout = 10 * time.Second
type Server struct {
Secret string
Profile *profile.Profile
......@@ -32,7 +34,10 @@ type Server struct {
echoServer *echo.Echo
httpServer *http.Server
runnerCancelFuncs []context.CancelFunc
sseHub *apiv1.SSEHub
backgroundRunnerCancels []context.CancelFunc
backgroundRunnerWG sync.WaitGroup
}
func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store) (*Server, error) {
......@@ -67,6 +72,7 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store
rootGroup := echoServer.Group("")
apiV1Service := apiv1.NewAPIV1Service(s.Secret, profile, store)
s.sseHub = apiV1Service.SSEHub
// Register HTTP file server routes BEFORE gRPC-Gateway to ensure proper range request handling for Safari.
// This uses native HTTP serving (http.ServeContent) instead of gRPC for video/audio files.
......@@ -109,30 +115,21 @@ func (s *Server) Start(ctx context.Context) error {
slog.Error("failed to start echo server", "error", err)
}
}()
s.StartBackgroundRunners(ctx)
s.startBackgroundRunners(ctx)
return nil
}
func (s *Server) Shutdown(ctx context.Context) {
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
ctx, cancel := context.WithTimeout(ctx, shutdownTimeout)
defer cancel()
slog.Info("server shutting down")
// Cancel all background runners
for _, cancelFunc := range s.runnerCancelFuncs {
if cancelFunc != nil {
cancelFunc()
}
}
// Shutdown HTTP server.
if s.httpServer != nil {
if err := s.httpServer.Shutdown(ctx); err != nil {
slog.Error("failed to shutdown server", slog.String("error", err.Error()))
}
}
s.stopBackgroundRunners()
s.closeLongLivedConnections()
s.shutdownHTTPServer(ctx)
s.waitBackgroundRunners(ctx)
// Close database connection.
if err := s.Store.Close(); err != nil {
......@@ -142,26 +139,73 @@ func (s *Server) Shutdown(ctx context.Context) {
slog.Info("memos stopped properly")
}
func (s *Server) StartBackgroundRunners(ctx context.Context) {
func (s *Server) startBackgroundRunners(ctx context.Context) {
// Create a separate context for each background runner
// This allows us to control cancellation for each runner independently
s3Context, s3Cancel := context.WithCancel(ctx)
// Store the cancel function so we can properly shut down runners
s.runnerCancelFuncs = append(s.runnerCancelFuncs, s3Cancel)
s.backgroundRunnerCancels = append(s.backgroundRunnerCancels, s3Cancel)
// Create and start S3 presign runner
s3presignRunner := s3presign.NewRunner(s.Store)
s3presignRunner.RunOnce(ctx)
// Start continuous S3 presign runner
s.backgroundRunnerWG.Add(1)
go func() {
defer s.backgroundRunnerWG.Done()
s3presignRunner.Run(s3Context)
slog.Info("s3presign runner stopped")
}()
// Log the number of goroutines running
slog.Info("background runners started", "goroutines", runtime.NumGoroutine())
slog.Info("background runners started")
}
func (s *Server) stopBackgroundRunners() {
for _, cancelFunc := range s.backgroundRunnerCancels {
if cancelFunc != nil {
cancelFunc()
}
}
}
func (s *Server) waitBackgroundRunners(ctx context.Context) {
done := make(chan struct{})
go func() {
s.backgroundRunnerWG.Wait()
close(done)
}()
select {
case <-done:
case <-ctx.Done():
select {
case <-done:
return
default:
}
slog.Error("failed to stop background runners", slog.String("error", ctx.Err().Error()))
}
}
func (s *Server) closeLongLivedConnections() {
// Long-lived SSE requests do not finish on their own during http.Server.Shutdown.
if s.sseHub != nil {
s.sseHub.Close()
}
}
func (s *Server) shutdownHTTPServer(ctx context.Context) {
if s.httpServer == nil {
return
}
if err := s.httpServer.Shutdown(ctx); err != nil {
slog.Error("failed to shutdown server", slog.String("error", err.Error()))
if closeErr := s.httpServer.Close(); closeErr != nil && closeErr != http.ErrServerClosed {
slog.Error("failed to close server", slog.String("error", closeErr.Error()))
}
}
}
func (s *Server) getOrUpsertInstanceBasicSetting(ctx context.Context) (*storepb.InstanceBasicSetting, error) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment