fix: prevent concurrent write to websocket connection

- refactor client handling and message broadcasting
- replace Client struct
- implement SendMessage method in Client struct for safer message sending
- update client map to use *models.Client instead of bool
- adjust BroadcastMessage and RunPomodoroTimer functions to use new client type

🤖
This commit is contained in:
Sebastian Mark 2024-10-20 11:06:37 +02:00
parent ffc6913344
commit 4471c86a0c
7 changed files with 53 additions and 17 deletions

View file

@ -8,7 +8,7 @@ import (
) )
// BroadcastMessage sends a message to all connected WebSocket clients. // BroadcastMessage sends a message to all connected WebSocket clients.
func BroadcastMessage(clients map[*websocket.Conn]bool, message models.BroadcastMessage) { func BroadcastMessage(clients map[*websocket.Conn]*models.Client, message models.BroadcastMessage) {
// Marshal the message into JSON format // Marshal the message into JSON format
jsonMessage, err := json.Marshal(message) jsonMessage, err := json.Marshal(message)
if err != nil { if err != nil {
@ -17,12 +17,11 @@ func BroadcastMessage(clients map[*websocket.Conn]bool, message models.Broadcast
} }
// Iterate over all connected clients and broadcast the message // Iterate over all connected clients and broadcast the message
for client := range clients { for _, client := range clients {
err := client.WriteMessage(websocket.TextMessage, jsonMessage) err := client.SendMessage(websocket.TextMessage, jsonMessage)
if err != nil { if err != nil {
log.Printf("Error broadcasting to client: %v", err) log.Printf("Error broadcasting to client: %v", err)
client.Close() // The client is responsible for closing itself on error
delete(clients, client) // Remove the client if an error occurs
} }
} }
} }

View file

@ -24,7 +24,7 @@ var pomodoroResumeChannel = make(chan bool, 1)
var mu sync.Mutex // to synchronize access to shared state var mu sync.Mutex // to synchronize access to shared state
// RunPomodoroTimer iterates the Pomodoro work/break sessions. // RunPomodoroTimer iterates the Pomodoro work/break sessions.
func RunPomodoroTimer(clients map[*websocket.Conn]bool) { func RunPomodoroTimer(clients map[*websocket.Conn]*models.Client) {
mu.Lock() mu.Lock()
pomodoroRunning = true pomodoroRunning = true
pomodoroPaused = false pomodoroPaused = false
@ -51,7 +51,7 @@ func RunPomodoroTimer(clients map[*websocket.Conn]bool) {
} }
// ResetPomodoro resets the running Pomodoro timer. // ResetPomodoro resets the running Pomodoro timer.
func ResetPomodoro(clients map[*websocket.Conn]bool) { func ResetPomodoro(clients map[*websocket.Conn]*models.Client) {
mu.Lock() mu.Lock()
pomodoroRunning = false // Reset the running state pomodoroRunning = false // Reset the running state
pomodoroPaused = false // Reset the paused state pomodoroPaused = false // Reset the paused state

View file

@ -8,7 +8,7 @@ import (
) )
// startTimer runs the countdown and broadcasts every second. // startTimer runs the countdown and broadcasts every second.
func startTimer(clients map[*websocket.Conn]bool, remainingSeconds int, mode string, session int) bool { func startTimer(clients map[*websocket.Conn]*models.Client, remainingSeconds int, mode string, session int) bool {
for remainingSeconds > 0 { for remainingSeconds > 0 {
select { select {
case <-pomodoroResetChannel: case <-pomodoroResetChannel:

View file

@ -6,10 +6,22 @@ import (
"git.smsvc.net/pomodoro/GoTomato/pkg/models" "git.smsvc.net/pomodoro/GoTomato/pkg/models"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"log" "log"
"sync"
) )
// Clients is a map of connected WebSocket clients, where each client is represented by the Client struct
var Clients = make(map[*websocket.Conn]*models.Client)
var mu sync.Mutex // Mutex to protect access to the Clients map
// handleClientCommands listens for commands from WebSocket clients and dispatches to the timer. // handleClientCommands listens for commands from WebSocket clients and dispatches to the timer.
func handleClientCommands(ws *websocket.Conn) { func handleClientCommands(ws *websocket.Conn) {
// Create a new Client and add it to the Clients map
mu.Lock()
Clients[ws] = &models.Client{
Conn: ws,
}
mu.Unlock()
for { for {
_, message, err := ws.ReadMessage() _, message, err := ws.ReadMessage()
if err != nil { if err != nil {

View file

@ -1,14 +1,12 @@
package websocket package websocket
import ( import (
"git.smsvc.net/pomodoro/GoTomato/pkg/models"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"log" "log"
"net/http" "net/http"
) )
// Map to track connected clients
var Clients = make(map[*websocket.Conn]bool)
// Upgrader to upgrade HTTP requests to WebSocket connections // Upgrader to upgrade HTTP requests to WebSocket connections
var upgrader = websocket.Upgrader{ var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true }, CheckOrigin: func(r *http.Request) bool { return true },
@ -25,7 +23,9 @@ func HandleConnections(w http.ResponseWriter, r *http.Request) {
defer ws.Close() defer ws.Close()
// Register the new client // Register the new client
Clients[ws] = true Clients[ws] = &models.Client{
Conn: ws, // Store the WebSocket connection
}
// Listen for commands from the connected client // Listen for commands from the connected client
handleClientCommands(ws) handleClientCommands(ws)

View file

@ -7,8 +7,3 @@ type BroadcastMessage struct {
MaxSession int `json:"max_session"` // Total number of sessions MaxSession int `json:"max_session"` // Total number of sessions
TimeLeft int `json:"time_left"` // Remaining time in seconds TimeLeft int `json:"time_left"` // Remaining time in seconds
} }
// ClientCommand represents a command from the client (start/stop).
type ClientCommand struct {
Command string `json:"command"`
}

30
pkg/models/client.go Normal file
View file

@ -0,0 +1,30 @@
package models
import (
"github.com/gorilla/websocket"
"log"
"sync"
)
// ClientCommand represents a command from the client (start/stop).
type ClientCommand struct {
Command string `json:"command"`
}
type Client struct {
Conn *websocket.Conn
Mutex sync.Mutex
}
// It automatically locks and unlocks the mutex to ensure that only one goroutine can write at a time.
func (c *Client) SendMessage(messageType int, data []byte) error {
c.Mutex.Lock()
defer c.Mutex.Unlock()
err := c.Conn.WriteMessage(messageType, data)
if err != nil {
log.Printf("Error writing to WebSocket: %v", err)
c.Conn.Close() // Close the connection on error
}
return err
}