zero_trust_rps/common/client/connection/
conn.rs1use std::{num::NonZeroU64, ops::Div, time::Duration};
2
3use futures::future::join3;
4use tokio::{
5 sync::mpsc::channel,
6 time::{sleep, sleep_until, Instant},
7};
8
9use crate::{
10 common::{
11 client::{
12 channel::{AsyncChannelReceiver, AsyncChannelSender, SendError},
13 do_moves::{do_move, MoveError},
14 state::{ClientState, ClientStateView},
15 update::handle_new_room_state,
16 },
17 connection::{Reader, WriteMessageError, Writer},
18 message::ClientMessage,
19 rps::simple_move::SimpleUserMove,
20 },
21 log_result,
22};
23
24use super::{
25 read::{read_server_messages, ReadMessagesError, SimplifiedServerMessage},
26 write::send_messages,
27};
28
29#[derive(thiserror::Error, Debug)]
30#[allow(clippy::enum_variant_names)]
31pub enum RunClientError {
32 #[error("{}", .0)]
33 WriteError(#[from] WriteMessageError),
34 #[error("{}", .0)]
35 ReadError(#[from] ReadMessagesError),
36 #[error("{}", .0)]
37 ManageStateError(#[from] ManageStateError),
38 #[error("{}", .0)]
39 TokioError(#[from] tokio::task::JoinError),
40}
41
42#[inline]
43pub async fn run_client(
44 timout: Option<NonZeroU64>,
45 writer: impl Writer + 'static,
46 reader: impl Reader + 'static,
47 repeat: impl AsyncChannelSender<SimplifiedServerMessage> + 'static + Clone,
48 states: impl AsyncChannelSender<ClientStateView> + 'static,
49 umoves: impl AsyncChannelReceiver<SimpleUserMove> + 'static,
50) -> Result<(), RunClientError> {
51 let (smsg_send, smsg_recv) = channel::<SimplifiedServerMessage>(1);
52 let (cmsg_send, cmsg_recv) = channel::<ClientMessage>(1);
53
54 let (a, b, c) = join3(
55 tokio::spawn(log_result!(manage_state(
56 timout, smsg_recv, states, repeat, umoves, cmsg_send
57 ))),
58 tokio::spawn(log_result!(send_messages(writer, cmsg_recv))),
59 tokio::spawn(log_result!(read_server_messages(reader, smsg_send))),
60 )
61 .await;
62
63 let _: () = a??;
64 let _: () = b??;
65 let _: () = c??;
66
67 Ok(())
68}
69
70#[derive(thiserror::Error, Debug)]
71#[allow(clippy::enum_variant_names)]
72pub enum ManageStateError {
73 #[error("Failed to send to channel: {}", .0)]
74 SendError(#[from] SendError),
75 #[error("{}", .0)]
76 MoveError(#[from] MoveError),
77 #[error("{}", .0)]
78 UpdateError(String),
79}
80
81impl From<String> for ManageStateError {
82 fn from(value: String) -> Self {
83 ManageStateError::UpdateError(value)
84 }
85}
86
87async fn manage_state(
88 timout: Option<NonZeroU64>,
89 smesgs: impl AsyncChannelReceiver<SimplifiedServerMessage> + 'static,
90 output: impl AsyncChannelSender<ClientStateView> + 'static,
91 repeat: impl AsyncChannelSender<SimplifiedServerMessage> + 'static + Clone,
92 umoves: impl AsyncChannelReceiver<SimpleUserMove> + 'static,
93 sender: impl AsyncChannelSender<ClientMessage> + 'static + Clone,
94) -> Result<(), ManageStateError> {
95 let mut smesgs = smesgs;
96 let mut umoves = umoves;
97
98 let timeout = timout.map(NonZeroU64::get).map(Duration::from_secs);
99
100 let mut expected_pong: Option<u8> = None;
101 let mut state: ClientState = Default::default();
102
103 output.send(Box::new(state.clone().into())).await?; loop {
106 let (ping_in_fut, timeout_fut) = if let Some(timeout) = timeout {
107 if state.timed_out {
108 (sleep(Duration::MAX), sleep(Duration::MAX))
109 } else {
110 let timing_out_at = state
111 .last_server_message
112 .checked_add(timeout)
113 .ok_or_else(|| format!("{timeout:?} is too long"))?;
114 if expected_pong.is_some() {
115 (sleep(Duration::MAX), sleep_until(timing_out_at))
116 } else {
117 let send_ping_at = state
118 .last_server_message
119 .checked_add(timeout.div(2))
120 .ok_or_else(|| format!("{timeout:?} is too long"))?;
121 (sleep_until(send_ping_at), sleep_until(timing_out_at))
122 }
123 }
124 } else {
125 (sleep(Duration::MAX), sleep(Duration::MAX))
126 };
127 let _: () = tokio::select! {
128 _ = ping_in_fut => {
129 let c: u8 = rand::random();
130 sender.send(ClientMessage::Ping { c }).await?;
131 expected_pong = Some(c);
132 },
133 _ = timeout_fut => {
134 log::warn!("timed out");
135 state.timed_out = true;
136 output.send(Box::new(state.clone().into())).await?;
137 },
138 Some(mesg) = smesgs.receive() => {
139 log::trace!("manage_state got {mesg:?}");
140 update_state(&mut state, &mut expected_pong, &mesg)?;
141 log::trace!("updated state");
142
143 output.send(Box::new(state.clone().into())).await?;
144 log::trace!("send state");
145 repeat.send(mesg).await?;
146 log::trace!("relayed server msg")
147 },
148 Some(umove) = umoves.receive() => {
149 log::trace!("User wants to move: {umove:?}");
150
151 do_move(&mut state, umove, sender.clone(), repeat.clone()).await?;
152 },
153 else => {
154 log::warn!("Returning from manage_state");
155 return Ok(());
156 }
157 };
158 }
159}
160
161fn update_state(
162 state: &mut ClientState,
163 expected_pong: &mut Option<u8>,
164 server_msg: &SimplifiedServerMessage,
165) -> Result<(), String> {
166 state.last_server_message = Instant::now();
167 state.timed_out = false;
168 match server_msg {
169 SimplifiedServerMessage::NewRoomState(room_state) => {
170 handle_new_room_state(state, room_state)
171 }
172 SimplifiedServerMessage::Pong(pong) => match expected_pong.take() {
173 Some(p) if p == *pong => Ok(()),
174 Some(expected) => Err(format!("Got invalid pong {pong}, but expected {expected}")),
175 None => Err("Didn't expect pong".into()),
176 },
177 SimplifiedServerMessage::Error(error) => {
178 log::error!("got error: {error:?}");
179 state.last_error = Some(error.into());
180 Ok(())
181 }
182 }
183}