Skip to main content

tvdata_rs/transport/
websocket.rs

1use std::sync::atomic::{AtomicU64, Ordering};
2
3use futures_util::SinkExt;
4use serde_json::{Value, json};
5use tokio::net::TcpStream;
6use tokio_tungstenite::tungstenite::client::IntoClientRequest;
7use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async, tungstenite::Message};
8#[cfg(feature = "tracing")]
9use tracing::{debug, warn};
10
11use crate::client::Endpoints;
12use crate::error::{Error, Result};
13
14pub type TradingViewWebSocket = WebSocketStream<MaybeTlsStream<TcpStream>>;
15
16pub(crate) async fn connect_socket(
17    endpoints: &Endpoints,
18    user_agent: &str,
19    session_id: Option<&str>,
20) -> Result<TradingViewWebSocket> {
21    #[cfg(feature = "tracing")]
22    debug!(
23        target: "tvdata_rs::transport",
24        url = endpoints.websocket_url().as_str(),
25        authenticated = session_id.is_some(),
26        "opening TradingView websocket",
27    );
28
29    let mut ws_request = endpoints
30        .websocket_url()
31        .as_str()
32        .into_client_request()
33        .map_err(Error::from)?;
34    ws_request.headers_mut().insert(
35        "Origin",
36        endpoints
37            .data_origin()
38            .as_str()
39            .parse()
40            .map_err(|_| Error::Protocol("failed to encode websocket origin header"))?,
41    );
42    ws_request.headers_mut().insert(
43        "User-Agent",
44        user_agent
45            .parse()
46            .map_err(|_| Error::Protocol("failed to encode websocket user agent header"))?,
47    );
48    if let Some(session_id) = session_id {
49        ws_request.headers_mut().insert(
50            "Cookie",
51            format!("sessionid={session_id}")
52                .parse()
53                .map_err(|_| Error::Protocol("failed to encode websocket cookie header"))?,
54        );
55    }
56
57    match connect_async(ws_request).await {
58        Ok((socket, _)) => {
59            #[cfg(feature = "tracing")]
60            debug!(
61                target: "tvdata_rs::transport",
62                url = endpoints.websocket_url().as_str(),
63                authenticated = session_id.is_some(),
64                "TradingView websocket connected",
65            );
66            Ok(socket)
67        }
68        Err(error) => {
69            #[cfg(feature = "tracing")]
70            warn!(
71                target: "tvdata_rs::transport",
72                url = endpoints.websocket_url().as_str(),
73                authenticated = session_id.is_some(),
74                error = %error,
75                "TradingView websocket connection failed",
76            );
77            Err(Error::from(error))
78        }
79    }
80}
81
82pub(crate) fn next_session_id(prefix: &str) -> String {
83    static COUNTER: AtomicU64 = AtomicU64::new(1);
84    let id = COUNTER.fetch_add(1, Ordering::Relaxed);
85    format!("{prefix}_{id:016x}")
86}
87
88pub(crate) async fn send_message(
89    socket: &mut TradingViewWebSocket,
90    method: &str,
91    params: Value,
92) -> Result<()> {
93    let payload = serde_json::to_string(&json!({ "m": method, "p": params }))?;
94    send_raw_frame(socket, payload).await
95}
96
97pub(crate) async fn send_raw_frame(
98    socket: &mut TradingViewWebSocket,
99    payload: String,
100) -> Result<()> {
101    let framed = format!("~m~{}~m~{payload}", payload.len());
102    socket.send(Message::Text(framed.into())).await?;
103    Ok(())
104}
105
106pub(crate) fn parse_framed_messages(frame: &str) -> Result<Vec<&str>> {
107    let mut rest = frame;
108    let mut payloads = Vec::new();
109
110    while !rest.is_empty() {
111        if let Some(next) = rest.strip_prefix("~m~") {
112            let Some((len, tail)) = next.split_once("~m~") else {
113                return Err(Error::Protocol("missing websocket frame length separator"));
114            };
115            let len: usize = len
116                .parse()
117                .map_err(|_| Error::Protocol("invalid websocket frame length"))?;
118            if tail.len() < len {
119                return Err(Error::Protocol(
120                    "declared websocket frame length exceeds payload",
121                ));
122            }
123            let (payload, remainder) = tail.split_at(len);
124            payloads.push(payload);
125            rest = remainder;
126            continue;
127        }
128
129        if let Some((_, remainder)) = rest.split_once("~m~") {
130            rest = remainder;
131            continue;
132        }
133
134        return Err(Error::Protocol("unexpected websocket frame prefix"));
135    }
136
137    Ok(payloads)
138}
139
140#[cfg(test)]
141mod tests {
142    use std::sync::{Arc, Mutex};
143
144    use tokio::io::{AsyncReadExt, AsyncWriteExt};
145    use tokio::net::TcpListener;
146    use tokio_tungstenite::tungstenite::handshake::derive_accept_key;
147
148    use crate::client::Endpoints;
149
150    use super::*;
151
152    async fn read_upgrade_request(stream: &mut TcpStream) -> String {
153        let mut request = Vec::new();
154        let mut buffer = [0_u8; 1024];
155
156        loop {
157            let read = stream.read(&mut buffer).await.unwrap();
158            assert_ne!(read, 0, "client closed connection before websocket upgrade");
159            request.extend_from_slice(&buffer[..read]);
160            if request.windows(4).any(|window| window == b"\r\n\r\n") {
161                break;
162            }
163        }
164
165        String::from_utf8(request).unwrap()
166    }
167
168    fn header_value<'a>(request: &'a str, name: &str) -> Option<&'a str> {
169        request.lines().find_map(|line| {
170            let (header, value) = line.split_once(':')?;
171            header
172                .trim()
173                .eq_ignore_ascii_case(name)
174                .then_some(value.trim())
175        })
176    }
177
178    #[test]
179    fn parses_concatenated_websocket_frames() {
180        let frames = parse_framed_messages("~m~9~m~{\"m\":\"a\"}~m~9~m~{\"m\":\"b\"}").unwrap();
181
182        assert_eq!(frames, vec![r#"{"m":"a"}"#, r#"{"m":"b"}"#]);
183    }
184
185    #[tokio::test]
186    async fn connect_socket_includes_session_cookie_when_configured() {
187        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
188        let address = listener.local_addr().unwrap();
189        let cookie = Arc::new(Mutex::new(None::<String>));
190        let cookie_clone = Arc::clone(&cookie);
191
192        tokio::spawn(async move {
193            let (mut stream, _) = listener.accept().await.unwrap();
194            let request = read_upgrade_request(&mut stream).await;
195            *cookie_clone.lock().unwrap() = header_value(&request, "cookie").map(str::to_owned);
196
197            let key = header_value(&request, "sec-websocket-key")
198                .expect("websocket upgrade request must contain Sec-WebSocket-Key");
199            let response = format!(
200                "HTTP/1.1 101 Switching Protocols\r\n\
201                 Connection: Upgrade\r\n\
202                 Upgrade: websocket\r\n\
203                 Sec-WebSocket-Accept: {}\r\n\
204                 \r\n",
205                derive_accept_key(key.as_bytes())
206            );
207
208            stream.write_all(response.as_bytes()).await.unwrap();
209        });
210
211        let endpoints = Endpoints::default()
212            .with_websocket_url(format!("ws://{address}"))
213            .unwrap();
214
215        let _socket = connect_socket(&endpoints, "tvdata-rs/test", Some("abc123"))
216            .await
217            .unwrap();
218
219        assert_eq!(cookie.lock().unwrap().as_deref(), Some("sessionid=abc123"));
220    }
221}