tvdata_rs/transport/
websocket.rs1use std::sync::atomic::{AtomicU64, Ordering};
2
3use futures_util::SinkExt;
4use serde_json::{Value, json};
5use tokio::net::TcpStream;
6use tokio_tungstenite::tungstenite::client::IntoClientRequest;
7use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async, tungstenite::Message};
8#[cfg(feature = "tracing")]
9use tracing::{debug, warn};
10
11use crate::client::Endpoints;
12use crate::error::{Error, Result};
13
14pub type TradingViewWebSocket = WebSocketStream<MaybeTlsStream<TcpStream>>;
15
16pub(crate) async fn connect_socket(
17 endpoints: &Endpoints,
18 user_agent: &str,
19 session_id: Option<&str>,
20) -> Result<TradingViewWebSocket> {
21 #[cfg(feature = "tracing")]
22 debug!(
23 target: "tvdata_rs::transport",
24 url = endpoints.websocket_url().as_str(),
25 authenticated = session_id.is_some(),
26 "opening TradingView websocket",
27 );
28
29 let mut ws_request = endpoints
30 .websocket_url()
31 .as_str()
32 .into_client_request()
33 .map_err(Error::from)?;
34 ws_request.headers_mut().insert(
35 "Origin",
36 endpoints
37 .data_origin()
38 .as_str()
39 .parse()
40 .map_err(|_| Error::Protocol("failed to encode websocket origin header"))?,
41 );
42 ws_request.headers_mut().insert(
43 "User-Agent",
44 user_agent
45 .parse()
46 .map_err(|_| Error::Protocol("failed to encode websocket user agent header"))?,
47 );
48 if let Some(session_id) = session_id {
49 ws_request.headers_mut().insert(
50 "Cookie",
51 format!("sessionid={session_id}")
52 .parse()
53 .map_err(|_| Error::Protocol("failed to encode websocket cookie header"))?,
54 );
55 }
56
57 match connect_async(ws_request).await {
58 Ok((socket, _)) => {
59 #[cfg(feature = "tracing")]
60 debug!(
61 target: "tvdata_rs::transport",
62 url = endpoints.websocket_url().as_str(),
63 authenticated = session_id.is_some(),
64 "TradingView websocket connected",
65 );
66 Ok(socket)
67 }
68 Err(error) => {
69 #[cfg(feature = "tracing")]
70 warn!(
71 target: "tvdata_rs::transport",
72 url = endpoints.websocket_url().as_str(),
73 authenticated = session_id.is_some(),
74 error = %error,
75 "TradingView websocket connection failed",
76 );
77 Err(Error::from(error))
78 }
79 }
80}
81
82pub(crate) fn next_session_id(prefix: &str) -> String {
83 static COUNTER: AtomicU64 = AtomicU64::new(1);
84 let id = COUNTER.fetch_add(1, Ordering::Relaxed);
85 format!("{prefix}_{id:016x}")
86}
87
88pub(crate) async fn send_message(
89 socket: &mut TradingViewWebSocket,
90 method: &str,
91 params: Value,
92) -> Result<()> {
93 let payload = serde_json::to_string(&json!({ "m": method, "p": params }))?;
94 send_raw_frame(socket, payload).await
95}
96
97pub(crate) async fn send_raw_frame(
98 socket: &mut TradingViewWebSocket,
99 payload: String,
100) -> Result<()> {
101 let framed = format!("~m~{}~m~{payload}", payload.len());
102 socket.send(Message::Text(framed.into())).await?;
103 Ok(())
104}
105
106pub(crate) fn parse_framed_messages(frame: &str) -> Result<Vec<&str>> {
107 let mut rest = frame;
108 let mut payloads = Vec::new();
109
110 while !rest.is_empty() {
111 if let Some(next) = rest.strip_prefix("~m~") {
112 let Some((len, tail)) = next.split_once("~m~") else {
113 return Err(Error::Protocol("missing websocket frame length separator"));
114 };
115 let len: usize = len
116 .parse()
117 .map_err(|_| Error::Protocol("invalid websocket frame length"))?;
118 if tail.len() < len {
119 return Err(Error::Protocol(
120 "declared websocket frame length exceeds payload",
121 ));
122 }
123 let (payload, remainder) = tail.split_at(len);
124 payloads.push(payload);
125 rest = remainder;
126 continue;
127 }
128
129 if let Some((_, remainder)) = rest.split_once("~m~") {
130 rest = remainder;
131 continue;
132 }
133
134 return Err(Error::Protocol("unexpected websocket frame prefix"));
135 }
136
137 Ok(payloads)
138}
139
140#[cfg(test)]
141mod tests {
142 use std::sync::{Arc, Mutex};
143
144 use tokio::io::{AsyncReadExt, AsyncWriteExt};
145 use tokio::net::TcpListener;
146 use tokio_tungstenite::tungstenite::handshake::derive_accept_key;
147
148 use crate::client::Endpoints;
149
150 use super::*;
151
152 async fn read_upgrade_request(stream: &mut TcpStream) -> String {
153 let mut request = Vec::new();
154 let mut buffer = [0_u8; 1024];
155
156 loop {
157 let read = stream.read(&mut buffer).await.unwrap();
158 assert_ne!(read, 0, "client closed connection before websocket upgrade");
159 request.extend_from_slice(&buffer[..read]);
160 if request.windows(4).any(|window| window == b"\r\n\r\n") {
161 break;
162 }
163 }
164
165 String::from_utf8(request).unwrap()
166 }
167
168 fn header_value<'a>(request: &'a str, name: &str) -> Option<&'a str> {
169 request.lines().find_map(|line| {
170 let (header, value) = line.split_once(':')?;
171 header
172 .trim()
173 .eq_ignore_ascii_case(name)
174 .then_some(value.trim())
175 })
176 }
177
178 #[test]
179 fn parses_concatenated_websocket_frames() {
180 let frames = parse_framed_messages("~m~9~m~{\"m\":\"a\"}~m~9~m~{\"m\":\"b\"}").unwrap();
181
182 assert_eq!(frames, vec![r#"{"m":"a"}"#, r#"{"m":"b"}"#]);
183 }
184
185 #[tokio::test]
186 async fn connect_socket_includes_session_cookie_when_configured() {
187 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
188 let address = listener.local_addr().unwrap();
189 let cookie = Arc::new(Mutex::new(None::<String>));
190 let cookie_clone = Arc::clone(&cookie);
191
192 tokio::spawn(async move {
193 let (mut stream, _) = listener.accept().await.unwrap();
194 let request = read_upgrade_request(&mut stream).await;
195 *cookie_clone.lock().unwrap() = header_value(&request, "cookie").map(str::to_owned);
196
197 let key = header_value(&request, "sec-websocket-key")
198 .expect("websocket upgrade request must contain Sec-WebSocket-Key");
199 let response = format!(
200 "HTTP/1.1 101 Switching Protocols\r\n\
201 Connection: Upgrade\r\n\
202 Upgrade: websocket\r\n\
203 Sec-WebSocket-Accept: {}\r\n\
204 \r\n",
205 derive_accept_key(key.as_bytes())
206 );
207
208 stream.write_all(response.as_bytes()).await.unwrap();
209 });
210
211 let endpoints = Endpoints::default()
212 .with_websocket_url(format!("ws://{address}"))
213 .unwrap();
214
215 let _socket = connect_socket(&endpoints, "tvdata-rs/test", Some("abc123"))
216 .await
217 .unwrap();
218
219 assert_eq!(cookie.lock().unwrap().as_deref(), Some("sessionid=abc123"));
220 }
221}