Skip to main content

spvirit_client/
transport.rs

1use std::time::Duration;
2
3use tokio::io::AsyncReadExt;
4use tokio::net::TcpStream;
5use tokio::time::timeout;
6
7use crate::types::PvGetError;
8use spvirit_codec::epics_decode::{PvaHeader, PvaPacket, PvaPacketCommand};
9
10pub async fn read_packet(
11    stream: &mut TcpStream,
12    timeout_dur: Duration,
13) -> Result<Vec<u8>, PvGetError> {
14    let mut header = [0u8; 8];
15    timeout(timeout_dur, stream.read_exact(&mut header))
16        .await
17        .map_err(|_| PvGetError::Timeout("read header"))??;
18
19    let header_parsed = PvaHeader::new(&header);
20    let payload_len = if header_parsed.flags.is_control {
21        0usize
22    } else {
23        header_parsed.payload_length as usize
24    };
25
26    let mut payload = vec![0u8; payload_len];
27    if payload_len > 0 {
28        timeout(timeout_dur, stream.read_exact(&mut payload))
29            .await
30            .map_err(|_| PvGetError::Timeout("read payload"))??;
31    }
32
33    let mut full = header.to_vec();
34    full.extend_from_slice(&payload);
35    Ok(full)
36}
37
38pub async fn read_until<F>(
39    stream: &mut TcpStream,
40    timeout_dur: Duration,
41    mut predicate: F,
42) -> Result<Vec<u8>, PvGetError>
43where
44    F: FnMut(&PvaPacketCommand) -> bool,
45{
46    let deadline = tokio::time::Instant::now() + timeout_dur;
47    loop {
48        let now = tokio::time::Instant::now();
49        if now >= deadline {
50            return Err(PvGetError::Timeout("read_until"));
51        }
52        let remaining = deadline - now;
53        let bytes = read_packet(stream, remaining).await?;
54        let mut pkt = PvaPacket::new(&bytes);
55        if let Some(cmd) = pkt.decode_payload() {
56            if predicate(&cmd) {
57                return Ok(bytes);
58            }
59        }
60    }
61}