diff --git a/internal/websocket/broadcast.go b/internal/websocket/broadcast.go index 2165787..aadfab6 100644 --- a/internal/websocket/broadcast.go +++ b/internal/websocket/broadcast.go @@ -22,11 +22,12 @@ func SendPermanentBroadCastMessage() { // Iterate over all connected clients and broadcast the message for _, client := range Clients { - if time.Now().Sub(client.LastPong) > (10 * time.Second) { - // Remove unresponsive client + // Remove unresponsive client + if client.IsStale() { client.Conn.Close() continue } + // Send message to client err := client.SendMessage(websocket.TextMessage, jsonMessage) if err != nil { helper.Logger.Info("Error broadcasting to client:", "msg", err, "host", client.RealIP, "clients", len(Clients)) diff --git a/internal/websocket/handle_connection.go b/internal/websocket/handle_connection.go index 2a0f414..360081e 100644 --- a/internal/websocket/handle_connection.go +++ b/internal/websocket/handle_connection.go @@ -36,6 +36,12 @@ func HandleConnection(w http.ResponseWriter, r *http.Request) { LastPong: time.Now(), RealIP: r.RemoteAddr, } + + client.Conn.SetPongHandler(func(appData string) error { + client.LastPong = time.Now() + return nil + }) + mu.Lock() Clients[ws.LocalAddr()] = &client mu.Unlock() diff --git a/pkg/models/client.go b/pkg/models/client.go index 00979c0..632ea34 100644 --- a/pkg/models/client.go +++ b/pkg/models/client.go @@ -24,12 +24,22 @@ type WebsocketClient struct { // Sends a message to the websocket. func (c *WebsocketClient) SendMessage(messageType int, data []byte) error { - c.Conn.SetPongHandler(func(s string) error { - c.LastPong = time.Now() - return nil - }) + c.Conn.SetWriteDeadline(time.Now().Add(TIMEOUT * time.Second)) + err := c.Conn.WriteMessage(websocket.PingMessage, nil) + if err != nil { + return err + } c.Conn.SetWriteDeadline(time.Now().Add(TIMEOUT * time.Second)) - c.Conn.WriteMessage(websocket.PingMessage, nil) - return c.Conn.WriteMessage(messageType, data) + err = c.Conn.WriteMessage(messageType, data) + if err != nil { + return err + } + + return nil +} + +// Check if websockets last Pong is recent +func (c *WebsocketClient) IsStale() bool { + return time.Since(c.LastPong) > (TIMEOUT * time.Second) }