1use std::{
8 collections::HashMap,
9 io::Read,
10 sync::{
11 atomic::{AtomicI32, Ordering},
12 Arc,
13 },
14};
15
16use flate2::read::GzDecoder;
17use futures_util::{SinkExt, StreamExt};
18use prost::Message;
19use steam_cm_provider::{CmServerProvider, HttpCmServerProvider};
20use steam_protos::{CMsgClientHello, CMsgClientServiceMethodLegacy, CMsgClientServiceMethodLegacyResponse, CMsgMulti, CMsgProtoBufHeader};
21use tokio::sync::{oneshot, Mutex};
22use tokio_tungstenite::{connect_async, tungstenite::Message as WsMessage};
23
24use crate::{
25 error::SessionError,
26 transport::{ApiRequest, ApiResponse},
27};
28
29mod emsg {
31 pub const MULTI: u32 = 1;
32 pub const SERVICE_METHOD: u32 = 146;
33 pub const SERVICE_METHOD_RESPONSE: u32 = 147;
34 pub const CLIENT_HELLO: u32 = 4006;
35}
36
37struct MsgHdrProtoBuf {
38 pub msg: u32,
39 pub proto: CMsgProtoBufHeader,
40}
41
42impl MsgHdrProtoBuf {
43 fn encode(&self) -> Vec<u8> {
44 let proto_bytes = self.proto.encode_to_vec();
45 let mut result = Vec::new();
46
47 result.extend_from_slice(&(self.msg | 0x80000000).to_le_bytes());
49 result.extend_from_slice(&(proto_bytes.len() as u32).to_le_bytes());
51 result.extend_from_slice(&proto_bytes);
53
54 result
55 }
56
57 fn decode(data: &[u8]) -> Result<(Self, usize), SessionError> {
58 if data.len() < 8 {
59 return Err(SessionError::ProtocolError("Header too short".into()));
60 }
61
62 let msg = u32::from_le_bytes([data[0], data[1], data[2], data[3]]) & 0x7FFFFFFF;
63 let header_length = u32::from_le_bytes([data[4], data[5], data[6], data[7]]) as usize;
64
65 if data.len() < 8 + header_length {
66 return Err(SessionError::ProtocolError("Header incomplete".into()));
67 }
68
69 let proto = CMsgProtoBufHeader::decode(&data[8..8 + header_length])?;
70
71 Ok((Self { msg, proto }, 8 + header_length))
72 }
73}
74
75struct ConnectionState {
77 session_id: AtomicI32,
78 job_id_counter: AtomicI32,
79 pending_jobs: Mutex<HashMap<u64, oneshot::Sender<ApiResponse>>>,
80}
81
82#[allow(clippy::type_complexity)]
84pub struct WebSocketCMTransport {
85 ws_sender: Arc<Mutex<Option<futures_util::stream::SplitSink<tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>, WsMessage>>>>,
86 state: Arc<ConnectionState>,
87 connected: Arc<Mutex<bool>>,
88 cm_provider: Arc<dyn CmServerProvider>,
89}
90
91impl std::fmt::Debug for WebSocketCMTransport {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 f.debug_struct("WebSocketCMTransport").field("connected", &self.connected).finish()
94 }
95}
96
97impl Clone for WebSocketCMTransport {
98 fn clone(&self) -> Self {
99 Self {
100 ws_sender: Arc::clone(&self.ws_sender),
101 state: Arc::clone(&self.state),
102 connected: Arc::clone(&self.connected),
103 cm_provider: Arc::clone(&self.cm_provider),
104 }
105 }
106}
107
108impl WebSocketCMTransport {
109 pub async fn new() -> Result<Self, SessionError> {
111 Self::with_options(None).await
112 }
113
114 pub async fn with_options(cm_provider: Option<Arc<dyn CmServerProvider>>) -> Result<Self, SessionError> {
117 let cm_provider = cm_provider.unwrap_or_else(|| Arc::new(HttpCmServerProvider::new_default()));
118
119 let transport = Self {
120 ws_sender: Arc::new(Mutex::new(None)),
121 state: Arc::new(ConnectionState { session_id: AtomicI32::new(0), job_id_counter: AtomicI32::new(0), pending_jobs: Mutex::new(HashMap::new()) }),
122 connected: Arc::new(Mutex::new(false)),
123 cm_provider,
124 };
125
126 transport.connect().await?;
127
128 Ok(transport)
129 }
130
131 async fn connect(&self) -> Result<(), SessionError> {
133 let server = self.cm_provider.get_server().await.map_err(|e| SessionError::NetworkError(format!("Failed to get CM server: {}", e)))?;
135
136 let url = format!("wss://{}/cmsocket/", server.endpoint);
137
138 tracing::debug!("Connecting to CM server: {}", url);
139
140 let (ws_stream, _) = connect_async(&url).await?;
141
142 let (write, mut read) = ws_stream.split();
143
144 *self.ws_sender.lock().await = Some(write);
145
146 self.send_hello().await?;
148
149 let state = self.state.clone();
151 let connected = self.connected.clone();
152
153 tokio::spawn(async move {
154 while let Some(msg) = read.next().await {
155 match msg {
156 Ok(WsMessage::Binary(data)) => {
157 if let Err(e) = Self::handle_message(&state, &data, 0).await {
158 tracing::error!("Error handling message: {}", e);
159 }
160 }
161 Ok(WsMessage::Close(_)) => {
162 *connected.lock().await = false;
163 break;
164 }
165 Err(e) => {
166 tracing::error!("WebSocket error: {}", e);
167 *connected.lock().await = false;
168 break;
169 }
170 _ => {}
171 }
172 }
173 });
174
175 *self.connected.lock().await = true;
176
177 Ok(())
178 }
179
180 async fn send_hello(&self) -> Result<(), SessionError> {
182 let header = MsgHdrProtoBuf { msg: emsg::CLIENT_HELLO, proto: CMsgProtoBufHeader { client_sessionid: Some(0), ..Default::default() } };
183
184 let body = CMsgClientHello { protocol_version: Some(65580) };
185
186 let mut data = header.encode();
187 data.extend_from_slice(&body.encode_to_vec());
188
189 self.send_raw(&data).await
190 }
191
192 async fn send_raw(&self, data: &[u8]) -> Result<(), SessionError> {
194 let mut sender = self.ws_sender.lock().await;
195 if let Some(ref mut ws) = *sender {
196 ws.send(WsMessage::Binary(data.to_vec())).await?;
197 } else {
198 return Err(SessionError::ProtocolError("Not connected".into()));
199 }
200 Ok(())
201 }
202
203 async fn handle_message(state: &ConnectionState, data: &[u8], depth: usize) -> Result<(), SessionError> {
205 if depth > 5 {
206 return Err(SessionError::ProtocolError("Message recursion depth exceeded".into()));
207 }
208
209 let (header, body_offset) = MsgHdrProtoBuf::decode(data)?;
210
211 match header.msg {
212 emsg::MULTI => {
213 let body = &data[body_offset..];
215 let multi = CMsgMulti::decode(body)?;
216
217 if let Some(message_body) = multi.message_body {
218 let decompressed = if multi.size_unzipped.is_some() {
219 let mut decoder = GzDecoder::new(message_body.as_slice());
221 let mut result = Vec::new();
222 decoder.read_to_end(&mut result).map_err(|e| SessionError::ProtocolError(format!("Gzip decompression failed: {}", e)))?;
223 result
224 } else {
225 message_body
226 };
227
228 let mut offset = 0;
230 while offset < decompressed.len() {
231 if offset + 4 > decompressed.len() {
232 break;
233 }
234 let size = u32::from_le_bytes([decompressed[offset], decompressed[offset + 1], decompressed[offset + 2], decompressed[offset + 3]]) as usize;
235 offset += 4;
236
237 if offset + size > decompressed.len() {
238 break;
239 }
240
241 let nested = &decompressed[offset..offset + size];
242 Box::pin(Self::handle_message(state, nested, depth + 1)).await?;
243 offset += size;
244 }
245 }
246 }
247 emsg::SERVICE_METHOD_RESPONSE => {
248 let body = &data[body_offset..];
249 let response = CMsgClientServiceMethodLegacyResponse::decode(body)?;
250
251 if let Some(job_id) = header.proto.jobid_target {
252 let mut pending = state.pending_jobs.lock().await;
253 if let Some(sender) = pending.remove(&job_id) {
254 let api_response = ApiResponse {
255 result: header.proto.eresult,
256 error_message: header.proto.error_message,
257 response_data: response.serialized_method_response,
258 };
259 let _ = sender.send(api_response);
260 }
261 }
262 }
263 _ => {
264 tracing::trace!("Unhandled message type: {}", header.msg);
265 }
266 }
267
268 Ok(())
269 }
270
271 async fn send_service_method(&self, method_name: &str, body: &[u8], access_token: Option<&str>) -> Result<ApiResponse, SessionError> {
273 let job_id = self.state.job_id_counter.fetch_add(1, Ordering::SeqCst) as u64 + 1;
274 let session_id = self.state.session_id.load(Ordering::SeqCst);
275
276 let header_proto = CMsgProtoBufHeader {
278 client_sessionid: Some(session_id),
279 jobid_source: Some(job_id),
280 target_job_name: Some(method_name.to_string()),
281 realm: Some(1),
282 ..Default::default()
283 };
284
285 if let Some(_token) = access_token {
287 }
290
291 let header = MsgHdrProtoBuf { msg: emsg::SERVICE_METHOD, proto: header_proto };
292
293 let service_method = CMsgClientServiceMethodLegacy {
294 method_name: Some(method_name.to_string()),
295 serialized_method: Some(body.to_vec()),
296 is_notification: Some(false),
297 };
298
299 let mut data = header.encode();
300 data.extend_from_slice(&service_method.encode_to_vec());
301
302 let (tx, rx) = oneshot::channel();
304 {
305 let mut pending = self.state.pending_jobs.lock().await;
306 pending.insert(job_id, tx);
307 }
308
309 self.send_raw(&data).await?;
311
312 let response = tokio::time::timeout(std::time::Duration::from_secs(30), rx).await.map_err(|_| SessionError::Timeout)?.map_err(|_| SessionError::ProtocolError("Response channel closed".into()))?;
314
315 Ok(response)
316 }
317}
318
319impl WebSocketCMTransport {
320 pub async fn send_request(&self, request: ApiRequest) -> Result<ApiResponse, SessionError> {
322 let method_name = format!("I{}Service.{}/v{}", request.api_interface, request.api_method, request.api_version);
323
324 let body = request.request_data.unwrap_or_default();
325
326 self.send_service_method(&method_name, &body, request.access_token.as_deref()).await
327 }
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333
334 #[test]
335 fn test_msg_hdr_encode_decode() {
336 let header = MsgHdrProtoBuf {
337 msg: emsg::SERVICE_METHOD,
338 proto: CMsgProtoBufHeader { client_sessionid: Some(12345), jobid_source: Some(1), ..Default::default() },
339 };
340
341 let encoded = header.encode();
342 let (decoded, _) = MsgHdrProtoBuf::decode(&encoded).unwrap();
343
344 assert_eq!(decoded.msg, emsg::SERVICE_METHOD);
345 assert_eq!(decoded.proto.client_sessionid, Some(12345));
346 assert_eq!(decoded.proto.jobid_source, Some(1));
347 }
348}