diff --git a/cmd/server/main.go b/cmd/server/main.go index 78c5c99..d8eb8f6 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -54,7 +54,6 @@ func Start() { r := http.NewServeMux() r.HandleFunc("/", websocket.HandleConnection) go websocket.SendPermanentBroadCastMessage() - go websocket.RemoveStaleClients() helper.Logger.Info("GoTomato started", "version", metadata.GoTomatoVersion) helper.Logger.Info("Websocket listening", "address", listen) diff --git a/internal/websocket/broadcast.go b/internal/websocket/broadcast.go index 32a6477..def8309 100644 --- a/internal/websocket/broadcast.go +++ b/internal/websocket/broadcast.go @@ -11,27 +11,22 @@ import ( // Sends continous messages to all connected WebSocket clients func SendPermanentBroadCastMessage() { - ticker := time.NewTicker(BROADCAST_INTERVAL * time.Second) - defer ticker.Stop() - - for range ticker.C { + tick := time.NewTicker(time.Second) + for { // Marshal the message into JSON format jsonMessage, err := json.Marshal(shared.State) if err != nil { helper.Logger.Error("Error marshalling message:", "msg", err) return } - // Iterate over all connected clients and broadcast the message - mu.Lock() for _, client := range Clients { - // Send message to client - client.Conn.SetWriteDeadline(time.Now().Add(SEND_TIMEOUT * time.Second)) - err := client.Conn.WriteMessage(websocket.TextMessage, jsonMessage) + err := client.SendMessage(websocket.TextMessage, jsonMessage) if err != nil { - helper.Logger.Error("Error broadcasting to client:", "msg", err, "host", client.RealIP, "clients", len(Clients)) + helper.Logger.Error("Error broadcasting to client:", "msg", err) + // The client is responsible for closing itself on error } } - mu.Unlock() + <-tick.C } } diff --git a/internal/websocket/client_commands.go b/internal/websocket/client_commands.go index 7712d08..7149faa 100644 --- a/internal/websocket/client_commands.go +++ b/internal/websocket/client_commands.go @@ -21,10 +21,7 @@ func handleClientCommands(c models.WebsocketClient) { _, message, err := ws.ReadMessage() if err != nil { - // remove client on error/disconnect - mu.Lock() delete(Clients, ws.LocalAddr()) - mu.Unlock() helper.Logger.Info("Client disconnected:", "msg", err, "host", c.RealIP, "clients", len(Clients)) break } diff --git a/internal/websocket/handle_connection.go b/internal/websocket/handle_connection.go index ca55229..9c4b43f 100644 --- a/internal/websocket/handle_connection.go +++ b/internal/websocket/handle_connection.go @@ -2,13 +2,18 @@ package websocket import ( "github.com/gorilla/websocket" + "net" "net/http" - "time" + "sync" "git.smsvc.net/pomodoro/GoTomato/internal/helper" "git.smsvc.net/pomodoro/GoTomato/pkg/models" ) +// Clients is a map of connected WebSocket clients, where each client is represented by the WebsocketClient struct +var Clients = make(map[net.Addr]*models.WebsocketClient) +var mu sync.Mutex // Mutex to protect access to the Clients map + // Upgrade HTTP requests to WebSocket connections var upgrader = websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true }, @@ -26,15 +31,9 @@ func HandleConnection(w http.ResponseWriter, r *http.Request) { // Register the new client client := models.WebsocketClient{ - Conn: ws, - LastPong: time.Now(), - RealIP: r.RemoteAddr, + Conn: ws, + RealIP: r.RemoteAddr, } - client.Conn.SetPongHandler(func(s string) error { - client.LastPong = time.Now() - return nil - }) - mu.Lock() Clients[ws.LocalAddr()] = &client mu.Unlock() diff --git a/internal/websocket/staleClients.go b/internal/websocket/staleClients.go deleted file mode 100644 index b636c49..0000000 --- a/internal/websocket/staleClients.go +++ /dev/null @@ -1,34 +0,0 @@ -package websocket - -import ( - "time" - - "git.smsvc.net/pomodoro/GoTomato/internal/helper" - "git.smsvc.net/pomodoro/GoTomato/pkg/models" - "github.com/gorilla/websocket" -) - -// Check and remove stale clients -func RemoveStaleClients() { - ticker := time.NewTicker(STALE_CHECK_INTERVALL * time.Second) - defer ticker.Stop() - - for range ticker.C { - mu.Lock() - for _, client := range Clients { - client.Conn.SetWriteDeadline(time.Now().Add(SEND_TIMEOUT * time.Second)) - client.Conn.WriteMessage(websocket.PingMessage, nil) - - if isStale(client) { - helper.Logger.Info("Removing stale client", "host", client.RealIP, "lastPong", client.LastPong.Format(time.RFC3339)) - client.Conn.Close() - delete(Clients, client.Conn.LocalAddr()) - } - } - mu.Unlock() - } -} - -func isStale(client *models.WebsocketClient) bool { - return time.Since(client.LastPong) > (STALE_CLIENT_TIMEOUT * time.Second) -} diff --git a/internal/websocket/vars.go b/internal/websocket/vars.go deleted file mode 100644 index ad02180..0000000 --- a/internal/websocket/vars.go +++ /dev/null @@ -1,19 +0,0 @@ -package websocket - -import ( - "net" - "sync" - - "git.smsvc.net/pomodoro/GoTomato/pkg/models" -) - -const BROADCAST_INTERVAL = 1 -const SEND_TIMEOUT = 10 -const STALE_CLIENT_TIMEOUT = 90 -const STALE_CHECK_INTERVALL = 30 - -// Clients is a map of connected WebSocket clients, where each client is represented by the WebsocketClient struct -var Clients = make(map[net.Addr]*models.WebsocketClient) - -// Mutex to protect access to the Clients map -var mu sync.Mutex diff --git a/pkg/models/client.go b/pkg/models/client.go index 9d5dbdb..eef9663 100644 --- a/pkg/models/client.go +++ b/pkg/models/client.go @@ -1,9 +1,9 @@ package models import ( - "time" - "github.com/gorilla/websocket" + + "git.smsvc.net/pomodoro/GoTomato/internal/helper" ) // Represents a command from the client (start/stop) @@ -15,7 +15,17 @@ type ClientCommand struct { // Represents a single client type WebsocketClient struct { - Conn *websocket.Conn - LastPong time.Time - RealIP string + Conn *websocket.Conn + RealIP string +} + +// Sends a message to the websocket. +// Automatically locks and unlocks the client mutex, to ensure that only one goroutine can write at a time. +func (c *WebsocketClient) SendMessage(messageType int, data []byte) error { + err := c.Conn.WriteMessage(messageType, data) + if err != nil { + helper.Logger.Error("Error writing to WebSocket:", "msg", err) + c.Conn.Close() // Close the connection on error + } + return err }