diff --git a/cmd/server/main.go b/cmd/server/main.go index d8eb8f6..78c5c99 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -54,6 +54,7 @@ func Start() { r := http.NewServeMux() r.HandleFunc("/", websocket.HandleConnection) go websocket.SendPermanentBroadCastMessage() + go websocket.RemoveStaleClients() helper.Logger.Info("GoTomato started", "version", metadata.GoTomatoVersion) helper.Logger.Info("Websocket listening", "address", listen) diff --git a/internal/websocket/broadcast.go b/internal/websocket/broadcast.go index def8309..dce2f82 100644 --- a/internal/websocket/broadcast.go +++ b/internal/websocket/broadcast.go @@ -12,6 +12,7 @@ import ( // Sends continous messages to all connected WebSocket clients func SendPermanentBroadCastMessage() { tick := time.NewTicker(time.Second) + for { // Marshal the message into JSON format jsonMessage, err := json.Marshal(shared.State) @@ -19,14 +20,16 @@ func SendPermanentBroadCastMessage() { helper.Logger.Error("Error marshalling message:", "msg", err) return } + // Iterate over all connected clients and broadcast the message for _, client := range Clients { + // Send message to client err := client.SendMessage(websocket.TextMessage, jsonMessage) if err != nil { - helper.Logger.Error("Error broadcasting to client:", "msg", err) - // The client is responsible for closing itself on error + helper.Logger.Error("Error broadcasting to client:", "msg", err, "host", client.RealIP, "clients", len(Clients)) } } + <-tick.C } } diff --git a/internal/websocket/client_commands.go b/internal/websocket/client_commands.go index 7149faa..6f85108 100644 --- a/internal/websocket/client_commands.go +++ b/internal/websocket/client_commands.go @@ -21,6 +21,7 @@ func handleClientCommands(c models.WebsocketClient) { _, message, err := ws.ReadMessage() if err != nil { + // remove client on error/disconnect delete(Clients, ws.LocalAddr()) helper.Logger.Info("Client disconnected:", "msg", err, "host", c.RealIP, "clients", len(Clients)) break diff --git a/internal/websocket/handle_connection.go b/internal/websocket/handle_connection.go index 9c4b43f..ca55229 100644 --- a/internal/websocket/handle_connection.go +++ b/internal/websocket/handle_connection.go @@ -2,18 +2,13 @@ package websocket import ( "github.com/gorilla/websocket" - "net" "net/http" - "sync" + "time" "git.smsvc.net/pomodoro/GoTomato/internal/helper" "git.smsvc.net/pomodoro/GoTomato/pkg/models" ) -// Clients is a map of connected WebSocket clients, where each client is represented by the WebsocketClient struct -var Clients = make(map[net.Addr]*models.WebsocketClient) -var mu sync.Mutex // Mutex to protect access to the Clients map - // Upgrade HTTP requests to WebSocket connections var upgrader = websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true }, @@ -31,9 +26,15 @@ func HandleConnection(w http.ResponseWriter, r *http.Request) { // Register the new client client := models.WebsocketClient{ - Conn: ws, - RealIP: r.RemoteAddr, + Conn: ws, + LastPong: time.Now(), + RealIP: r.RemoteAddr, } + client.Conn.SetPongHandler(func(s string) error { + client.LastPong = time.Now() + return nil + }) + mu.Lock() Clients[ws.LocalAddr()] = &client mu.Unlock() diff --git a/internal/websocket/staleClients.go b/internal/websocket/staleClients.go new file mode 100644 index 0000000..b636c49 --- /dev/null +++ b/internal/websocket/staleClients.go @@ -0,0 +1,34 @@ +package websocket + +import ( + "time" + + "git.smsvc.net/pomodoro/GoTomato/internal/helper" + "git.smsvc.net/pomodoro/GoTomato/pkg/models" + "github.com/gorilla/websocket" +) + +// Check and remove stale clients +func RemoveStaleClients() { + ticker := time.NewTicker(STALE_CHECK_INTERVALL * time.Second) + defer ticker.Stop() + + for range ticker.C { + mu.Lock() + for _, client := range Clients { + client.Conn.SetWriteDeadline(time.Now().Add(SEND_TIMEOUT * time.Second)) + client.Conn.WriteMessage(websocket.PingMessage, nil) + + if isStale(client) { + helper.Logger.Info("Removing stale client", "host", client.RealIP, "lastPong", client.LastPong.Format(time.RFC3339)) + client.Conn.Close() + delete(Clients, client.Conn.LocalAddr()) + } + } + mu.Unlock() + } +} + +func isStale(client *models.WebsocketClient) bool { + return time.Since(client.LastPong) > (STALE_CLIENT_TIMEOUT * time.Second) +} diff --git a/internal/websocket/vars.go b/internal/websocket/vars.go new file mode 100644 index 0000000..5aedb53 --- /dev/null +++ b/internal/websocket/vars.go @@ -0,0 +1,18 @@ +package websocket + +import ( + "net" + "sync" + + "git.smsvc.net/pomodoro/GoTomato/pkg/models" +) + +const SEND_TIMEOUT = 10 +const STALE_CLIENT_TIMEOUT = 90 +const STALE_CHECK_INTERVALL = 30 + +// Clients is a map of connected WebSocket clients, where each client is represented by the WebsocketClient struct +var Clients = make(map[net.Addr]*models.WebsocketClient) + +// Mutex to protect access to the Clients map +var mu sync.Mutex diff --git a/pkg/models/client.go b/pkg/models/client.go index eef9663..84979ee 100644 --- a/pkg/models/client.go +++ b/pkg/models/client.go @@ -1,9 +1,9 @@ package models import ( - "github.com/gorilla/websocket" + "time" - "git.smsvc.net/pomodoro/GoTomato/internal/helper" + "github.com/gorilla/websocket" ) // Represents a command from the client (start/stop) @@ -15,17 +15,13 @@ type ClientCommand struct { // Represents a single client type WebsocketClient struct { - Conn *websocket.Conn - RealIP string + Conn *websocket.Conn + LastPong time.Time + RealIP string } // Sends a message to the websocket. -// Automatically locks and unlocks the client mutex, to ensure that only one goroutine can write at a time. func (c *WebsocketClient) SendMessage(messageType int, data []byte) error { - err := c.Conn.WriteMessage(messageType, data) - if err != nil { - helper.Logger.Error("Error writing to WebSocket:", "msg", err) - c.Conn.Close() // Close the connection on error - } - return err + c.Conn.SetWriteDeadline(time.Now().Add(TIMEOUT * time.Second)) + return c.Conn.WriteMessage(messageType, data) }