Skip to main content

vapour_protocol/
connection.rs

1use std::{
2    collections::HashMap,
3    sync::{
4        Arc,
5        atomic::{AtomicU64, Ordering},
6    },
7    time::Duration,
8};
9
10use futures_util::{SinkExt, StreamExt, stream::SplitSink};
11use prost::Message;
12use tokio::{
13    sync::{Mutex, Notify, RwLock, mpsc, oneshot},
14    task::JoinHandle,
15    time,
16};
17use tokio_tungstenite::tungstenite::Message as WebSocketMessage;
18
19use crate::{
20    emsg::EMsg,
21    error::{Error, Result},
22    message::{NO_JOB_ID, Packet, decode_frame, encode_message},
23    protobuf::{CMsgClientHeartBeat, CMsgProtoBufHeader},
24    transport::websocket::{SteamWebSocket, connect},
25};
26
27type PendingJobs = Arc<Mutex<HashMap<u64, oneshot::Sender<Result<Packet>>>>>;
28type PendingStreams = Arc<Mutex<HashMap<u64, mpsc::UnboundedSender<Result<Packet>>>>>;
29type IncomingEvents = mpsc::UnboundedReceiver<Result<Packet>>;
30type WriteHalf = SplitSink<SteamWebSocket, WebSocketMessage>;
31
32#[derive(Debug, Default, Clone)]
33pub struct ConnectionState {
34    pub steamid: Option<u64>,
35    pub client_session_id: Option<i32>,
36    pub heartbeat_seconds: Option<i32>,
37    pub close_reason: Option<String>,
38    pub license_list_received: bool,
39    pub package_ids: Vec<u32>,
40}
41
42#[derive(Debug)]
43pub struct Connection {
44    sender: Arc<Mutex<WriteHalf>>,
45    pending_jobs: PendingJobs,
46    pending_streams: PendingStreams,
47    incoming: IncomingEvents,
48    next_job_id: AtomicU64,
49    state: Arc<RwLock<ConnectionState>>,
50    license_notify: Arc<Notify>,
51    read_task: JoinHandle<()>,
52    heartbeat_task: Option<JoinHandle<()>>,
53}
54
55impl Connection {
56    pub async fn connect(url: &str) -> Result<Self> {
57        let socket = connect(url).await?;
58        let (writer, mut reader) = socket.split();
59        let sender = Arc::new(Mutex::new(writer));
60        let pending_jobs = Arc::new(Mutex::new(
61            HashMap::<u64, oneshot::Sender<Result<Packet>>>::new(),
62        ));
63        let pending_streams = Arc::new(Mutex::new(HashMap::<
64            u64,
65            mpsc::UnboundedSender<Result<Packet>>,
66        >::new()));
67        let state = Arc::new(RwLock::new(ConnectionState::default()));
68        let license_notify = Arc::new(Notify::new());
69        let (incoming_tx, incoming_rx) = mpsc::unbounded_channel();
70        let pending_jobs_for_read = Arc::clone(&pending_jobs);
71        let pending_streams_for_read = Arc::clone(&pending_streams);
72        let state_for_read = Arc::clone(&state);
73
74        let read_task = tokio::spawn(async move {
75            while let Some(frame) = reader.next().await {
76                let binary = match frame {
77                    Ok(WebSocketMessage::Binary(payload)) => payload,
78                    Ok(WebSocketMessage::Close(_)) => {
79                        mark_closed(&state_for_read, "Steam CM closed the connection").await;
80                        fail_pending_jobs(
81                            &pending_jobs_for_read,
82                            &pending_streams_for_read,
83                            Error::Closed,
84                        )
85                        .await;
86                        let _ = incoming_tx.send(Err(Error::Closed));
87                        break;
88                    }
89                    Ok(WebSocketMessage::Ping(_))
90                    | Ok(WebSocketMessage::Pong(_))
91                    | Ok(WebSocketMessage::Text(_))
92                    | Ok(WebSocketMessage::Frame(_)) => {
93                        continue;
94                    }
95                    Err(error) => {
96                        let message = error.to_string();
97                        let wrapped = Error::from(error);
98                        mark_closed(&state_for_read, message.clone()).await;
99                        fail_pending_jobs(
100                            &pending_jobs_for_read,
101                            &pending_streams_for_read,
102                            Error::Transport(message),
103                        )
104                        .await;
105                        let _ = incoming_tx.send(Err(wrapped));
106                        break;
107                    }
108                };
109
110                match decode_frame(&binary) {
111                    Ok(packets) => {
112                        for packet in packets {
113                            // Server-initiated service notifications (`ServiceMethod` 146, e.g.
114                            // `FriendMessagesClient.IncomingMessage#1`, and the rarer
115                            // `ServiceMethodSendToClient` 152) are pushes, not responses to a
116                            // pending request. Never route them to pending jobs even if jobid_target
117                            // happens to match — doing so would consume the pending slot and silently
118                            // drop the real response. (Responses are `ServiceMethodResponse` 147.)
119                            let is_server_push = packet.emsg
120                                == crate::emsg::EMsg::ServiceMethod.raw()
121                                || packet.emsg
122                                    == crate::emsg::EMsg::ServiceMethodSendToClient.raw();
123
124                            if !is_server_push && let Some(job_id) = packet.jobid_target() {
125                                let waiter = {
126                                    let mut pending = pending_jobs_for_read.lock().await;
127                                    pending.remove(&job_id)
128                                };
129                                if let Some(waiter) = waiter {
130                                    let _ = waiter.send(Ok(packet));
131                                    continue;
132                                }
133
134                                let stream = {
135                                    let pending = pending_streams_for_read.lock().await;
136                                    pending.get(&job_id).cloned()
137                                };
138                                if let Some(stream) = stream {
139                                    if stream.send(Ok(packet)).is_err() {
140                                        let mut pending = pending_streams_for_read.lock().await;
141                                        pending.remove(&job_id);
142                                    }
143                                    continue;
144                                }
145                            }
146
147                            let _ = incoming_tx.send(Ok(packet));
148                        }
149                    }
150                    Err(error) => {
151                        let message = error.to_string();
152                        mark_closed(&state_for_read, message.clone()).await;
153                        fail_pending_jobs(
154                            &pending_jobs_for_read,
155                            &pending_streams_for_read,
156                            Error::Transport(message),
157                        )
158                        .await;
159                        let _ = incoming_tx.send(Err(error));
160                        break;
161                    }
162                }
163            }
164
165            mark_closed_if_unset(&state_for_read, "Steam CM read loop ended").await;
166            fail_pending_jobs(
167                &pending_jobs_for_read,
168                &pending_streams_for_read,
169                Error::Closed,
170            )
171            .await;
172        });
173
174        Ok(Self {
175            sender,
176            pending_jobs,
177            pending_streams,
178            incoming: incoming_rx,
179            next_job_id: AtomicU64::new(1),
180            state,
181            license_notify,
182            read_task,
183            heartbeat_task: None,
184        })
185    }
186
187    pub async fn send_message<M>(
188        &self,
189        emsg: EMsg,
190        header: &CMsgProtoBufHeader,
191        body: &M,
192    ) -> Result<()>
193    where
194        M: Message,
195    {
196        let payload = encode_message(emsg, header, body)?;
197        self.send_frame(payload).await
198    }
199
200    pub async fn request<M>(
201        &self,
202        emsg: EMsg,
203        header: CMsgProtoBufHeader,
204        body: &M,
205    ) -> Result<Packet>
206    where
207        M: Message,
208    {
209        let rx = self.send_request(emsg, header, body).await?;
210        rx.await
211            .map_err(|_| self.closed_error())
212            .and_then(|result| result)
213    }
214
215    /// Send a request and return the response receiver without awaiting it.
216    /// The caller can release any held locks before awaiting the receiver.
217    pub async fn send_request<M>(
218        &self,
219        emsg: EMsg,
220        mut header: CMsgProtoBufHeader,
221        body: &M,
222    ) -> Result<oneshot::Receiver<Result<Packet>>>
223    where
224        M: Message,
225    {
226        let job_id = self.next_job_id.fetch_add(1, Ordering::Relaxed);
227        header.jobid_source = Some(job_id);
228        if header.jobid_target.is_none() {
229            header.jobid_target = Some(NO_JOB_ID);
230        }
231
232        let (tx, rx) = oneshot::channel();
233        self.pending_jobs.lock().await.insert(job_id, tx);
234
235        if let Err(error) = self.send_message(emsg, &header, body).await {
236            self.pending_jobs.lock().await.remove(&job_id);
237            return Err(error);
238        }
239
240        Ok(rx)
241    }
242
243    /// Send a request whose response may arrive in multiple packets with the
244    /// same jobid_target. The caller must call `end_stream` after it sees the
245    /// protocol-level end marker.
246    pub async fn send_request_stream<M>(
247        &self,
248        emsg: EMsg,
249        mut header: CMsgProtoBufHeader,
250        body: &M,
251    ) -> Result<(u64, mpsc::UnboundedReceiver<Result<Packet>>)>
252    where
253        M: Message,
254    {
255        let job_id = self.next_job_id.fetch_add(1, Ordering::Relaxed);
256        header.jobid_source = Some(job_id);
257        if header.jobid_target.is_none() {
258            header.jobid_target = Some(NO_JOB_ID);
259        }
260
261        let (tx, rx) = mpsc::unbounded_channel();
262        self.pending_streams.lock().await.insert(job_id, tx);
263
264        if let Err(error) = self.send_message(emsg, &header, body).await {
265            self.pending_streams.lock().await.remove(&job_id);
266            return Err(error);
267        }
268
269        Ok((job_id, rx))
270    }
271
272    pub async fn end_stream(&self, job_id: u64) {
273        self.pending_streams.lock().await.remove(&job_id);
274    }
275
276    pub async fn next_event(&mut self) -> Option<Result<Packet>> {
277        self.incoming.recv().await
278    }
279
280    pub async fn set_logged_on(
281        &mut self,
282        steamid: u64,
283        client_session_id: i32,
284        heartbeat_seconds: i32,
285    ) -> Result<()> {
286        {
287            let mut state = self.state.write().await;
288            state.steamid = Some(steamid);
289            state.client_session_id = Some(client_session_id);
290            state.heartbeat_seconds = Some(heartbeat_seconds);
291        }
292
293        self.start_heartbeat(Duration::from_secs(heartbeat_seconds as u64))
294            .await
295    }
296
297    /// Extract the incoming event receiver so `run()` can select over it
298    /// without holding `&mut self`. Replaces `self.incoming` with a dead channel.
299    pub fn take_incoming(&mut self) -> IncomingEvents {
300        let (_dead_tx, dead_rx) = mpsc::unbounded_channel();
301        std::mem::replace(&mut self.incoming, dead_rx)
302    }
303
304    pub async fn state_snapshot(&self) -> ConnectionState {
305        self.state.read().await.clone()
306    }
307
308    pub async fn set_package_ids(&self, package_ids: Vec<u32>) {
309        {
310            let mut state = self.state.write().await;
311            state.license_list_received = true;
312            state.package_ids = package_ids;
313        }
314        self.license_notify.notify_waiters();
315    }
316
317    pub fn license_notify(&self) -> Arc<Notify> {
318        Arc::clone(&self.license_notify)
319    }
320
321    pub async fn is_closed(&self) -> bool {
322        self.state.read().await.close_reason.is_some()
323    }
324
325    async fn send_frame(&self, payload: bytes::Bytes) -> Result<()> {
326        if let Some(reason) = self.state.read().await.close_reason.clone() {
327            return Err(Error::Transport(reason));
328        }
329
330        let mut sender = self.sender.lock().await;
331        if let Err(error) = sender.send(WebSocketMessage::Binary(payload)).await {
332            let message = error.to_string();
333            {
334                let mut state = self.state.write().await;
335                state.close_reason = Some(message.clone());
336            }
337            return Err(Error::Transport(message));
338        }
339        Ok(())
340    }
341
342    async fn start_heartbeat(&mut self, interval: Duration) -> Result<()> {
343        if let Some(task) = self.heartbeat_task.take() {
344            task.abort();
345        }
346
347        let sender = Arc::clone(&self.sender);
348        let state = Arc::clone(&self.state);
349        self.heartbeat_task = Some(tokio::spawn(async move {
350            let mut ticker = time::interval(interval);
351            loop {
352                ticker.tick().await;
353
354                let state_snapshot = state.read().await.clone();
355                let header = CMsgProtoBufHeader {
356                    steamid: state_snapshot.steamid,
357                    client_sessionid: state_snapshot.client_session_id,
358                    ..Default::default()
359                };
360                let payload = match encode_message(
361                    EMsg::ClientHeartBeat,
362                    &header,
363                    &CMsgClientHeartBeat {
364                        send_reply: Some(false),
365                    },
366                ) {
367                    Ok(payload) => payload,
368                    Err(_) => break,
369                };
370
371                let mut writer = sender.lock().await;
372                if writer
373                    .send(WebSocketMessage::Binary(payload))
374                    .await
375                    .is_err()
376                {
377                    break;
378                }
379            }
380        }));
381
382        Ok(())
383    }
384}
385
386impl Drop for Connection {
387    fn drop(&mut self) {
388        self.read_task.abort();
389        if let Some(task) = self.heartbeat_task.take() {
390            task.abort();
391        }
392    }
393}
394
395impl Connection {
396    fn closed_error(&self) -> Error {
397        match self.state.try_read() {
398            Ok(state) => state
399                .close_reason
400                .clone()
401                .map(Error::Transport)
402                .unwrap_or(Error::Closed),
403            Err(_) => Error::Closed,
404        }
405    }
406}
407
408async fn mark_closed(state: &Arc<RwLock<ConnectionState>>, reason: impl Into<String>) {
409    let mut state = state.write().await;
410    state.close_reason = Some(reason.into());
411}
412
413async fn mark_closed_if_unset(state: &Arc<RwLock<ConnectionState>>, reason: impl Into<String>) {
414    let mut state = state.write().await;
415    if state.close_reason.is_none() {
416        state.close_reason = Some(reason.into());
417    }
418}
419
420async fn fail_pending_jobs(
421    pending_jobs: &PendingJobs,
422    pending_streams: &PendingStreams,
423    error: Error,
424) {
425    let waiters = {
426        let mut pending = pending_jobs.lock().await;
427        pending
428            .drain()
429            .map(|(_, waiter)| waiter)
430            .collect::<Vec<_>>()
431    };
432    let streams = {
433        let mut pending = pending_streams.lock().await;
434        pending
435            .drain()
436            .map(|(_, stream)| stream)
437            .collect::<Vec<_>>()
438    };
439
440    let error_message = error.to_string();
441    for waiter in waiters {
442        let _ = waiter.send(Err(Error::Transport(error_message.clone())));
443    }
444    for stream in streams {
445        let _ = stream.send(Err(Error::Transport(error_message.clone())));
446    }
447}