zero_trust_rps/common/client/connection/
conn.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
use std::{num::NonZeroU64, ops::Div, time::Duration};

use futures::future::join3;
use tokio::{
    sync::mpsc::channel,
    time::{sleep, sleep_until, Instant},
};

use crate::{
    common::{
        client::{
            channel::{AsyncChannelReceiver, AsyncChannelSender, SendError},
            do_moves::{do_move, MoveError},
            state::{ClientState, ClientStateView},
            update::handle_new_room_state,
        },
        connection::{Reader, WriteMessageError, Writer},
        message::ClientMessage,
        rps::simple_move::SimpleUserMove,
    },
    log_result,
};

use super::{
    read::{read_server_messages, ReadMessagesError, SimplifiedServerMessage},
    write::send_messages,
};

#[derive(thiserror::Error, Debug)]
#[allow(clippy::enum_variant_names)]
pub enum RunClientError {
    #[error("{}", .0)]
    WriteError(#[from] WriteMessageError),
    #[error("{}", .0)]
    ReadError(#[from] ReadMessagesError),
    #[error("{}", .0)]
    ManageStateError(#[from] ManageStateError),
    #[error("{}", .0)]
    TokioError(#[from] tokio::task::JoinError),
}

#[inline]
pub async fn run_client(
    timout: Option<NonZeroU64>,
    writer: impl Writer + 'static,
    reader: impl Reader + 'static,
    repeat: impl AsyncChannelSender<SimplifiedServerMessage> + 'static + Clone,
    states: impl AsyncChannelSender<ClientStateView> + 'static,
    umoves: impl AsyncChannelReceiver<SimpleUserMove> + 'static,
) -> Result<(), RunClientError> {
    let (smsg_send, smsg_recv) = channel::<SimplifiedServerMessage>(1);
    let (cmsg_send, cmsg_recv) = channel::<ClientMessage>(1);

    let (a, b, c) = join3(
        tokio::spawn(log_result!(manage_state(
            timout, smsg_recv, states, repeat, umoves, cmsg_send
        ))),
        tokio::spawn(log_result!(send_messages(writer, cmsg_recv))),
        tokio::spawn(log_result!(read_server_messages(reader, smsg_send))),
    )
    .await;

    let _: () = a??;
    let _: () = b??;
    let _: () = c??;

    Ok(())
}

#[derive(thiserror::Error, Debug)]
#[allow(clippy::enum_variant_names)]
pub enum ManageStateError {
    #[error("Failed to send to channel: {}", .0)]
    SendError(#[from] SendError),
    #[error("{}", .0)]
    MoveError(#[from] MoveError),
    #[error("{}", .0)]
    UpdateError(String),
}

impl From<String> for ManageStateError {
    fn from(value: String) -> Self {
        ManageStateError::UpdateError(value)
    }
}

async fn manage_state(
    timout: Option<NonZeroU64>,
    smesgs: impl AsyncChannelReceiver<SimplifiedServerMessage> + 'static,
    output: impl AsyncChannelSender<ClientStateView> + 'static,
    repeat: impl AsyncChannelSender<SimplifiedServerMessage> + 'static + Clone,
    umoves: impl AsyncChannelReceiver<SimpleUserMove> + 'static,
    sender: impl AsyncChannelSender<ClientMessage> + 'static + Clone,
) -> Result<(), ManageStateError> {
    let mut smesgs = smesgs;
    let mut umoves = umoves;

    let timeout = timout.map(NonZeroU64::get).map(Duration::from_secs);

    let mut expected_pong: Option<u8> = None;
    let mut state: ClientState = Default::default();

    output.send(Box::new(state.clone().into())).await?; // send state asap!

    loop {
        let (ping_in_fut, timeout_fut) = if let Some(timeout) = timeout {
            if state.timed_out {
                (sleep(Duration::MAX), sleep(Duration::MAX))
            } else {
                let timing_out_at = state
                    .last_server_message
                    .checked_add(timeout)
                    .ok_or_else(|| format!("{timeout:?} is too long"))?;
                if expected_pong.is_some() {
                    (sleep(Duration::MAX), sleep_until(timing_out_at))
                } else {
                    let send_ping_at = state
                        .last_server_message
                        .checked_add(timeout.div(2))
                        .ok_or_else(|| format!("{timeout:?} is too long"))?;
                    (sleep_until(send_ping_at), sleep_until(timing_out_at))
                }
            }
        } else {
            (sleep(Duration::MAX), sleep(Duration::MAX))
        };
        let _: () = tokio::select! {
            _ = ping_in_fut => {
                let c: u8 = rand::random();
                sender.send(ClientMessage::Ping { c }).await?;
                expected_pong = Some(c);
            },
            _ = timeout_fut => {
                log::warn!("timed out");
                state.timed_out = true;
                output.send(Box::new(state.clone().into())).await?;
            },
            Some(mesg) = smesgs.receive() => {
                log::trace!("manage_state got {mesg:?}");
                update_state(&mut state, &mut expected_pong, &mesg)?;
                log::trace!("updated state");

                output.send(Box::new(state.clone().into())).await?;
                log::trace!("send state");
                repeat.send(mesg).await?;
                log::trace!("relayed server msg")
            },
            Some(umove) = umoves.receive() => {
                log::trace!("User wants to move: {umove:?}");

                do_move(&mut state, umove, sender.clone(), repeat.clone()).await?;
            },
            else => {
                log::warn!("Returning from manage_state");
                return Ok(());
            }
        };
    }
}

fn update_state(
    state: &mut ClientState,
    expected_pong: &mut Option<u8>,
    server_msg: &SimplifiedServerMessage,
) -> Result<(), String> {
    state.last_server_message = Instant::now();
    state.timed_out = false;
    match server_msg {
        SimplifiedServerMessage::NewRoomState(room_state) => {
            handle_new_room_state(state, room_state)
        }
        SimplifiedServerMessage::Pong(pong) => match expected_pong.take() {
            Some(p) if p == *pong => Ok(()),
            Some(expected) => Err(format!("Got invalid pong {pong}, but expected {expected}")),
            None => Err("Didn't expect pong".into()),
        },
        SimplifiedServerMessage::Error(error) => {
            log::error!("got error: {error:?}");
            state.last_error = Some(error.into());
            Ok(())
        }
    }
}