Skip to main content

vapour_protocol/
message.rs

1use std::io::{Cursor, Read};
2
3use bytes::{BufMut, Bytes, BytesMut};
4use flate2::read::GzDecoder;
5use prost::Message;
6
7use crate::{
8    emsg::{EMsg, PROTO_MASK},
9    error::{Error, Result},
10    protobuf::{CMsgMulti, CMsgProtoBufHeader},
11};
12
13pub const NO_JOB_ID: u64 = u64::MAX;
14
15#[derive(Debug, Clone, PartialEq, Eq)]
16pub struct Packet {
17    pub emsg: u32,
18    pub header: CMsgProtoBufHeader,
19    pub body: Bytes,
20}
21
22impl Packet {
23    pub fn jobid_target(&self) -> Option<u64> {
24        self.header
25            .jobid_target
26            .filter(|job_id| *job_id != NO_JOB_ID)
27    }
28
29    pub fn jobid_source(&self) -> Option<u64> {
30        self.header
31            .jobid_source
32            .filter(|job_id| *job_id != NO_JOB_ID)
33    }
34
35    pub fn target_job_name(&self) -> Option<&str> {
36        self.header.target_job_name.as_deref()
37    }
38
39    pub fn decode_body<M>(&self) -> Result<M>
40    where
41        M: Message + Default,
42    {
43        M::decode(self.body.clone()).map_err(Error::from)
44    }
45}
46
47pub fn encode_message<M>(emsg: EMsg, header: &CMsgProtoBufHeader, body: &M) -> Result<Bytes>
48where
49    M: Message,
50{
51    let body_bytes = body.encode_to_vec();
52    encode_raw(emsg.protobuf(), header, &body_bytes)
53}
54
55pub fn encode_raw(emsg: u32, header: &CMsgProtoBufHeader, body: &[u8]) -> Result<Bytes> {
56    let header_bytes = header.encode_to_vec();
57    let mut frame = BytesMut::with_capacity(8 + header_bytes.len() + body.len());
58    frame.put_u32_le(emsg);
59    frame.put_u32_le(header_bytes.len() as u32);
60    frame.extend_from_slice(&header_bytes);
61    frame.extend_from_slice(body);
62    Ok(frame.freeze())
63}
64
65pub fn decode_frame(frame: &[u8]) -> Result<Vec<Packet>> {
66    if frame.len() >= 4 {
67        let raw_emsg = u32::from_le_bytes(frame[0..4].try_into().expect("slice length checked"));
68        if raw_emsg & PROTO_MASK == 0 {
69            tracing::trace!(raw_emsg, "skipping non-protobuf frame");
70            return Ok(vec![]);
71        }
72    }
73    let packet = decode_packet(frame)?;
74    expand_packet(packet)
75}
76
77pub fn decode_packet(frame: &[u8]) -> Result<Packet> {
78    if frame.len() < 8 {
79        return Err(Error::InvalidPacket("frame too short"));
80    }
81
82    let raw_emsg = u32::from_le_bytes(frame[0..4].try_into().expect("slice length checked"));
83    if raw_emsg & PROTO_MASK == 0 {
84        return Err(Error::InvalidPacket("non-protobuf packet is unsupported"));
85    }
86
87    let header_len =
88        u32::from_le_bytes(frame[4..8].try_into().expect("slice length checked")) as usize;
89    if frame.len() < 8 + header_len {
90        return Err(Error::InvalidPacket("truncated protobuf header"));
91    }
92
93    let header = CMsgProtoBufHeader::decode(&frame[8..8 + header_len])?;
94    let body = Bytes::copy_from_slice(&frame[8 + header_len..]);
95
96    Ok(Packet {
97        emsg: raw_emsg & !PROTO_MASK,
98        header,
99        body,
100    })
101}
102
103fn expand_packet(packet: Packet) -> Result<Vec<Packet>> {
104    if packet.emsg != EMsg::Multi.raw() {
105        return Ok(vec![packet]);
106    }
107
108    let multi = packet.decode_body::<CMsgMulti>()?;
109    let payload = multi
110        .message_body
111        .ok_or(Error::MissingField("CMsgMulti.message_body"))?;
112
113    let data = if multi.size_unzipped.unwrap_or_default() > 0 {
114        let mut decoder = GzDecoder::new(payload.as_slice());
115        let mut uncompressed = Vec::with_capacity(multi.size_unzipped.unwrap_or_default() as usize);
116        decoder.read_to_end(&mut uncompressed)?;
117        uncompressed
118    } else {
119        payload
120    };
121
122    split_multi_payload(&data)
123}
124
125fn split_multi_payload(payload: &[u8]) -> Result<Vec<Packet>> {
126    let mut cursor = Cursor::new(payload);
127    let mut packets = Vec::new();
128
129    while (cursor.position() as usize) < payload.len() {
130        let mut len_bytes = [0_u8; 4];
131        cursor.read_exact(&mut len_bytes)?;
132        let packet_len = u32::from_le_bytes(len_bytes) as usize;
133        if packet_len == 0 {
134            return Err(Error::InvalidPacket("multi payload contained empty packet"));
135        }
136
137        let start = cursor.position() as usize;
138        let end = start + packet_len;
139        if end > payload.len() {
140            return Err(Error::InvalidPacket("multi payload packet length overflow"));
141        }
142
143        packets.extend(decode_frame(&payload[start..end])?);
144        cursor.set_position(end as u64);
145    }
146
147    Ok(packets)
148}
149
150#[cfg(test)]
151mod tests {
152    use std::io::Write;
153
154    use flate2::{Compression, write::GzEncoder};
155
156    use super::{decode_frame, encode_message};
157    use crate::{
158        emsg::EMsg,
159        protobuf::{CMsgClientHeartBeat, CMsgMulti, CMsgProtoBufHeader},
160    };
161
162    #[test]
163    fn non_protobuf_frame_is_skipped() {
164        // Steam sends legacy (non-protobuf) frames after logon. They should be
165        // silently dropped rather than killing the connection.
166        let mut frame = Vec::new();
167        frame.extend_from_slice(&703u32.to_le_bytes()); // ClientHeartBeat without PROTO_MASK
168        frame.extend_from_slice(&[0u8; 32]);
169        let result = decode_frame(&frame).unwrap();
170        assert!(result.is_empty());
171    }
172
173    #[test]
174    fn protobuf_packet_roundtrip() {
175        let header = CMsgProtoBufHeader {
176            steamid: Some(76561197960287930),
177            client_sessionid: Some(42),
178            jobid_source: Some(7),
179            target_job_name: Some("Authentication.BeginAuthSessionViaQR#1".to_owned()),
180            ..Default::default()
181        };
182        let body = CMsgClientHeartBeat {
183            send_reply: Some(true),
184        };
185
186        let encoded = encode_message(EMsg::ClientHeartBeat, &header, &body).unwrap();
187        let decoded = decode_frame(&encoded).unwrap();
188        assert_eq!(decoded.len(), 1);
189        assert_eq!(decoded[0].emsg, EMsg::ClientHeartBeat.raw());
190        assert_eq!(decoded[0].header.client_sessionid, Some(42));
191        assert_eq!(decoded[0].header.jobid_source, Some(7));
192
193        let decoded_body = decoded[0].decode_body::<CMsgClientHeartBeat>().unwrap();
194        assert_eq!(decoded_body.send_reply, Some(true));
195    }
196
197    #[test]
198    fn multi_packet_split_handles_gzip_payload() {
199        let packet_a = encode_message(
200            EMsg::ClientHeartBeat,
201            &CMsgProtoBufHeader {
202                jobid_source: Some(1),
203                ..Default::default()
204            },
205            &CMsgClientHeartBeat {
206                send_reply: Some(false),
207            },
208        )
209        .unwrap();
210        let packet_b = encode_message(
211            EMsg::ClientHeartBeat,
212            &CMsgProtoBufHeader {
213                jobid_source: Some(2),
214                ..Default::default()
215            },
216            &CMsgClientHeartBeat {
217                send_reply: Some(true),
218            },
219        )
220        .unwrap();
221
222        let mut payload = Vec::new();
223        payload.extend_from_slice(&(packet_a.len() as u32).to_le_bytes());
224        payload.extend_from_slice(&packet_a);
225        payload.extend_from_slice(&(packet_b.len() as u32).to_le_bytes());
226        payload.extend_from_slice(&packet_b);
227
228        let mut gzip = GzEncoder::new(Vec::new(), Compression::default());
229        gzip.write_all(&payload).unwrap();
230        let compressed = gzip.finish().unwrap();
231
232        let multi = CMsgMulti {
233            size_unzipped: Some(payload.len() as u32),
234            message_body: Some(compressed),
235        };
236        let encoded_multi =
237            encode_message(EMsg::Multi, &CMsgProtoBufHeader::default(), &multi).unwrap();
238
239        let decoded = decode_frame(&encoded_multi).unwrap();
240        assert_eq!(decoded.len(), 2);
241        assert_eq!(decoded[0].header.jobid_source, Some(1));
242        assert_eq!(decoded[1].header.jobid_source, Some(2));
243    }
244}