Skip to main content

steam_auth/transport/
websocket_cm.rs

1//! WebSocket CM transport for SteamClient platform authentication.
2//!
3//! This transport connects to Steam's Connection Manager (CM) servers via
4//! WebSocket and is required for authenticating with the SteamClient platform
5//! type.
6
7use std::{
8    collections::HashMap,
9    io::Read,
10    sync::{
11        atomic::{AtomicI32, Ordering},
12        Arc,
13    },
14};
15
16use flate2::read::GzDecoder;
17use futures_util::{SinkExt, StreamExt};
18use prost::Message;
19use steam_cm_provider::{CmServerProvider, HttpCmServerProvider};
20use steam_protos::{CMsgClientHello, CMsgClientServiceMethodLegacy, CMsgClientServiceMethodLegacyResponse, CMsgMulti, CMsgProtoBufHeader};
21use tokio::sync::{oneshot, Mutex};
22use tokio_tungstenite::{connect_async, tungstenite::Message as WsMessage};
23
24use crate::{
25    error::SessionError,
26    transport::{ApiRequest, ApiResponse},
27};
28
29/// Steam message IDs (EMsg)
30mod emsg {
31    pub const MULTI: u32 = 1;
32    pub const SERVICE_METHOD: u32 = 146;
33    pub const SERVICE_METHOD_RESPONSE: u32 = 147;
34    pub const CLIENT_HELLO: u32 = 4006;
35}
36
37struct MsgHdrProtoBuf {
38    pub msg: u32,
39    pub proto: CMsgProtoBufHeader,
40}
41
42impl MsgHdrProtoBuf {
43    fn encode(&self) -> Vec<u8> {
44        let proto_bytes = self.proto.encode_to_vec();
45        let mut result = Vec::new();
46
47        // EMsg with protobuf flag
48        result.extend_from_slice(&(self.msg | 0x80000000).to_le_bytes());
49        // Header length
50        result.extend_from_slice(&(proto_bytes.len() as u32).to_le_bytes());
51        // Proto header
52        result.extend_from_slice(&proto_bytes);
53
54        result
55    }
56
57    fn decode(data: &[u8]) -> Result<(Self, usize), SessionError> {
58        if data.len() < 8 {
59            return Err(SessionError::ProtocolError("Header too short".into()));
60        }
61
62        let msg = u32::from_le_bytes([data[0], data[1], data[2], data[3]]) & 0x7FFFFFFF;
63        let header_length = u32::from_le_bytes([data[4], data[5], data[6], data[7]]) as usize;
64
65        if data.len() < 8 + header_length {
66            return Err(SessionError::ProtocolError("Header incomplete".into()));
67        }
68
69        let proto = CMsgProtoBufHeader::decode(&data[8..8 + header_length])?;
70
71        Ok((Self { msg, proto }, 8 + header_length))
72    }
73}
74
75/// Internal connection state.
76struct ConnectionState {
77    session_id: AtomicI32,
78    job_id_counter: AtomicI32,
79    pending_jobs: Mutex<HashMap<u64, oneshot::Sender<ApiResponse>>>,
80}
81
82/// WebSocket CM transport for SteamClient authentication.
83#[allow(clippy::type_complexity)]
84pub struct WebSocketCMTransport {
85    ws_sender: Arc<Mutex<Option<futures_util::stream::SplitSink<tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>, WsMessage>>>>,
86    state: Arc<ConnectionState>,
87    connected: Arc<Mutex<bool>>,
88    cm_provider: Arc<dyn CmServerProvider>,
89}
90
91impl std::fmt::Debug for WebSocketCMTransport {
92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93        f.debug_struct("WebSocketCMTransport").field("connected", &self.connected).finish()
94    }
95}
96
97impl Clone for WebSocketCMTransport {
98    fn clone(&self) -> Self {
99        Self {
100            ws_sender: Arc::clone(&self.ws_sender),
101            state: Arc::clone(&self.state),
102            connected: Arc::clone(&self.connected),
103            cm_provider: Arc::clone(&self.cm_provider),
104        }
105    }
106}
107
108impl WebSocketCMTransport {
109    /// Create a new WebSocket CM transport and connect to Steam.
110    pub async fn new() -> Result<Self, SessionError> {
111        Self::with_options(None).await
112    }
113
114    /// Create a new WebSocket CM transport with optional provider
115    /// configuration.
116    pub async fn with_options(cm_provider: Option<Arc<dyn CmServerProvider>>) -> Result<Self, SessionError> {
117        let cm_provider = cm_provider.unwrap_or_else(|| Arc::new(HttpCmServerProvider::new_default()));
118
119        let transport = Self {
120            ws_sender: Arc::new(Mutex::new(None)),
121            state: Arc::new(ConnectionState { session_id: AtomicI32::new(0), job_id_counter: AtomicI32::new(0), pending_jobs: Mutex::new(HashMap::new()) }),
122            connected: Arc::new(Mutex::new(false)),
123            cm_provider,
124        };
125
126        transport.connect().await?;
127
128        Ok(transport)
129    }
130
131    /// Connect to a Steam CM server.
132    async fn connect(&self) -> Result<(), SessionError> {
133        // Get CM server
134        let server = self.cm_provider.get_server().await.map_err(|e| SessionError::NetworkError(format!("Failed to get CM server: {}", e)))?;
135
136        let url = format!("wss://{}/cmsocket/", server.endpoint);
137
138        tracing::debug!("Connecting to CM server: {}", url);
139
140        let (ws_stream, _) = connect_async(&url).await?;
141
142        let (write, mut read) = ws_stream.split();
143
144        *self.ws_sender.lock().await = Some(write);
145
146        // Send hello message
147        self.send_hello().await?;
148
149        // Spawn message receiver
150        let state = self.state.clone();
151        let connected = self.connected.clone();
152
153        tokio::spawn(async move {
154            while let Some(msg) = read.next().await {
155                match msg {
156                    Ok(WsMessage::Binary(data)) => {
157                        if let Err(e) = Self::handle_message(&state, &data, 0).await {
158                            tracing::error!("Error handling message: {}", e);
159                        }
160                    }
161                    Ok(WsMessage::Close(_)) => {
162                        *connected.lock().await = false;
163                        break;
164                    }
165                    Err(e) => {
166                        tracing::error!("WebSocket error: {}", e);
167                        *connected.lock().await = false;
168                        break;
169                    }
170                    _ => {}
171                }
172            }
173        });
174
175        *self.connected.lock().await = true;
176
177        Ok(())
178    }
179
180    /// Send the client hello message.
181    async fn send_hello(&self) -> Result<(), SessionError> {
182        let header = MsgHdrProtoBuf { msg: emsg::CLIENT_HELLO, proto: CMsgProtoBufHeader { client_sessionid: Some(0), ..Default::default() } };
183
184        let body = CMsgClientHello { protocol_version: Some(65580) };
185
186        let mut data = header.encode();
187        data.extend_from_slice(&body.encode_to_vec());
188
189        self.send_raw(&data).await
190    }
191
192    /// Send raw data over WebSocket.
193    async fn send_raw(&self, data: &[u8]) -> Result<(), SessionError> {
194        let mut sender = self.ws_sender.lock().await;
195        if let Some(ref mut ws) = *sender {
196            ws.send(WsMessage::Binary(data.to_vec())).await?;
197        } else {
198            return Err(SessionError::ProtocolError("Not connected".into()));
199        }
200        Ok(())
201    }
202
203    /// Handle an incoming message.
204    async fn handle_message(state: &ConnectionState, data: &[u8], depth: usize) -> Result<(), SessionError> {
205        if depth > 5 {
206            return Err(SessionError::ProtocolError("Message recursion depth exceeded".into()));
207        }
208
209        let (header, body_offset) = MsgHdrProtoBuf::decode(data)?;
210
211        match header.msg {
212            emsg::MULTI => {
213                // Handle CMsgMulti - decompress and process nested messages
214                let body = &data[body_offset..];
215                let multi = CMsgMulti::decode(body)?;
216
217                if let Some(message_body) = multi.message_body {
218                    let decompressed = if multi.size_unzipped.is_some() {
219                        // Decompress gzip
220                        let mut decoder = GzDecoder::new(message_body.as_slice());
221                        let mut result = Vec::new();
222                        decoder.read_to_end(&mut result).map_err(|e| SessionError::ProtocolError(format!("Gzip decompression failed: {}", e)))?;
223                        result
224                    } else {
225                        message_body
226                    };
227
228                    // Process nested messages
229                    let mut offset = 0;
230                    while offset < decompressed.len() {
231                        if offset + 4 > decompressed.len() {
232                            break;
233                        }
234                        let size = u32::from_le_bytes([decompressed[offset], decompressed[offset + 1], decompressed[offset + 2], decompressed[offset + 3]]) as usize;
235                        offset += 4;
236
237                        if offset + size > decompressed.len() {
238                            break;
239                        }
240
241                        let nested = &decompressed[offset..offset + size];
242                        Box::pin(Self::handle_message(state, nested, depth + 1)).await?;
243                        offset += size;
244                    }
245                }
246            }
247            emsg::SERVICE_METHOD_RESPONSE => {
248                let body = &data[body_offset..];
249                let response = CMsgClientServiceMethodLegacyResponse::decode(body)?;
250
251                if let Some(job_id) = header.proto.jobid_target {
252                    let mut pending = state.pending_jobs.lock().await;
253                    if let Some(sender) = pending.remove(&job_id) {
254                        let api_response = ApiResponse {
255                            result: header.proto.eresult,
256                            error_message: header.proto.error_message,
257                            response_data: response.serialized_method_response,
258                        };
259                        let _ = sender.send(api_response);
260                    }
261                }
262            }
263            _ => {
264                tracing::trace!("Unhandled message type: {}", header.msg);
265            }
266        }
267
268        Ok(())
269    }
270
271    /// Send a service method call.
272    async fn send_service_method(&self, method_name: &str, body: &[u8], access_token: Option<&str>) -> Result<ApiResponse, SessionError> {
273        let job_id = self.state.job_id_counter.fetch_add(1, Ordering::SeqCst) as u64 + 1;
274        let session_id = self.state.session_id.load(Ordering::SeqCst);
275
276        // Create header
277        let header_proto = CMsgProtoBufHeader {
278            client_sessionid: Some(session_id),
279            jobid_source: Some(job_id),
280            target_job_name: Some(method_name.to_string()),
281            realm: Some(1),
282            ..Default::default()
283        };
284
285        // Add auth token if provided
286        if let Some(_token) = access_token {
287            // Note: For WebSocket transport, tokens are typically set
288            // differently This is a simplified implementation
289        }
290
291        let header = MsgHdrProtoBuf { msg: emsg::SERVICE_METHOD, proto: header_proto };
292
293        let service_method = CMsgClientServiceMethodLegacy {
294            method_name: Some(method_name.to_string()),
295            serialized_method: Some(body.to_vec()),
296            is_notification: Some(false),
297        };
298
299        let mut data = header.encode();
300        data.extend_from_slice(&service_method.encode_to_vec());
301
302        // Create response channel
303        let (tx, rx) = oneshot::channel();
304        {
305            let mut pending = self.state.pending_jobs.lock().await;
306            pending.insert(job_id, tx);
307        }
308
309        // Send request
310        self.send_raw(&data).await?;
311
312        // Wait for response with timeout
313        let response = tokio::time::timeout(std::time::Duration::from_secs(30), rx).await.map_err(|_| SessionError::Timeout)?.map_err(|_| SessionError::ProtocolError("Response channel closed".into()))?;
314
315        Ok(response)
316    }
317}
318
319impl WebSocketCMTransport {
320    /// Send a request and receive a response.
321    pub async fn send_request(&self, request: ApiRequest) -> Result<ApiResponse, SessionError> {
322        let method_name = format!("I{}Service.{}/v{}", request.api_interface, request.api_method, request.api_version);
323
324        let body = request.request_data.unwrap_or_default();
325
326        self.send_service_method(&method_name, &body, request.access_token.as_deref()).await
327    }
328}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333
334    #[test]
335    fn test_msg_hdr_encode_decode() {
336        let header = MsgHdrProtoBuf {
337            msg: emsg::SERVICE_METHOD,
338            proto: CMsgProtoBufHeader { client_sessionid: Some(12345), jobid_source: Some(1), ..Default::default() },
339        };
340
341        let encoded = header.encode();
342        let (decoded, _) = MsgHdrProtoBuf::decode(&encoded).unwrap();
343
344        assert_eq!(decoded.msg, emsg::SERVICE_METHOD);
345        assert_eq!(decoded.proto.client_sessionid, Some(12345));
346        assert_eq!(decoded.proto.jobid_source, Some(1));
347    }
348}