vapour_protocol/
message.rs1use std::io::{Cursor, Read};
2
3use bytes::{BufMut, Bytes, BytesMut};
4use flate2::read::GzDecoder;
5use prost::Message;
6
7use crate::{
8 emsg::{EMsg, PROTO_MASK},
9 error::{Error, Result},
10 protobuf::{CMsgMulti, CMsgProtoBufHeader},
11};
12
13pub const NO_JOB_ID: u64 = u64::MAX;
14
15#[derive(Debug, Clone, PartialEq, Eq)]
16pub struct Packet {
17 pub emsg: u32,
18 pub header: CMsgProtoBufHeader,
19 pub body: Bytes,
20}
21
22impl Packet {
23 pub fn jobid_target(&self) -> Option<u64> {
24 self.header
25 .jobid_target
26 .filter(|job_id| *job_id != NO_JOB_ID)
27 }
28
29 pub fn jobid_source(&self) -> Option<u64> {
30 self.header
31 .jobid_source
32 .filter(|job_id| *job_id != NO_JOB_ID)
33 }
34
35 pub fn target_job_name(&self) -> Option<&str> {
36 self.header.target_job_name.as_deref()
37 }
38
39 pub fn decode_body<M>(&self) -> Result<M>
40 where
41 M: Message + Default,
42 {
43 M::decode(self.body.clone()).map_err(Error::from)
44 }
45}
46
47pub fn encode_message<M>(emsg: EMsg, header: &CMsgProtoBufHeader, body: &M) -> Result<Bytes>
48where
49 M: Message,
50{
51 let body_bytes = body.encode_to_vec();
52 encode_raw(emsg.protobuf(), header, &body_bytes)
53}
54
55pub fn encode_raw(emsg: u32, header: &CMsgProtoBufHeader, body: &[u8]) -> Result<Bytes> {
56 let header_bytes = header.encode_to_vec();
57 let mut frame = BytesMut::with_capacity(8 + header_bytes.len() + body.len());
58 frame.put_u32_le(emsg);
59 frame.put_u32_le(header_bytes.len() as u32);
60 frame.extend_from_slice(&header_bytes);
61 frame.extend_from_slice(body);
62 Ok(frame.freeze())
63}
64
65pub fn decode_frame(frame: &[u8]) -> Result<Vec<Packet>> {
66 if frame.len() >= 4 {
67 let raw_emsg = u32::from_le_bytes(frame[0..4].try_into().expect("slice length checked"));
68 if raw_emsg & PROTO_MASK == 0 {
69 tracing::trace!(raw_emsg, "skipping non-protobuf frame");
70 return Ok(vec![]);
71 }
72 }
73 let packet = decode_packet(frame)?;
74 expand_packet(packet)
75}
76
77pub fn decode_packet(frame: &[u8]) -> Result<Packet> {
78 if frame.len() < 8 {
79 return Err(Error::InvalidPacket("frame too short"));
80 }
81
82 let raw_emsg = u32::from_le_bytes(frame[0..4].try_into().expect("slice length checked"));
83 if raw_emsg & PROTO_MASK == 0 {
84 return Err(Error::InvalidPacket("non-protobuf packet is unsupported"));
85 }
86
87 let header_len =
88 u32::from_le_bytes(frame[4..8].try_into().expect("slice length checked")) as usize;
89 if frame.len() < 8 + header_len {
90 return Err(Error::InvalidPacket("truncated protobuf header"));
91 }
92
93 let header = CMsgProtoBufHeader::decode(&frame[8..8 + header_len])?;
94 let body = Bytes::copy_from_slice(&frame[8 + header_len..]);
95
96 Ok(Packet {
97 emsg: raw_emsg & !PROTO_MASK,
98 header,
99 body,
100 })
101}
102
103fn expand_packet(packet: Packet) -> Result<Vec<Packet>> {
104 if packet.emsg != EMsg::Multi.raw() {
105 return Ok(vec![packet]);
106 }
107
108 let multi = packet.decode_body::<CMsgMulti>()?;
109 let payload = multi
110 .message_body
111 .ok_or(Error::MissingField("CMsgMulti.message_body"))?;
112
113 let data = if multi.size_unzipped.unwrap_or_default() > 0 {
114 let mut decoder = GzDecoder::new(payload.as_slice());
115 let mut uncompressed = Vec::with_capacity(multi.size_unzipped.unwrap_or_default() as usize);
116 decoder.read_to_end(&mut uncompressed)?;
117 uncompressed
118 } else {
119 payload
120 };
121
122 split_multi_payload(&data)
123}
124
125fn split_multi_payload(payload: &[u8]) -> Result<Vec<Packet>> {
126 let mut cursor = Cursor::new(payload);
127 let mut packets = Vec::new();
128
129 while (cursor.position() as usize) < payload.len() {
130 let mut len_bytes = [0_u8; 4];
131 cursor.read_exact(&mut len_bytes)?;
132 let packet_len = u32::from_le_bytes(len_bytes) as usize;
133 if packet_len == 0 {
134 return Err(Error::InvalidPacket("multi payload contained empty packet"));
135 }
136
137 let start = cursor.position() as usize;
138 let end = start + packet_len;
139 if end > payload.len() {
140 return Err(Error::InvalidPacket("multi payload packet length overflow"));
141 }
142
143 packets.extend(decode_frame(&payload[start..end])?);
144 cursor.set_position(end as u64);
145 }
146
147 Ok(packets)
148}
149
150#[cfg(test)]
151mod tests {
152 use std::io::Write;
153
154 use flate2::{Compression, write::GzEncoder};
155
156 use super::{decode_frame, encode_message};
157 use crate::{
158 emsg::EMsg,
159 protobuf::{CMsgClientHeartBeat, CMsgMulti, CMsgProtoBufHeader},
160 };
161
162 #[test]
163 fn non_protobuf_frame_is_skipped() {
164 let mut frame = Vec::new();
167 frame.extend_from_slice(&703u32.to_le_bytes()); frame.extend_from_slice(&[0u8; 32]);
169 let result = decode_frame(&frame).unwrap();
170 assert!(result.is_empty());
171 }
172
173 #[test]
174 fn protobuf_packet_roundtrip() {
175 let header = CMsgProtoBufHeader {
176 steamid: Some(76561197960287930),
177 client_sessionid: Some(42),
178 jobid_source: Some(7),
179 target_job_name: Some("Authentication.BeginAuthSessionViaQR#1".to_owned()),
180 ..Default::default()
181 };
182 let body = CMsgClientHeartBeat {
183 send_reply: Some(true),
184 };
185
186 let encoded = encode_message(EMsg::ClientHeartBeat, &header, &body).unwrap();
187 let decoded = decode_frame(&encoded).unwrap();
188 assert_eq!(decoded.len(), 1);
189 assert_eq!(decoded[0].emsg, EMsg::ClientHeartBeat.raw());
190 assert_eq!(decoded[0].header.client_sessionid, Some(42));
191 assert_eq!(decoded[0].header.jobid_source, Some(7));
192
193 let decoded_body = decoded[0].decode_body::<CMsgClientHeartBeat>().unwrap();
194 assert_eq!(decoded_body.send_reply, Some(true));
195 }
196
197 #[test]
198 fn multi_packet_split_handles_gzip_payload() {
199 let packet_a = encode_message(
200 EMsg::ClientHeartBeat,
201 &CMsgProtoBufHeader {
202 jobid_source: Some(1),
203 ..Default::default()
204 },
205 &CMsgClientHeartBeat {
206 send_reply: Some(false),
207 },
208 )
209 .unwrap();
210 let packet_b = encode_message(
211 EMsg::ClientHeartBeat,
212 &CMsgProtoBufHeader {
213 jobid_source: Some(2),
214 ..Default::default()
215 },
216 &CMsgClientHeartBeat {
217 send_reply: Some(true),
218 },
219 )
220 .unwrap();
221
222 let mut payload = Vec::new();
223 payload.extend_from_slice(&(packet_a.len() as u32).to_le_bytes());
224 payload.extend_from_slice(&packet_a);
225 payload.extend_from_slice(&(packet_b.len() as u32).to_le_bytes());
226 payload.extend_from_slice(&packet_b);
227
228 let mut gzip = GzEncoder::new(Vec::new(), Compression::default());
229 gzip.write_all(&payload).unwrap();
230 let compressed = gzip.finish().unwrap();
231
232 let multi = CMsgMulti {
233 size_unzipped: Some(payload.len() as u32),
234 message_body: Some(compressed),
235 };
236 let encoded_multi =
237 encode_message(EMsg::Multi, &CMsgProtoBufHeader::default(), &multi).unwrap();
238
239 let decoded = decode_frame(&encoded_multi).unwrap();
240 assert_eq!(decoded.len(), 2);
241 assert_eq!(decoded[0].header.jobid_source, Some(1));
242 assert_eq!(decoded[1].header.jobid_source, Some(2));
243 }
244}