zero_trust_rps/common/client/connection/
conn.rsuse std::{num::NonZeroU64, ops::Div, time::Duration};
use futures::future::join3;
use tokio::{
sync::mpsc::channel,
time::{sleep, sleep_until, Instant},
};
use crate::{
common::{
client::{
channel::{AsyncChannelReceiver, AsyncChannelSender, SendError},
do_moves::{do_move, MoveError},
state::{ClientState, ClientStateView},
update::handle_new_room_state,
},
connection::{Reader, WriteMessageError, Writer},
message::ClientMessage,
rps::simple_move::SimpleUserMove,
},
log_result,
};
use super::{
read::{read_server_messages, ReadMessagesError, SimplifiedServerMessage},
write::send_messages,
};
#[derive(thiserror::Error, Debug)]
#[allow(clippy::enum_variant_names)]
pub enum RunClientError {
#[error("{}", .0)]
WriteError(#[from] WriteMessageError),
#[error("{}", .0)]
ReadError(#[from] ReadMessagesError),
#[error("{}", .0)]
ManageStateError(#[from] ManageStateError),
#[error("{}", .0)]
TokioError(#[from] tokio::task::JoinError),
}
#[inline]
pub async fn run_client(
timout: Option<NonZeroU64>,
writer: impl Writer + 'static,
reader: impl Reader + 'static,
repeat: impl AsyncChannelSender<SimplifiedServerMessage> + 'static + Clone,
states: impl AsyncChannelSender<ClientStateView> + 'static,
umoves: impl AsyncChannelReceiver<SimpleUserMove> + 'static,
) -> Result<(), RunClientError> {
let (smsg_send, smsg_recv) = channel::<SimplifiedServerMessage>(1);
let (cmsg_send, cmsg_recv) = channel::<ClientMessage>(1);
let (a, b, c) = join3(
tokio::spawn(log_result!(manage_state(
timout, smsg_recv, states, repeat, umoves, cmsg_send
))),
tokio::spawn(log_result!(send_messages(writer, cmsg_recv))),
tokio::spawn(log_result!(read_server_messages(reader, smsg_send))),
)
.await;
let _: () = a??;
let _: () = b??;
let _: () = c??;
Ok(())
}
#[derive(thiserror::Error, Debug)]
#[allow(clippy::enum_variant_names)]
pub enum ManageStateError {
#[error("Failed to send to channel: {}", .0)]
SendError(#[from] SendError),
#[error("{}", .0)]
MoveError(#[from] MoveError),
#[error("{}", .0)]
UpdateError(String),
}
impl From<String> for ManageStateError {
fn from(value: String) -> Self {
ManageStateError::UpdateError(value)
}
}
async fn manage_state(
timout: Option<NonZeroU64>,
smesgs: impl AsyncChannelReceiver<SimplifiedServerMessage> + 'static,
output: impl AsyncChannelSender<ClientStateView> + 'static,
repeat: impl AsyncChannelSender<SimplifiedServerMessage> + 'static + Clone,
umoves: impl AsyncChannelReceiver<SimpleUserMove> + 'static,
sender: impl AsyncChannelSender<ClientMessage> + 'static + Clone,
) -> Result<(), ManageStateError> {
let mut smesgs = smesgs;
let mut umoves = umoves;
let timeout = timout.map(NonZeroU64::get).map(Duration::from_secs);
let mut expected_pong: Option<u8> = None;
let mut state: ClientState = Default::default();
output.send(Box::new(state.clone().into())).await?; loop {
let (ping_in_fut, timeout_fut) = if let Some(timeout) = timeout {
if state.timed_out {
(sleep(Duration::MAX), sleep(Duration::MAX))
} else {
let timing_out_at = state
.last_server_message
.checked_add(timeout)
.ok_or_else(|| format!("{timeout:?} is too long"))?;
if expected_pong.is_some() {
(sleep(Duration::MAX), sleep_until(timing_out_at))
} else {
let send_ping_at = state
.last_server_message
.checked_add(timeout.div(2))
.ok_or_else(|| format!("{timeout:?} is too long"))?;
(sleep_until(send_ping_at), sleep_until(timing_out_at))
}
}
} else {
(sleep(Duration::MAX), sleep(Duration::MAX))
};
let _: () = tokio::select! {
_ = ping_in_fut => {
let c: u8 = rand::random();
sender.send(ClientMessage::Ping { c }).await?;
expected_pong = Some(c);
},
_ = timeout_fut => {
log::warn!("timed out");
state.timed_out = true;
output.send(Box::new(state.clone().into())).await?;
},
Some(mesg) = smesgs.receive() => {
log::trace!("manage_state got {mesg:?}");
update_state(&mut state, &mut expected_pong, &mesg)?;
log::trace!("updated state");
output.send(Box::new(state.clone().into())).await?;
log::trace!("send state");
repeat.send(mesg).await?;
log::trace!("relayed server msg")
},
Some(umove) = umoves.receive() => {
log::trace!("User wants to move: {umove:?}");
do_move(&mut state, umove, sender.clone(), repeat.clone()).await?;
},
else => {
log::warn!("Returning from manage_state");
return Ok(());
}
};
}
}
fn update_state(
state: &mut ClientState,
expected_pong: &mut Option<u8>,
server_msg: &SimplifiedServerMessage,
) -> Result<(), String> {
state.last_server_message = Instant::now();
state.timed_out = false;
match server_msg {
SimplifiedServerMessage::NewRoomState(room_state) => {
handle_new_room_state(state, room_state)
}
SimplifiedServerMessage::Pong(pong) => match expected_pong.take() {
Some(p) if p == *pong => Ok(()),
Some(expected) => Err(format!("Got invalid pong {pong}, but expected {expected}")),
None => Err("Didn't expect pong".into()),
},
SimplifiedServerMessage::Error(error) => {
log::error!("got error: {error:?}");
state.last_error = Some(error.into());
Ok(())
}
}
}