zero_trust_rps/common/client/connection/
conn.rs

1use std::{num::NonZeroU64, ops::Div, time::Duration};
2
3use futures::future::join3;
4use tokio::{
5    sync::mpsc::channel,
6    time::{sleep, sleep_until, Instant},
7};
8
9use crate::{
10    common::{
11        client::{
12            channel::{AsyncChannelReceiver, AsyncChannelSender, SendError},
13            do_moves::{do_move, MoveError},
14            state::{ClientState, ClientStateView},
15            update::handle_new_room_state,
16        },
17        connection::{Reader, WriteMessageError, Writer},
18        message::ClientMessage,
19        rps::simple_move::SimpleUserMove,
20    },
21    log_result,
22};
23
24use super::{
25    read::{read_server_messages, ReadMessagesError, SimplifiedServerMessage},
26    write::send_messages,
27};
28
29#[derive(thiserror::Error, Debug)]
30#[allow(clippy::enum_variant_names)]
31pub enum RunClientError {
32    #[error("{}", .0)]
33    WriteError(#[from] WriteMessageError),
34    #[error("{}", .0)]
35    ReadError(#[from] ReadMessagesError),
36    #[error("{}", .0)]
37    ManageStateError(#[from] ManageStateError),
38    #[error("{}", .0)]
39    TokioError(#[from] tokio::task::JoinError),
40}
41
42#[inline]
43pub async fn run_client(
44    timout: Option<NonZeroU64>,
45    writer: impl Writer + 'static,
46    reader: impl Reader + 'static,
47    repeat: impl AsyncChannelSender<SimplifiedServerMessage> + 'static + Clone,
48    states: impl AsyncChannelSender<ClientStateView> + 'static,
49    umoves: impl AsyncChannelReceiver<SimpleUserMove> + 'static,
50) -> Result<(), RunClientError> {
51    let (smsg_send, smsg_recv) = channel::<SimplifiedServerMessage>(1);
52    let (cmsg_send, cmsg_recv) = channel::<ClientMessage>(1);
53
54    let (a, b, c) = join3(
55        tokio::spawn(log_result!(manage_state(
56            timout, smsg_recv, states, repeat, umoves, cmsg_send
57        ))),
58        tokio::spawn(log_result!(send_messages(writer, cmsg_recv))),
59        tokio::spawn(log_result!(read_server_messages(reader, smsg_send))),
60    )
61    .await;
62
63    let _: () = a??;
64    let _: () = b??;
65    let _: () = c??;
66
67    Ok(())
68}
69
70#[derive(thiserror::Error, Debug)]
71#[allow(clippy::enum_variant_names)]
72pub enum ManageStateError {
73    #[error("Failed to send to channel: {}", .0)]
74    SendError(#[from] SendError),
75    #[error("{}", .0)]
76    MoveError(#[from] MoveError),
77    #[error("{}", .0)]
78    UpdateError(String),
79}
80
81impl From<String> for ManageStateError {
82    fn from(value: String) -> Self {
83        ManageStateError::UpdateError(value)
84    }
85}
86
87async fn manage_state(
88    timout: Option<NonZeroU64>,
89    smesgs: impl AsyncChannelReceiver<SimplifiedServerMessage> + 'static,
90    output: impl AsyncChannelSender<ClientStateView> + 'static,
91    repeat: impl AsyncChannelSender<SimplifiedServerMessage> + 'static + Clone,
92    umoves: impl AsyncChannelReceiver<SimpleUserMove> + 'static,
93    sender: impl AsyncChannelSender<ClientMessage> + 'static + Clone,
94) -> Result<(), ManageStateError> {
95    let mut smesgs = smesgs;
96    let mut umoves = umoves;
97
98    let timeout = timout.map(NonZeroU64::get).map(Duration::from_secs);
99
100    let mut expected_pong: Option<u8> = None;
101    let mut state: ClientState = Default::default();
102
103    output.send(Box::new(state.clone().into())).await?; // send state asap!
104
105    loop {
106        let (ping_in_fut, timeout_fut) = if let Some(timeout) = timeout {
107            if state.timed_out {
108                (sleep(Duration::MAX), sleep(Duration::MAX))
109            } else {
110                let timing_out_at = state
111                    .last_server_message
112                    .checked_add(timeout)
113                    .ok_or_else(|| format!("{timeout:?} is too long"))?;
114                if expected_pong.is_some() {
115                    (sleep(Duration::MAX), sleep_until(timing_out_at))
116                } else {
117                    let send_ping_at = state
118                        .last_server_message
119                        .checked_add(timeout.div(2))
120                        .ok_or_else(|| format!("{timeout:?} is too long"))?;
121                    (sleep_until(send_ping_at), sleep_until(timing_out_at))
122                }
123            }
124        } else {
125            (sleep(Duration::MAX), sleep(Duration::MAX))
126        };
127        let _: () = tokio::select! {
128            _ = ping_in_fut => {
129                let c: u8 = rand::random();
130                sender.send(ClientMessage::Ping { c }).await?;
131                expected_pong = Some(c);
132            },
133            _ = timeout_fut => {
134                log::warn!("timed out");
135                state.timed_out = true;
136                output.send(Box::new(state.clone().into())).await?;
137            },
138            Some(mesg) = smesgs.receive() => {
139                log::trace!("manage_state got {mesg:?}");
140                update_state(&mut state, &mut expected_pong, &mesg)?;
141                log::trace!("updated state");
142
143                output.send(Box::new(state.clone().into())).await?;
144                log::trace!("send state");
145                repeat.send(mesg).await?;
146                log::trace!("relayed server msg")
147            },
148            Some(umove) = umoves.receive() => {
149                log::trace!("User wants to move: {umove:?}");
150
151                do_move(&mut state, umove, sender.clone(), repeat.clone()).await?;
152            },
153            else => {
154                log::warn!("Returning from manage_state");
155                return Ok(());
156            }
157        };
158    }
159}
160
161fn update_state(
162    state: &mut ClientState,
163    expected_pong: &mut Option<u8>,
164    server_msg: &SimplifiedServerMessage,
165) -> Result<(), String> {
166    state.last_server_message = Instant::now();
167    state.timed_out = false;
168    match server_msg {
169        SimplifiedServerMessage::NewRoomState(room_state) => {
170            handle_new_room_state(state, room_state)
171        }
172        SimplifiedServerMessage::Pong(pong) => match expected_pong.take() {
173            Some(p) if p == *pong => Ok(()),
174            Some(expected) => Err(format!("Got invalid pong {pong}, but expected {expected}")),
175            None => Err("Didn't expect pong".into()),
176        },
177        SimplifiedServerMessage::Error(error) => {
178            log::error!("got error: {error:?}");
179            state.last_error = Some(error.into());
180            Ok(())
181        }
182    }
183}