From 4471c86a0cf340178639ef97e331a14d11cb8df2 Mon Sep 17 00:00:00 2001 From: Sebastian Mark Date: Sun, 20 Oct 2024 11:06:37 +0200 Subject: [PATCH 1/3] fix: prevent concurrent write to websocket connection MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - refactor client handling and message broadcasting - replace Client struct - implement SendMessage method in Client struct for safer message sending - update client map to use *models.Client instead of bool - adjust BroadcastMessage and RunPomodoroTimer functions to use new client type 🤖 --- internal/broadcast/broadcast.go | 9 ++++--- internal/pomodoro/pomodoro.go | 4 ++-- internal/pomodoro/timer.go | 2 +- internal/websocket/client_commands.go | 12 ++++++++++ internal/websocket/handle_connections.go | 8 +++---- pkg/models/{types.go => broadcast.go} | 5 ---- pkg/models/client.go | 30 ++++++++++++++++++++++++ 7 files changed, 53 insertions(+), 17 deletions(-) rename pkg/models/{types.go => broadcast.go} (75%) create mode 100644 pkg/models/client.go diff --git a/internal/broadcast/broadcast.go b/internal/broadcast/broadcast.go index deb3010..229ae67 100644 --- a/internal/broadcast/broadcast.go +++ b/internal/broadcast/broadcast.go @@ -8,7 +8,7 @@ import ( ) // BroadcastMessage sends a message to all connected WebSocket clients. -func BroadcastMessage(clients map[*websocket.Conn]bool, message models.BroadcastMessage) { +func BroadcastMessage(clients map[*websocket.Conn]*models.Client, message models.BroadcastMessage) { // Marshal the message into JSON format jsonMessage, err := json.Marshal(message) if err != nil { @@ -17,12 +17,11 @@ func BroadcastMessage(clients map[*websocket.Conn]bool, message models.Broadcast } // Iterate over all connected clients and broadcast the message - for client := range clients { - err := client.WriteMessage(websocket.TextMessage, jsonMessage) + for _, client := range clients { + err := client.SendMessage(websocket.TextMessage, jsonMessage) if err != nil { log.Printf("Error broadcasting to client: %v", err) - client.Close() - delete(clients, client) // Remove the client if an error occurs + // The client is responsible for closing itself on error } } } diff --git a/internal/pomodoro/pomodoro.go b/internal/pomodoro/pomodoro.go index 6297d45..0cf8514 100644 --- a/internal/pomodoro/pomodoro.go +++ b/internal/pomodoro/pomodoro.go @@ -24,7 +24,7 @@ var pomodoroResumeChannel = make(chan bool, 1) var mu sync.Mutex // to synchronize access to shared state // RunPomodoroTimer iterates the Pomodoro work/break sessions. -func RunPomodoroTimer(clients map[*websocket.Conn]bool) { +func RunPomodoroTimer(clients map[*websocket.Conn]*models.Client) { mu.Lock() pomodoroRunning = true pomodoroPaused = false @@ -51,7 +51,7 @@ func RunPomodoroTimer(clients map[*websocket.Conn]bool) { } // ResetPomodoro resets the running Pomodoro timer. -func ResetPomodoro(clients map[*websocket.Conn]bool) { +func ResetPomodoro(clients map[*websocket.Conn]*models.Client) { mu.Lock() pomodoroRunning = false // Reset the running state pomodoroPaused = false // Reset the paused state diff --git a/internal/pomodoro/timer.go b/internal/pomodoro/timer.go index 6a8e39b..b71f770 100644 --- a/internal/pomodoro/timer.go +++ b/internal/pomodoro/timer.go @@ -8,7 +8,7 @@ import ( ) // startTimer runs the countdown and broadcasts every second. -func startTimer(clients map[*websocket.Conn]bool, remainingSeconds int, mode string, session int) bool { +func startTimer(clients map[*websocket.Conn]*models.Client, remainingSeconds int, mode string, session int) bool { for remainingSeconds > 0 { select { case <-pomodoroResetChannel: diff --git a/internal/websocket/client_commands.go b/internal/websocket/client_commands.go index 50fab15..91ee750 100644 --- a/internal/websocket/client_commands.go +++ b/internal/websocket/client_commands.go @@ -6,10 +6,22 @@ import ( "git.smsvc.net/pomodoro/GoTomato/pkg/models" "github.com/gorilla/websocket" "log" + "sync" ) +// Clients is a map of connected WebSocket clients, where each client is represented by the Client struct +var Clients = make(map[*websocket.Conn]*models.Client) +var mu sync.Mutex // Mutex to protect access to the Clients map + // handleClientCommands listens for commands from WebSocket clients and dispatches to the timer. func handleClientCommands(ws *websocket.Conn) { + // Create a new Client and add it to the Clients map + mu.Lock() + Clients[ws] = &models.Client{ + Conn: ws, + } + mu.Unlock() + for { _, message, err := ws.ReadMessage() if err != nil { diff --git a/internal/websocket/handle_connections.go b/internal/websocket/handle_connections.go index a7d6cfc..0b8f63e 100644 --- a/internal/websocket/handle_connections.go +++ b/internal/websocket/handle_connections.go @@ -1,14 +1,12 @@ package websocket import ( + "git.smsvc.net/pomodoro/GoTomato/pkg/models" "github.com/gorilla/websocket" "log" "net/http" ) -// Map to track connected clients -var Clients = make(map[*websocket.Conn]bool) - // Upgrader to upgrade HTTP requests to WebSocket connections var upgrader = websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true }, @@ -25,7 +23,9 @@ func HandleConnections(w http.ResponseWriter, r *http.Request) { defer ws.Close() // Register the new client - Clients[ws] = true + Clients[ws] = &models.Client{ + Conn: ws, // Store the WebSocket connection + } // Listen for commands from the connected client handleClientCommands(ws) diff --git a/pkg/models/types.go b/pkg/models/broadcast.go similarity index 75% rename from pkg/models/types.go rename to pkg/models/broadcast.go index ea04e21..78463b9 100644 --- a/pkg/models/types.go +++ b/pkg/models/broadcast.go @@ -7,8 +7,3 @@ type BroadcastMessage struct { MaxSession int `json:"max_session"` // Total number of sessions TimeLeft int `json:"time_left"` // Remaining time in seconds } - -// ClientCommand represents a command from the client (start/stop). -type ClientCommand struct { - Command string `json:"command"` -} diff --git a/pkg/models/client.go b/pkg/models/client.go new file mode 100644 index 0000000..8d5c3e9 --- /dev/null +++ b/pkg/models/client.go @@ -0,0 +1,30 @@ +package models + +import ( + "github.com/gorilla/websocket" + "log" + "sync" +) + +// ClientCommand represents a command from the client (start/stop). +type ClientCommand struct { + Command string `json:"command"` +} + +type Client struct { + Conn *websocket.Conn + Mutex sync.Mutex +} + +// It automatically locks and unlocks the mutex to ensure that only one goroutine can write at a time. +func (c *Client) SendMessage(messageType int, data []byte) error { + c.Mutex.Lock() + defer c.Mutex.Unlock() + + err := c.Conn.WriteMessage(messageType, data) + if err != nil { + log.Printf("Error writing to WebSocket: %v", err) + c.Conn.Close() // Close the connection on error + } + return err +} From b62e92b5a402384f64205576003dd3891a86d169 Mon Sep 17 00:00:00 2001 From: Sebastian Mark Date: Sun, 20 Oct 2024 11:26:23 +0200 Subject: [PATCH 2/3] fix: resolve race condition in pause/resume check - replace direct check of pomodoroPaused with IsPomodoroPaused function --- internal/pomodoro/timer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/pomodoro/timer.go b/internal/pomodoro/timer.go index b71f770..69627c3 100644 --- a/internal/pomodoro/timer.go +++ b/internal/pomodoro/timer.go @@ -23,7 +23,7 @@ func startTimer(clients map[*websocket.Conn]*models.Client, remainingSeconds int mu.Unlock() default: // Broadcast the current state to all clients - if !pomodoroPaused { + if !IsPomodoroPaused() { broadcast.BroadcastMessage(clients, models.BroadcastMessage{ Mode: mode, Session: session, From 3d5cb29c54b2a17455e20189695706c381a28dbb Mon Sep 17 00:00:00 2001 From: Sebastian Mark Date: Sun, 20 Oct 2024 11:27:28 +0200 Subject: [PATCH 3/3] chore: cleanup pause/resume functions --- internal/pomodoro/pomodoro.go | 7 +++++++ internal/pomodoro/timer.go | 8 ++------ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/internal/pomodoro/pomodoro.go b/internal/pomodoro/pomodoro.go index 0cf8514..c8a3b03 100644 --- a/internal/pomodoro/pomodoro.go +++ b/internal/pomodoro/pomodoro.go @@ -70,10 +70,17 @@ func ResetPomodoro(clients map[*websocket.Conn]*models.Client) { } func PausePomodoro() { + mu.Lock() + pomodoroPaused = true + mu.Unlock() + pomodoroPauseChannel <- true } func ResumePomodoro() { + mu.Lock() + pomodoroPaused = false + mu.Unlock() pomodoroResumeChannel <- true } diff --git a/internal/pomodoro/timer.go b/internal/pomodoro/timer.go index 69627c3..96595b3 100644 --- a/internal/pomodoro/timer.go +++ b/internal/pomodoro/timer.go @@ -14,13 +14,9 @@ func startTimer(clients map[*websocket.Conn]*models.Client, remainingSeconds int case <-pomodoroResetChannel: return false case <-pomodoroPauseChannel: - mu.Lock() - pomodoroPaused = true - mu.Unlock() + // Nothing to set here, just waiting for the signal case <-pomodoroResumeChannel: - mu.Lock() - pomodoroPaused = false - mu.Unlock() + // Nothing to set here, just waiting for the signal default: // Broadcast the current state to all clients if !IsPomodoroPaused() {