1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
use crate::messages::{MessageData, MessageFromClient, MessageFromServer};
use actix::{Actor, ActorContext, AsyncContext, Handler, Recipient, SpawnHandle, StreamHandler};
use actix_web_actors::ws;
use jamsocket::ClientId;
use std::time::{Duration, Instant};

/// Represents a connection from a service to a client, which consists of a
/// message receiver and a user ID.
pub struct ClientSocketConnection {
    pub room: Recipient<MessageFromClient>,
    pub client_id: ClientId,
    pub last_seen: Instant,
    pub heartbeat_interval: Duration,
    pub heartbeat_timeout: Duration,
    pub interval_handle: Option<SpawnHandle>,
}

impl ClientSocketConnection {
    fn start_heartbeat_interval(&mut self, ctx: &mut <Self as Actor>::Context) {
        self.interval_handle = Some(ctx.run_interval(self.heartbeat_interval, |act, ctx| {
            if Instant::now() - act.last_seen > act.heartbeat_timeout {
                tracing::warn!(
                    client_id=?act.client_id,
                    "Stopping ClientSocketConnection because heartbeat not responded.",
                );
                act.close(ctx);
            } else {
                ctx.ping(b"");
            }
        }));
    }

    fn close(&self, ctx: &mut ws::WebsocketContext<Self>) {
        self.interval_handle.map(|d| ctx.cancel_future(d));

        if self
            .room
            .do_send(MessageFromClient::Disconnect(self.client_id))
            .is_err()
        {
            tracing::warn!("Could not send Disconnect message before closing room",);
        }

        ctx.stop();
    }
}

impl Actor for ClientSocketConnection {
    type Context = ws::WebsocketContext<Self>;

    fn started(&mut self, ctx: &mut Self::Context) {
        self.start_heartbeat_interval(ctx);
    }
}

impl Handler<MessageFromServer> for ClientSocketConnection {
    type Result = ();

    fn handle(&mut self, msg: MessageFromServer, ctx: &mut Self::Context) {
        match msg.data {
            MessageData::String(st) => ctx.text(st),
            MessageData::Binary(bin) => ctx.binary(bin),
        };
    }
}

impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for ClientSocketConnection {
    fn handle(&mut self, msg: Result<ws::Message, ws::ProtocolError>, ctx: &mut Self::Context) {
        match msg {
            Ok(ws::Message::Ping(msg)) => ctx.pong(&msg),
            Ok(ws::Message::Pong(_)) => self.last_seen = Instant::now(),
            Ok(ws::Message::Text(text)) => {
                let message = MessageFromClient::Message {
                    from_client: self.client_id,
                    data: MessageData::String(text.to_string()),
                };
                if self.room.do_send(message).is_err() {
                    tracing::warn!("Error forwarding message to service",);
                }
            }
            Ok(ws::Message::Binary(data)) => {
                let message = MessageFromClient::Message {
                    from_client: self.client_id,
                    data: MessageData::Binary(data.to_vec()),
                };
                if self.room.do_send(message).is_err() {
                    tracing::warn!("Error forwarding binary message to service",);
                }
            }
            Ok(ws::Message::Close(_)) => {
                tracing::info!(
                    client_id=?self.client_id,
                    "User has disconnected from room",
                );

                self.close(ctx);
            }
            Err(error) => tracing::error!(?error, "Encountered error in StreamHandler"),
            _ => tracing::warn!(message=?msg, "Unhandled message in StreamHandler"),
        }
    }
}