import { Middleware, MiddlewareAPI } from "@reduxjs/toolkit";
import { WSMessage, WSMessageStatus } from "@veris-health/user-ms/lib/v1/";
import { getToken } from "../api/utils/localStorage";
import { AppDispatch } from "../store";
import {
  connectionEstablished,
  startConnecting,
  logout,
  setActiveConnections,
  terminateConnection,
  setSocketError,
} from "../features/shared/slices/authSlice";

export const socketMiddleware: Middleware = (store: MiddlewareAPI<AppDispatch>) => {
  let socket: WebSocket;

  return (next) => (action) => {
    const { dispatch } = store;
    if (startConnecting.match(action)) {
      const token = getToken("VERIS_ACCESS_TOKEN");
      socket = new WebSocket(`${import.meta.env.VITE_SOCKET_LOCATION}${token}`);
      socket.onopen = () => {
        dispatch(connectionEstablished());
      };

      socket.onmessage = (event) => {
        const messageData: WSMessage = JSON.parse(event.data);
        dispatch(setActiveConnections(Number(messageData.active_connections)));
        if (messageData.message === WSMessageStatus.Logout) {
          dispatch(logout());
        }
      };

      socket.onclose = () => {
        dispatch(terminateConnection());
      };

      socket.onerror = () => {
        dispatch(setSocketError("Websocket failed to connect."));
      };
    }

    if (logout.match(action)) {
      if (socket) {
        socket.send(JSON.stringify({ message: WSMessageStatus.Logout }));
        socket.close();
      }
      dispatch(terminateConnection());
    }

    next(action);
  };
};
