1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license OR Apache 2.0

#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum NextExpected {
  LengthSpecifier,
  Content { length: usize },
}

#[derive(thiserror::Error, Debug)]
pub enum ReadError {
  #[error("Frame length exceeded expectation of {expected} bytes with {received}")]
  MaxLengthExceeded { expected: usize, received: usize },
  #[error("Unexpected end of frame; expected {expected:?}")]
  UnexpectedEnd {
    expected: NextExpected,
    error: ::std::io::Error,
  },
}

#[derive(thiserror::Error, Debug)]
pub enum JsonReadError {
  #[error("Failure reading JSON from frame: {0}")]
  Read(#[from] ReadError),
  #[error("Failure deserializing JSON from frame: {0}")]
  Deserialization(#[from] ::serde_json::Error),
}

#[derive(thiserror::Error, Debug)]
pub enum WriteError {
  #[error("Frame write failure: {0:?}")]
  UnexpectedEnd(#[from] ::std::io::Error),
}

#[derive(thiserror::Error, Debug)]
pub enum JsonWriteError {
  #[error("Failure writing JSON into frame: {0}")]
  Write(#[from] WriteError),
  #[error("Failure serializing JSON for frame: {0}")]
  Serialization(#[from] ::serde_json::Error),
  /// Since the output is generated automatically, we return before
  /// risking corruption of the stream, skipping any write actions.
  ///
  /// Will never occur when a maximum length of `None` is provided.
  #[error("Frame length exceeded expectation of {expected} bytes with {produced}")]
  MaxLengthExceeded { expected: usize, produced: usize },
}

pub async fn read_frame<T: tokio::io::AsyncRead + Unpin>(
  mut s: T,
  max_length: Option<usize>,
) -> Result<Vec<u8>, ReadError> {
  use tokio::io::AsyncReadExt;
  let length = s
    .read_u32()
    .await
    .map_err(|error| ReadError::UnexpectedEnd {
      expected: NextExpected::LengthSpecifier,
      error,
    })? as usize;
  if let Some(max_length) = max_length {
    if length > max_length {
      return Err(ReadError::MaxLengthExceeded {
        expected: max_length,
        received: length,
      });
    }
  }
  let mut buffer = Vec::with_capacity(length);
  buffer.resize_with(length, Default::default);
  s.read_exact(buffer.as_mut_slice())
    .await
    .map_err(|error| ReadError::UnexpectedEnd {
      expected: NextExpected::Content { length },
      error,
    })?;
  Ok(buffer)
}

pub async fn write_frame<T: tokio::io::AsyncWrite + Unpin>(
  mut s: T,
  buffer: &[u8],
) -> Result<(), WriteError> {
  use tokio::io::AsyncWriteExt;
  s.write_u32(buffer.len() as u32).await?;
  Ok(s.write_all(&buffer).await?)
}

pub async fn read_framed_json<
  TStream: tokio::io::AsyncRead + Unpin,
  TOutput: serde::de::DeserializeOwned,
>(
  s: TStream,
  max_length: Option<usize>,
) -> Result<TOutput, JsonReadError> {
  let buffer = read_frame(s, max_length).await?;
  let x = serde_json::from_slice::<TOutput>(&buffer)?;
  Ok(x)
}

pub async fn write_framed_json<TStream: tokio::io::AsyncWrite + Unpin, TInput: serde::Serialize>(
  s: TStream,
  value: TInput,
  max_length: Option<usize>,
) -> Result<(), JsonWriteError> {
  const U32_SIZE: usize = std::mem::size_of::<u32>();
  let buffer = serde_json::to_vec(&value)?.into_boxed_slice(); // Drop the ability to resize the buffer
  if let Some(max_length) = max_length {
    if buffer.len() + U32_SIZE > max_length {
      return Err(JsonWriteError::MaxLengthExceeded {
        expected: max_length,
        produced: buffer.len() + U32_SIZE,
      });
    }
  }
  Ok(write_frame(s, &buffer).await?)
}

#[cfg(test)]
mod tests {
  use std::assert_matches::assert_matches;

  use super::{read_framed_json, write_framed_json, JsonWriteError};

  #[tokio::test]
  async fn stream_framed_roundtrip() {
    use super::{read_frame, write_frame};
    const TEST_BLOB_LENGTH: usize = 1234;
    let mut buffer: Vec<u8> = Vec::with_capacity(TEST_BLOB_LENGTH + std::mem::size_of::<u32>());
    {
      let mut cursor = std::io::Cursor::new(&mut buffer);
      // Test data is a simple array of 0 through (but not including) its capacity
      let test_data = {
        let mut test_data = Vec::with_capacity(TEST_BLOB_LENGTH);
        test_data.extend(
          (0u32..(test_data.capacity() as u32))
            .map(|x| std::ops::Rem::rem(x, std::u8::MAX as u32) as u8),
        );
        test_data
      };
      write_frame(&mut cursor, &test_data)
        .await
        .expect("Writing frame to stream must succeed");
      cursor.set_position(0);
      let deserialized = read_frame(&mut cursor, None)
        .await
        .expect("Reading frame from stream must succeed");
      // Input and output data should be the same
      assert_eq!(test_data, deserialized);
      // After the length of a u32, the stream should be equal to the content
      assert_eq!(&buffer[std::mem::size_of::<u32>()..], &test_data[..]);
    }
    // Stream must receive content of equal length to a u32 plus that of the content
    assert_eq!(buffer.len(), TEST_BLOB_LENGTH + std::mem::size_of::<u32>());
    // Verify function on zero-length frames
    buffer.clear();
    {
      let mut cursor = std::io::Cursor::new(&mut buffer);
      // Test data is an empty array
      let test_data = Vec::new();
      write_frame(&mut cursor, &test_data).await.unwrap();
      cursor.set_position(0);
      let result = read_frame(&mut cursor, None).await.unwrap();
      assert_eq!(&test_data, &result);
    }
    assert_eq!(buffer.len(), std::mem::size_of::<u32>());
  }

  #[tokio::test]
  async fn exceeding_maximum_length_is_no_op() {
    let mut buffer: Vec<u8> = Vec::with_capacity(0);
    {
      // a single-character string in JSON UTF-8 is 3 bytes long due to quotes
      assert_matches!(
        write_framed_json(&mut buffer, "a", Some(std::mem::size_of::<u32>() + 2)).await,
        Err(JsonWriteError::MaxLengthExceeded { .. })
      );
    }
    assert_eq!(
      buffer.len(),
      0,
      "Buffer must not have been written to during a max length error"
    );
  }

  #[tokio::test]
  async fn stream_json_serialization_roundtrip() {
    let buffer: Vec<u8> = Vec::new();
    let mut cursor = std::io::Cursor::new(buffer);
    let original = (6f32, String::from("a"), 2u8, 12f64);
    write_framed_json(&mut cursor, &original, None)
      .await
      .expect("Writing to stream must succeed");
    cursor.set_position(0);
    let deserialized = read_framed_json(&mut cursor, None)
      .await
      .expect("Reading header from stream must succeed");
    assert_eq!(original, deserialized);
  }
}