Skip to main content

trustless_protocol/
codec.rs

1/// Wrap an `AsyncRead` with length-delimited codec framing for reading protocol messages.
2pub fn framed_read<R: tokio::io::AsyncRead>(
3    reader: R,
4) -> tokio_util::codec::FramedRead<R, tokio_util::codec::LengthDelimitedCodec> {
5    tokio_util::codec::FramedRead::new(reader, tokio_util::codec::LengthDelimitedCodec::new())
6}
7
8/// Wrap an `AsyncWrite` with length-delimited codec framing for writing protocol messages.
9pub fn framed_write<W: tokio::io::AsyncWrite>(
10    writer: W,
11) -> tokio_util::codec::FramedWrite<W, tokio_util::codec::LengthDelimitedCodec> {
12    tokio_util::codec::FramedWrite::new(writer, tokio_util::codec::LengthDelimitedCodec::new())
13}
14
15/// Serialize a message as JSON and send it over a framed writer.
16pub async fn send_message<W>(
17    writer: &mut tokio_util::codec::FramedWrite<W, tokio_util::codec::LengthDelimitedCodec>,
18    msg: &impl serde::Serialize,
19) -> Result<(), crate::error::Error>
20where
21    W: tokio::io::AsyncWrite + Unpin,
22{
23    use futures_util::SinkExt as _;
24
25    let json = serde_json::to_vec(msg)?;
26    writer.send(bytes::Bytes::from(json)).await?;
27    Ok(())
28}
29
30/// Read and deserialize a JSON message from a framed reader.
31///
32/// Returns [`Error::ProcessExited`](crate::error::Error::ProcessExited) when the stream reaches EOF.
33pub async fn recv_message<R, M>(
34    reader: &mut tokio_util::codec::FramedRead<R, tokio_util::codec::LengthDelimitedCodec>,
35) -> Result<M, crate::error::Error>
36where
37    R: tokio::io::AsyncRead + Unpin,
38    M: serde::de::DeserializeOwned,
39{
40    use futures_util::StreamExt as _;
41
42    let frame = reader
43        .next()
44        .await
45        .ok_or(crate::error::Error::ProcessExited)??;
46    let msg = serde_json::from_slice(&frame)?;
47    Ok(msg)
48}
49
50#[cfg(test)]
51mod tests {
52    use secrecy::ExposeSecret as _;
53
54    #[tokio::test]
55    async fn round_trip_message() {
56        let (client, server) = tokio::io::duplex(4096);
57        let (read_half, write_half) = tokio::io::split(server);
58        let (client_read, client_write) = tokio::io::split(client);
59
60        let mut writer = super::framed_write(client_write);
61        let mut reader = super::framed_read(read_half);
62
63        let request = crate::message::Request::Initialize {
64            id: 1,
65            params: crate::message::InitializeParams {},
66        };
67        super::send_message(&mut writer, &request).await.unwrap();
68
69        let received: crate::message::Request = super::recv_message(&mut reader).await.unwrap();
70        assert_eq!(received.id(), 1);
71        assert!(matches!(
72            received,
73            crate::message::Request::Initialize { .. }
74        ));
75
76        // Send a response back
77        let mut server_writer = super::framed_write(write_half);
78        let mut client_reader = super::framed_read(client_read);
79
80        let response =
81            crate::message::Response::Success(crate::message::SuccessResponse::Initialize {
82                id: 1,
83                result: crate::message::InitializeResult {
84                    default: "cert1".to_owned(),
85                    certificates: vec![],
86                },
87            });
88        super::send_message(&mut server_writer, &response)
89            .await
90            .unwrap();
91
92        let received: crate::message::Response =
93            super::recv_message(&mut client_reader).await.unwrap();
94        assert_eq!(received.id(), 1);
95        match received {
96            crate::message::Response::Success(crate::message::SuccessResponse::Initialize {
97                result,
98                ..
99            }) => {
100                assert_eq!(result.default, "cert1");
101            }
102            _ => panic!("expected Initialize Result"),
103        }
104    }
105
106    #[tokio::test]
107    async fn eof_returns_process_exited() {
108        let (client, server) = tokio::io::duplex(4096);
109        drop(client);
110        let mut reader = super::framed_read(server);
111        let result: Result<crate::message::Request, _> = super::recv_message(&mut reader).await;
112        assert!(matches!(result, Err(crate::error::Error::ProcessExited)));
113    }
114
115    #[tokio::test]
116    async fn multiple_messages_in_sequence() {
117        let (client, server) = tokio::io::duplex(4096);
118        let (server_read, _server_write) = tokio::io::split(server);
119        let (client_read, client_write) = tokio::io::split(client);
120        let _ = client_read;
121
122        let mut writer = super::framed_write(client_write);
123        let mut reader = super::framed_read(server_read);
124
125        for i in 1..=5 {
126            let req = crate::message::Request::Sign {
127                id: i,
128                params: crate::message::SignParams {
129                    certificate_id: format!("cert{i}"),
130                    scheme: "ECDSA_NISTP256_SHA256".to_owned(),
131                    blob: crate::message::Base64Bytes::from(vec![i as u8; 16]).into_secret(),
132                },
133            };
134            super::send_message(&mut writer, &req).await.unwrap();
135        }
136
137        for i in 1..=5 {
138            let received: crate::message::Request = super::recv_message(&mut reader).await.unwrap();
139            assert_eq!(received.id(), i);
140            match &received {
141                crate::message::Request::Sign { params, .. } => {
142                    assert_eq!(params.certificate_id, format!("cert{i}"));
143                    assert_eq!(
144                        params.blob.expose_secret().as_slice(),
145                        vec![i as u8; 16].as_slice()
146                    );
147                }
148                _ => panic!("expected Sign"),
149            }
150        }
151    }
152}