From 4471c86a0cf340178639ef97e331a14d11cb8df2 Mon Sep 17 00:00:00 2001 From: Sebastian Mark Date: Sun, 20 Oct 2024 11:06:37 +0200 Subject: [PATCH] fix: prevent concurrent write to websocket connection MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - refactor client handling and message broadcasting - replace Client struct - implement SendMessage method in Client struct for safer message sending - update client map to use *models.Client instead of bool - adjust BroadcastMessage and RunPomodoroTimer functions to use new client type 🤖 --- internal/broadcast/broadcast.go | 9 ++++--- internal/pomodoro/pomodoro.go | 4 ++-- internal/pomodoro/timer.go | 2 +- internal/websocket/client_commands.go | 12 ++++++++++ internal/websocket/handle_connections.go | 8 +++---- pkg/models/{types.go => broadcast.go} | 5 ---- pkg/models/client.go | 30 ++++++++++++++++++++++++ 7 files changed, 53 insertions(+), 17 deletions(-) rename pkg/models/{types.go => broadcast.go} (75%) create mode 100644 pkg/models/client.go diff --git a/internal/broadcast/broadcast.go b/internal/broadcast/broadcast.go index deb3010..229ae67 100644 --- a/internal/broadcast/broadcast.go +++ b/internal/broadcast/broadcast.go @@ -8,7 +8,7 @@ import ( ) // BroadcastMessage sends a message to all connected WebSocket clients. -func BroadcastMessage(clients map[*websocket.Conn]bool, message models.BroadcastMessage) { +func BroadcastMessage(clients map[*websocket.Conn]*models.Client, message models.BroadcastMessage) { // Marshal the message into JSON format jsonMessage, err := json.Marshal(message) if err != nil { @@ -17,12 +17,11 @@ func BroadcastMessage(clients map[*websocket.Conn]bool, message models.Broadcast } // Iterate over all connected clients and broadcast the message - for client := range clients { - err := client.WriteMessage(websocket.TextMessage, jsonMessage) + for _, client := range clients { + err := client.SendMessage(websocket.TextMessage, jsonMessage) if err != nil { log.Printf("Error broadcasting to client: %v", err) - client.Close() - delete(clients, client) // Remove the client if an error occurs + // The client is responsible for closing itself on error } } } diff --git a/internal/pomodoro/pomodoro.go b/internal/pomodoro/pomodoro.go index 6297d45..0cf8514 100644 --- a/internal/pomodoro/pomodoro.go +++ b/internal/pomodoro/pomodoro.go @@ -24,7 +24,7 @@ var pomodoroResumeChannel = make(chan bool, 1) var mu sync.Mutex // to synchronize access to shared state // RunPomodoroTimer iterates the Pomodoro work/break sessions. -func RunPomodoroTimer(clients map[*websocket.Conn]bool) { +func RunPomodoroTimer(clients map[*websocket.Conn]*models.Client) { mu.Lock() pomodoroRunning = true pomodoroPaused = false @@ -51,7 +51,7 @@ func RunPomodoroTimer(clients map[*websocket.Conn]bool) { } // ResetPomodoro resets the running Pomodoro timer. -func ResetPomodoro(clients map[*websocket.Conn]bool) { +func ResetPomodoro(clients map[*websocket.Conn]*models.Client) { mu.Lock() pomodoroRunning = false // Reset the running state pomodoroPaused = false // Reset the paused state diff --git a/internal/pomodoro/timer.go b/internal/pomodoro/timer.go index 6a8e39b..b71f770 100644 --- a/internal/pomodoro/timer.go +++ b/internal/pomodoro/timer.go @@ -8,7 +8,7 @@ import ( ) // startTimer runs the countdown and broadcasts every second. -func startTimer(clients map[*websocket.Conn]bool, remainingSeconds int, mode string, session int) bool { +func startTimer(clients map[*websocket.Conn]*models.Client, remainingSeconds int, mode string, session int) bool { for remainingSeconds > 0 { select { case <-pomodoroResetChannel: diff --git a/internal/websocket/client_commands.go b/internal/websocket/client_commands.go index 50fab15..91ee750 100644 --- a/internal/websocket/client_commands.go +++ b/internal/websocket/client_commands.go @@ -6,10 +6,22 @@ import ( "git.smsvc.net/pomodoro/GoTomato/pkg/models" "github.com/gorilla/websocket" "log" + "sync" ) +// Clients is a map of connected WebSocket clients, where each client is represented by the Client struct +var Clients = make(map[*websocket.Conn]*models.Client) +var mu sync.Mutex // Mutex to protect access to the Clients map + // handleClientCommands listens for commands from WebSocket clients and dispatches to the timer. func handleClientCommands(ws *websocket.Conn) { + // Create a new Client and add it to the Clients map + mu.Lock() + Clients[ws] = &models.Client{ + Conn: ws, + } + mu.Unlock() + for { _, message, err := ws.ReadMessage() if err != nil { diff --git a/internal/websocket/handle_connections.go b/internal/websocket/handle_connections.go index a7d6cfc..0b8f63e 100644 --- a/internal/websocket/handle_connections.go +++ b/internal/websocket/handle_connections.go @@ -1,14 +1,12 @@ package websocket import ( + "git.smsvc.net/pomodoro/GoTomato/pkg/models" "github.com/gorilla/websocket" "log" "net/http" ) -// Map to track connected clients -var Clients = make(map[*websocket.Conn]bool) - // Upgrader to upgrade HTTP requests to WebSocket connections var upgrader = websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true }, @@ -25,7 +23,9 @@ func HandleConnections(w http.ResponseWriter, r *http.Request) { defer ws.Close() // Register the new client - Clients[ws] = true + Clients[ws] = &models.Client{ + Conn: ws, // Store the WebSocket connection + } // Listen for commands from the connected client handleClientCommands(ws) diff --git a/pkg/models/types.go b/pkg/models/broadcast.go similarity index 75% rename from pkg/models/types.go rename to pkg/models/broadcast.go index ea04e21..78463b9 100644 --- a/pkg/models/types.go +++ b/pkg/models/broadcast.go @@ -7,8 +7,3 @@ type BroadcastMessage struct { MaxSession int `json:"max_session"` // Total number of sessions TimeLeft int `json:"time_left"` // Remaining time in seconds } - -// ClientCommand represents a command from the client (start/stop). -type ClientCommand struct { - Command string `json:"command"` -} diff --git a/pkg/models/client.go b/pkg/models/client.go new file mode 100644 index 0000000..8d5c3e9 --- /dev/null +++ b/pkg/models/client.go @@ -0,0 +1,30 @@ +package models + +import ( + "github.com/gorilla/websocket" + "log" + "sync" +) + +// ClientCommand represents a command from the client (start/stop). +type ClientCommand struct { + Command string `json:"command"` +} + +type Client struct { + Conn *websocket.Conn + Mutex sync.Mutex +} + +// It automatically locks and unlocks the mutex to ensure that only one goroutine can write at a time. +func (c *Client) SendMessage(messageType int, data []byte) error { + c.Mutex.Lock() + defer c.Mutex.Unlock() + + err := c.Conn.WriteMessage(messageType, data) + if err != nil { + log.Printf("Error writing to WebSocket: %v", err) + c.Conn.Close() // Close the connection on error + } + return err +}