Skip to main content

schwab_sdk/streamer/
connection.rs

1use std::sync::Arc;
2use std::sync::atomic::{AtomicU64, Ordering};
3
4use fastwebsockets::{FragmentCollectorRead, WebSocketWrite};
5use http::{
6    Method,
7    header::{CONNECTION, HOST, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION, UPGRADE},
8};
9use http_body_util::Empty;
10use hyper::{Request, Uri, body::Bytes};
11use hyper_util::rt::TokioIo;
12use rustls_platform_verifier::ConfigVerifierExt;
13use tokio::net::TcpStream;
14use tokio::sync::{Mutex, watch};
15use tokio_rustls::{TlsConnector, client::TlsStream, rustls};
16
17use crate::error::{Error, Result};
18use crate::secrets::CustomerId;
19use crate::streamer::events::{ConnectionEvent, DisconnectReason};
20use crate::streamer::protocol::{ResponseCode, Service, StreamerCommand};
21use crate::streamer::request::{RequestPayload, StreamerRequest};
22use crate::streamer::response::{RawStreamerResponse, StreamerResponse};
23use crate::streamer::subscription::SubscribeRequest;
24use crate::streamer::{account_activity, admin, book, chart, level_one, screener};
25use crate::token::TokenProvider;
26use crate::user_preferences::StreamerInfo;
27
28type Upgraded = TokioIo<hyper::upgrade::Upgraded>;
29type WsReadHalf = FragmentCollectorRead<tokio::io::ReadHalf<Upgraded>>;
30type WsWriteHalf = WebSocketWrite<tokio::io::WriteHalf<Upgraded>>;
31type WebSocket = fastwebsockets::WebSocket<Upgraded>;
32
33/// Errors that surface from the streamer transport (TCP / TLS / WebSocket
34/// handshake plus any frame-level error after the socket is up).
35#[derive(Debug, thiserror::Error)]
36pub enum WebSocketError {
37    /// TCP connect failed.
38    #[error("failed to connect to server: {0}")]
39    Connect(#[source] std::io::Error),
40    /// WebSocket upgrade handshake failed.
41    #[error("failed to perform websocket handshake: {0}")]
42    Handshake(#[source] fastwebsockets::WebSocketError),
43    /// `streamerSocketUrl` host is not a valid DNS name.
44    #[error("invalid domain: {0}")]
45    InvalidDomain(#[source] rustls_pki_types::InvalidDnsNameError),
46    /// `streamerSocketUrl` did not include a host component.
47    #[error("host is required")]
48    MissingHost,
49    /// TLS handshake failed on top of the TCP socket.
50    #[error("failed to create TLS stream: {0}")]
51    TlsStream(#[source] std::io::Error),
52    /// Building the rustls client config failed.
53    #[error("failed to configure TLS: {0}")]
54    TlsConfig(#[source] rustls::Error),
55    /// Building the HTTP upgrade request failed.
56    #[error("failed to build upgrade request: {0}")]
57    BuildRequest(#[source] http::Error),
58    /// `streamerSocketUrl` used a scheme that is not permitted for the
59    /// current build. `wss://` is always accepted; `ws://` is accepted
60    /// only in debug builds, because a plaintext WebSocket would carry
61    /// the bearer token in the LOGIN frame in the clear. Any other
62    /// scheme (or a URL with no scheme at all) is always rejected.
63    #[error("unsupported websocket scheme: {0}")]
64    UnsupportedScheme(String),
65    /// Runtime frame error after the websocket is up: read/write/control
66    /// frame failures from `fastwebsockets`.
67    #[error("websocket runtime error: {0}")]
68    Runtime(#[from] fastwebsockets::WebSocketError),
69}
70
71impl WebSocketError {
72    /// Whether a fresh `connect` (and re-login) is worth attempting after
73    /// this error. Returns `false` for configuration-shaped failures that
74    /// will fail identically on retry (bad scheme, missing host, malformed
75    /// upgrade request, rustls config error) and `true` for transport- or
76    /// session-level failures (TCP connect, TLS handshake, WebSocket
77    /// handshake, post-handshake frame errors).
78    ///
79    /// Used by [`crate::Error::is_retryable`] to classify
80    /// [`crate::Error::WebSocket`].
81    pub fn is_retryable(&self) -> bool {
82        match self {
83            WebSocketError::Connect(_)
84            | WebSocketError::TlsStream(_)
85            | WebSocketError::Handshake(_)
86            | WebSocketError::Runtime(_) => true,
87            WebSocketError::InvalidDomain(_)
88            | WebSocketError::MissingHost
89            | WebSocketError::TlsConfig(_)
90            | WebSocketError::BuildRequest(_)
91            | WebSocketError::UnsupportedScheme(_) => false,
92        }
93    }
94}
95
96impl From<fastwebsockets::WebSocketError> for Error {
97    fn from(value: fastwebsockets::WebSocketError) -> Self {
98        Error::WebSocket(WebSocketError::Runtime(value))
99    }
100}
101
102struct SpawnExecutor;
103
104impl<Fut> hyper::rt::Executor<Fut> for SpawnExecutor
105where
106    Fut: Future + Send + 'static,
107    Fut::Output: Send + 'static,
108{
109    fn execute(&self, fut: Fut) {
110        tokio::task::spawn(fut);
111    }
112}
113
114async fn connect_tls(uri: &Uri) -> std::result::Result<TlsStream<TcpStream>, WebSocketError> {
115    let host = uri.host().ok_or(WebSocketError::MissingHost)?;
116    let port = uri.port_u16().unwrap_or(443);
117    let addr = format!("{}:{}", host, port);
118
119    let socket = TcpStream::connect(addr)
120        .await
121        .map_err(WebSocketError::Connect)?;
122
123    let domain = rustls_pki_types::ServerName::try_from(host.to_string())
124        .map_err(WebSocketError::InvalidDomain)?;
125    let config =
126        rustls::ClientConfig::with_platform_verifier().map_err(WebSocketError::TlsConfig)?;
127    let connector = TlsConnector::from(Arc::new(config));
128    connector
129        .connect(domain, socket)
130        .await
131        .map_err(WebSocketError::TlsStream)
132}
133
134async fn connect_tcp(uri: &Uri) -> std::result::Result<TcpStream, WebSocketError> {
135    let host = uri.host().ok_or(WebSocketError::MissingHost)?;
136    let port = uri.port_u16().unwrap_or(80);
137    TcpStream::connect(format!("{}:{}", host, port))
138        .await
139        .map_err(WebSocketError::Connect)
140}
141
142/// Which transport to use for a given streamer URL scheme.
143#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
144enum WsTransport {
145    /// TLS handshake on top of TCP (`wss://`).
146    Tls,
147    /// Plain TCP (`ws://`). Reachable only in debug builds; the
148    /// streamer LOGIN frame would otherwise put the bearer on the wire
149    /// in cleartext.
150    Plain,
151}
152
153/// Map a URI scheme to the [`WsTransport`] to use. `allow_insecure`
154/// gates `ws://`; release builds set it to `false`.
155///
156/// Extracted from [`connect_websocket`] so both modes are unit-testable
157/// from a single test binary without rebuilding in release mode.
158fn check_websocket_scheme(
159    scheme: Option<&str>,
160    allow_insecure: bool,
161) -> std::result::Result<WsTransport, WebSocketError> {
162    match scheme {
163        Some("wss") => Ok(WsTransport::Tls),
164        Some("ws") if allow_insecure => Ok(WsTransport::Plain),
165        Some("ws") => Err(WebSocketError::UnsupportedScheme("ws".to_string())),
166        Some(other) => Err(WebSocketError::UnsupportedScheme(other.to_string())),
167        None => Err(WebSocketError::UnsupportedScheme(String::new())),
168    }
169}
170
171async fn connect_websocket(uri: &Uri) -> std::result::Result<WebSocket, WebSocketError> {
172    let transport = check_websocket_scheme(uri.scheme_str(), cfg!(debug_assertions))?;
173
174    let req = Request::builder()
175        .method(Method::GET)
176        .uri(uri)
177        .header(HOST, uri.host().ok_or(WebSocketError::MissingHost)?)
178        .header(UPGRADE, "websocket")
179        .header(CONNECTION, "upgrade")
180        .header(SEC_WEBSOCKET_KEY, fastwebsockets::handshake::generate_key())
181        .header(SEC_WEBSOCKET_VERSION, "13")
182        .body(Empty::<Bytes>::new())
183        .map_err(WebSocketError::BuildRequest)?;
184
185    match transport {
186        WsTransport::Tls => {
187            let stream = connect_tls(uri).await?;
188            let (ws, _) = fastwebsockets::handshake::client(&SpawnExecutor, req, stream)
189                .await
190                .map_err(WebSocketError::Handshake)?;
191            Ok(ws)
192        }
193        WsTransport::Plain => {
194            let stream = connect_tcp(uri).await?;
195            let (ws, _) = fastwebsockets::handshake::client(&SpawnExecutor, req, stream)
196                .await
197                .map_err(WebSocketError::Handshake)?;
198            Ok(ws)
199        }
200    }
201}
202
203/// Open the streamer websocket using the connection details from
204/// `/userPreference` and return the read and write halves of the session.
205/// Call [`WriteHalf::login`] before any other command.
206///
207/// `token_provider` is the [`TokenProvider`] used to fetch the bearer for
208/// the LOGIN frame. It is consulted at LOGIN-frame construction so a token
209/// rotated in the provider after `connect` returns is the one carried on
210/// the wire when `login` is called.
211///
212/// Every field on `streamer_info` is `Option` per the spec; this function
213/// validates that the fields needed to log in and route subscribe frames
214/// (socket URL, customer id, correlation id, channel, function id) are
215/// all present, returning [`Error::InvalidPreference`] for the first
216/// missing one.
217pub async fn connect(
218    streamer_info: StreamerInfo,
219    token_provider: Arc<dyn TokenProvider + Send + Sync>,
220) -> Result<(ReadHalf, WriteHalf)> {
221    let validated = ValidatedStreamerInfo::try_from(streamer_info)?;
222    let websocket = connect_websocket(&validated.socket_url).await?;
223    Ok(split(websocket, validated, token_provider))
224}
225
226/// `StreamerInfo` after the per-field optionality has been resolved.
227/// Constructing one of these is the only way to reach [`split`].
228#[derive(Debug)]
229struct ValidatedStreamerInfo {
230    socket_url: Uri,
231    customer_id: CustomerId,
232    correlation_id: String,
233    channel: String,
234    function_id: String,
235}
236
237impl TryFrom<StreamerInfo> for ValidatedStreamerInfo {
238    type Error = Error;
239
240    fn try_from(info: StreamerInfo) -> Result<Self> {
241        fn required<T>(field: &'static str, value: Option<T>) -> Result<T> {
242            value.ok_or(Error::InvalidPreference {
243                field,
244                reason: "missing".to_string(),
245            })
246        }
247
248        let socket_url = required("streamerSocketUrl", info.streamer_socket_url)?
249            .parse::<Uri>()
250            .map_err(|e| Error::InvalidPreference {
251                field: "streamerSocketUrl",
252                reason: e.to_string(),
253            })?;
254
255        Ok(Self {
256            socket_url,
257            customer_id: required("schwabClientCustomerId", info.schwab_client_customer_id)?,
258            correlation_id: required("schwabClientCorrelId", info.schwab_client_correlation_id)?,
259            channel: required("schwabClientChannel", info.schwab_client_channel)?,
260            function_id: required("schwabClientFunctionId", info.schwab_client_function_id)?,
261        })
262    }
263}
264
265/// Split a connected [`WebSocket`] into the [`ReadHalf`] and [`WriteHalf`]
266/// the streamer surface exposes.
267///
268/// The websocket's write half is owned by an `Arc<Mutex<_>>` shared by both
269/// halves: the writer locks it for `login`/`logout`/`send`, the reader locks
270/// it inside `read_frame`'s control-frame callback to reply to pings and
271/// close frames. No background task is spawned; all I/O happens inline on
272/// the caller's own stack inside `recv()` / `send()`.
273fn split(
274    websocket: WebSocket,
275    streamer_info: ValidatedStreamerInfo,
276    token_provider: Arc<dyn TokenProvider + Send + Sync>,
277) -> (ReadHalf, WriteHalf) {
278    let (read_half, write_half) = websocket.split(tokio::io::split);
279    let write_half = Arc::new(Mutex::new(write_half));
280    let (events_tx, _) = watch::channel(ConnectionEvent::Connected);
281
282    let reader = ReadHalf {
283        read_half: FragmentCollectorRead::new(read_half),
284        write_half: write_half.clone(),
285        events_tx,
286    };
287
288    let writer = WriteHalf {
289        write_half,
290        customer_id: streamer_info.customer_id,
291        correlation_id: streamer_info.correlation_id,
292        channel: streamer_info.channel,
293        function_id: streamer_info.function_id,
294        request_id: Arc::new(AtomicU64::new(0)),
295        token_provider,
296    };
297
298    (reader, writer)
299}
300
301/// Lock the shared write half and write a single frame. Used both by the
302/// reader (to reply to ping/close control frames) and the writer (to send
303/// requests). Lifting this out of the closure that `read_frame` consumes
304/// makes the future's lifetime relation to `frame` explicit, which the
305/// closure form (with an `async move` block) cannot express on stable Rust.
306async fn write_one(
307    write_half: Arc<Mutex<WsWriteHalf>>,
308    frame: fastwebsockets::Frame<'_>,
309) -> std::result::Result<(), fastwebsockets::WebSocketError> {
310    write_half.lock().await.write_frame(frame).await
311}
312
313/// Read half of the streamer session. Yields one
314/// [`StreamerResponse`] per [`Self::recv`] call. Cloneable through
315/// [`Self::events`] for connection-state observation only; the read half
316/// itself is single-consumer.
317pub struct ReadHalf {
318    read_half: WsReadHalf,
319    write_half: Arc<Mutex<WsWriteHalf>>,
320    events_tx: watch::Sender<ConnectionEvent>,
321}
322
323impl std::fmt::Debug for ReadHalf {
324    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
325        f.debug_struct("ReadHalf").finish_non_exhaustive()
326    }
327}
328
329impl ReadHalf {
330    /// Receive the next streamer frame.
331    ///
332    /// Blocks until a text frame arrives, then parses it into a
333    /// [`StreamerResponse`]. Control frames (ping/pong/close) are handled
334    /// inline, so this method only returns on real protocol traffic.
335    ///
336    /// Errors:
337    /// - [`Error::WebSocket`](crate::Error::WebSocket) on transport
338    ///   failure (the [`ConnectionEvent::Disconnected`] event also fires
339    ///   on the watch channel returned by [`Self::events`]).
340    /// - [`Error::Codec`](crate::Error::Codec) on a malformed frame.
341    pub async fn recv(&mut self) -> Result<StreamerResponse> {
342        let write_half = self.write_half.clone();
343        let mut send_fn = move |frame| write_one(write_half.clone(), frame);
344        loop {
345            let frame = match self.read_half.read_frame(&mut send_fn).await {
346                Ok(f) => f,
347                Err(e) => {
348                    self.events_tx.send_replace(ConnectionEvent::Disconnected(
349                        DisconnectReason::Transport(e.to_string()),
350                    ));
351                    return Err(e.into());
352                }
353            };
354            if frame.opcode == fastwebsockets::OpCode::Text {
355                let raw_response: RawStreamerResponse = match serde_json::from_slice(&frame.payload)
356                {
357                    Ok(r) => r,
358                    Err(e) => {
359                        self.events_tx.send_replace(ConnectionEvent::StreamError {
360                            message: e.to_string(),
361                        });
362                        return Err(Error::Codec {
363                            context: "streamer response frame".to_string(),
364                            reason: e.to_string(),
365                        });
366                    }
367                };
368                let response = StreamerResponse::try_from(raw_response)?;
369                classify_and_emit(&self.events_tx, &response);
370                return Ok(response);
371            }
372        }
373    }
374
375    /// Subscribe to connection-state updates for this session. Receivers
376    /// initially observe the current state (typically `Connected` or, after
377    /// the first login response, `LoggedIn`).
378    ///
379    /// # Examples
380    ///
381    /// Drive a reconnect decision off the state stream. The reconnect loop
382    /// itself lives in consumer code; this side only surfaces the signal.
383    ///
384    /// ```no_run
385    /// use schwab_sdk::streamer::{ConnectionEvent, ReadHalf};
386    ///
387    /// # async fn run(read: &ReadHalf) {
388    /// let mut events = read.events();
389    /// while events.changed().await.is_ok() {
390    ///     match &*events.borrow_and_update() {
391    ///         ConnectionEvent::LoggedIn => println!("session ready"),
392    ///         ConnectionEvent::Disconnected(reason) => {
393    ///             println!("disconnected: {reason:?}");
394    ///             break;
395    ///         }
396    ///         other => println!("state: {other:?}"),
397    ///     }
398    /// }
399    /// # }
400    /// ```
401    pub fn events(&self) -> watch::Receiver<ConnectionEvent> {
402        self.events_tx.subscribe()
403    }
404}
405
406/// Classify a parsed `StreamerResponse` and emit any state changes through
407/// `events_tx`. Errors are not emitted here; the caller handles them.
408fn classify_and_emit(events_tx: &watch::Sender<ConnectionEvent>, response: &StreamerResponse) {
409    let StreamerResponse::Response(responses) = response else {
410        return;
411    };
412    for r in responses {
413        let is_login = r.service == Service::Admin && r.command == StreamerCommand::Login;
414        match r.content.code {
415            ResponseCode::Ok if is_login => {
416                events_tx.send_replace(ConnectionEvent::LoggedIn);
417            }
418            ResponseCode::LoginDenied => {
419                events_tx.send_replace(ConnectionEvent::Disconnected(
420                    DisconnectReason::LoginDenied(r.content.message.clone()),
421                ));
422            }
423            ResponseCode::CloseConnection => {
424                events_tx.send_replace(ConnectionEvent::Disconnected(
425                    DisconnectReason::ServerClose(r.content.message.clone()),
426                ));
427            }
428            ResponseCode::StopStreaming => {
429                events_tx.send_replace(ConnectionEvent::Disconnected(
430                    DisconnectReason::StopStreaming(r.content.message.clone()),
431                ));
432            }
433            _ => {}
434        }
435    }
436}
437
438/// Write half of the streamer session. Sends login/logout/subscribe
439/// frames. Cloneable: all clones share the same underlying socket,
440/// monotonic request-id counter, and [`TokenProvider`], so they can be
441/// moved into independent tasks safely.
442#[derive(Clone)]
443pub struct WriteHalf {
444    write_half: Arc<Mutex<WsWriteHalf>>,
445    customer_id: CustomerId,
446    correlation_id: String,
447    channel: String,
448    function_id: String,
449    request_id: Arc<AtomicU64>,
450    token_provider: Arc<dyn TokenProvider + Send + Sync>,
451}
452
453impl std::fmt::Debug for WriteHalf {
454    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
455        f.debug_struct("WriteHalf")
456            .field("channel", &self.channel)
457            .field("function_id", &self.function_id)
458            .finish_non_exhaustive()
459    }
460}
461
462impl WriteHalf {
463    /// Send the streamer LOGIN frame establishing the session. Must be
464    /// called before any subscribe/add/unsubscribe/view request.
465    /// Returns when the frame has been handed to the socket; the LOGIN
466    /// ack arrives later on the read half as a `response` frame.
467    ///
468    /// The bearer carried by the frame is fetched from the
469    /// [`TokenProvider`] supplied to [`connect`] at the moment `login`
470    /// is called - calling `login` again after the provider observes a
471    /// rotated token will re-LOGIN with the new value.
472    /// [`Error::TokenProvider`] surfaces if the provider fails before
473    /// any frame is written.
474    pub async fn login(&self) -> Result<()> {
475        let auth_token = self.token_provider.access_token().await?;
476        let request = admin::Login {
477            authorization: auth_token,
478            schwab_client_channel: self.channel.clone(),
479            schwab_client_function_id: self.function_id.clone(),
480        };
481        self.send(request).await
482    }
483
484    /// Send the streamer LOGOUT frame.
485    pub async fn logout(&self) -> Result<()> {
486        self.send(admin::Logout).await
487    }
488
489    /// LEVELONE_EQUITIES subscription entry point.
490    pub fn equities(&self) -> SubscribeRequest<'_, level_one::equities::Field> {
491        SubscribeRequest::new(self)
492    }
493
494    /// LEVELONE_OPTIONS subscription entry point.
495    pub fn options(&self) -> SubscribeRequest<'_, level_one::options::Field> {
496        SubscribeRequest::new(self)
497    }
498
499    /// LEVELONE_FUTURES subscription entry point.
500    pub fn futures(&self) -> SubscribeRequest<'_, level_one::futures::Field> {
501        SubscribeRequest::new(self)
502    }
503
504    /// LEVELONE_FUTURES_OPTIONS subscription entry point.
505    pub fn futures_options(&self) -> SubscribeRequest<'_, level_one::futures_options::Field> {
506        SubscribeRequest::new(self)
507    }
508
509    /// LEVELONE_FOREX subscription entry point.
510    pub fn forex(&self) -> SubscribeRequest<'_, level_one::forex::Field> {
511        SubscribeRequest::new(self)
512    }
513
514    /// NYSE_BOOK subscription entry point.
515    pub fn nyse_book(&self) -> SubscribeRequest<'_, book::nyse::Field> {
516        SubscribeRequest::new(self)
517    }
518
519    /// NASDAQ_BOOK subscription entry point.
520    pub fn nasdaq_book(&self) -> SubscribeRequest<'_, book::nasdaq::Field> {
521        SubscribeRequest::new(self)
522    }
523
524    /// OPTIONS_BOOK subscription entry point.
525    pub fn options_book(&self) -> SubscribeRequest<'_, book::options::Field> {
526        SubscribeRequest::new(self)
527    }
528
529    /// CHART_EQUITY subscription entry point.
530    pub fn chart_equity(&self) -> SubscribeRequest<'_, chart::equity::Field> {
531        SubscribeRequest::new(self)
532    }
533
534    /// CHART_FUTURES subscription entry point.
535    pub fn chart_futures(&self) -> SubscribeRequest<'_, chart::futures::Field> {
536        SubscribeRequest::new(self)
537    }
538
539    /// SCREENER_EQUITY subscription entry point.
540    pub fn screener_equity(&self) -> SubscribeRequest<'_, screener::equity::Field> {
541        SubscribeRequest::new(self)
542    }
543
544    /// SCREENER_OPTION subscription entry point.
545    pub fn screener_option(&self) -> SubscribeRequest<'_, screener::option::Field> {
546        SubscribeRequest::new(self)
547    }
548
549    /// ACCT_ACTIVITY subscription entry point.
550    pub fn account_activity(&self) -> SubscribeRequest<'_, account_activity::Field> {
551        SubscribeRequest::new(self)
552    }
553
554    /// Serialize a built [`StreamerRequest`] and write it as one frame.
555    /// Crate-internal: external callers reach this only through the typed
556    /// service accessors above (and through [`Self::login`] /
557    /// [`Self::logout`]).
558    pub(crate) async fn send<T: Into<StreamerRequest>>(&self, request: T) -> Result<()> {
559        let request: StreamerRequest = request.into();
560        let request_id = self.request_id.fetch_add(1, Ordering::Relaxed);
561        let request = RequestPayload {
562            request_id,
563            service: request.service,
564            command: request.command,
565            parameters: request.parameters,
566            schwab_client_customer_id: self.customer_id.clone(),
567            schwab_client_correlation_id: self.correlation_id.clone(),
568        };
569
570        let serialized = serde_json::to_string(&request).map_err(|e| Error::Codec {
571            context: "streamer request envelope".to_string(),
572            reason: e.to_string(),
573        })?;
574        write_one(
575            self.write_half.clone(),
576            fastwebsockets::Frame::text(fastwebsockets::Payload::Borrowed(serialized.as_bytes())),
577        )
578        .await?;
579        Ok(())
580    }
581}
582
583#[cfg(test)]
584mod tests {
585    use super::*;
586    use crate::streamer::events::{ConnectionEvent, DisconnectReason};
587    use crate::streamer::protocol::{ResponseCode, Service, StreamerCommand};
588    use crate::streamer::response::{ResponseContent, ResponsePayload};
589
590    fn response(code: ResponseCode, command: StreamerCommand, msg: &str) -> StreamerResponse {
591        StreamerResponse::Response(vec![ResponsePayload {
592            request_id: 1,
593            service: Service::Admin,
594            timestamp: 1,
595            command,
596            schwab_client_correlation_id: "x".into(),
597            content: ResponseContent {
598                code,
599                message: msg.into(),
600            },
601        }])
602    }
603
604    fn full_streamer_info() -> StreamerInfo {
605        StreamerInfo {
606            streamer_socket_url: Some("wss://streamer-api.schwab.com/ws".into()),
607            schwab_client_customer_id: Some(CustomerId::from("CUSTID")),
608            schwab_client_correlation_id: Some("abc-123".into()),
609            schwab_client_channel: Some("N9".into()),
610            schwab_client_function_id: Some("APIAPP".into()),
611        }
612    }
613
614    #[test]
615    fn validates_complete_streamer_info() {
616        let validated =
617            ValidatedStreamerInfo::try_from(full_streamer_info()).expect("complete info validates");
618        assert_eq!(validated.socket_url, "wss://streamer-api.schwab.com/ws");
619        assert_eq!(validated.correlation_id, "abc-123");
620        assert_eq!(validated.channel, "N9");
621        assert_eq!(validated.function_id, "APIAPP");
622    }
623
624    #[test]
625    fn missing_socket_url_reports_field() {
626        let mut info = full_streamer_info();
627        info.streamer_socket_url = None;
628        match ValidatedStreamerInfo::try_from(info) {
629            Err(Error::InvalidPreference { field, .. }) => {
630                assert_eq!(field, "streamerSocketUrl");
631            }
632            other => panic!("expected InvalidPreference, got {other:?}"),
633        }
634    }
635
636    #[test]
637    fn missing_customer_id_reports_field() {
638        let mut info = full_streamer_info();
639        info.schwab_client_customer_id = None;
640        match ValidatedStreamerInfo::try_from(info) {
641            Err(Error::InvalidPreference { field, .. }) => {
642                assert_eq!(field, "schwabClientCustomerId");
643            }
644            other => panic!("expected InvalidPreference, got {other:?}"),
645        }
646    }
647
648    #[test]
649    fn missing_correlation_id_reports_field() {
650        let mut info = full_streamer_info();
651        info.schwab_client_correlation_id = None;
652        match ValidatedStreamerInfo::try_from(info) {
653            Err(Error::InvalidPreference { field, .. }) => {
654                assert_eq!(field, "schwabClientCorrelId");
655            }
656            other => panic!("expected InvalidPreference, got {other:?}"),
657        }
658    }
659
660    #[test]
661    fn missing_channel_reports_field() {
662        let mut info = full_streamer_info();
663        info.schwab_client_channel = None;
664        match ValidatedStreamerInfo::try_from(info) {
665            Err(Error::InvalidPreference { field, .. }) => {
666                assert_eq!(field, "schwabClientChannel");
667            }
668            other => panic!("expected InvalidPreference, got {other:?}"),
669        }
670    }
671
672    #[test]
673    fn missing_function_id_reports_field() {
674        let mut info = full_streamer_info();
675        info.schwab_client_function_id = None;
676        match ValidatedStreamerInfo::try_from(info) {
677            Err(Error::InvalidPreference { field, .. }) => {
678                assert_eq!(field, "schwabClientFunctionId");
679            }
680            other => panic!("expected InvalidPreference, got {other:?}"),
681        }
682    }
683
684    #[test]
685    fn login_ok_emits_logged_in() {
686        let (tx, mut rx) = watch::channel(ConnectionEvent::Connected);
687        classify_and_emit(&tx, &response(ResponseCode::Ok, StreamerCommand::Login, ""));
688        assert!(rx.has_changed().unwrap());
689        assert_eq!(*rx.borrow_and_update(), ConnectionEvent::LoggedIn);
690    }
691
692    #[test]
693    fn login_denied_emits_disconnected() {
694        let (tx, mut rx) = watch::channel(ConnectionEvent::Connected);
695        classify_and_emit(
696            &tx,
697            &response(
698                ResponseCode::LoginDenied,
699                StreamerCommand::Login,
700                "token expired",
701            ),
702        );
703        match rx.borrow_and_update().clone() {
704            ConnectionEvent::Disconnected(DisconnectReason::LoginDenied(msg)) => {
705                assert!(msg.contains("token expired"), "msg = {msg}");
706            }
707            other => panic!("expected Disconnected(LoginDenied), got {other:?}"),
708        }
709    }
710
711    #[test]
712    fn close_connection_emits_disconnected_server_close() {
713        let (tx, mut rx) = watch::channel(ConnectionEvent::Connected);
714        classify_and_emit(
715            &tx,
716            &response(
717                ResponseCode::CloseConnection,
718                StreamerCommand::Subs,
719                "max connections",
720            ),
721        );
722        assert!(matches!(
723            *rx.borrow_and_update(),
724            ConnectionEvent::Disconnected(DisconnectReason::ServerClose(_))
725        ));
726    }
727
728    #[test]
729    fn stop_streaming_emits_disconnected_stop_streaming() {
730        let (tx, mut rx) = watch::channel(ConnectionEvent::Connected);
731        classify_and_emit(
732            &tx,
733            &response(
734                ResponseCode::StopStreaming,
735                StreamerCommand::Subs,
736                "inactivity",
737            ),
738        );
739        assert!(matches!(
740            *rx.borrow_and_update(),
741            ConnectionEvent::Disconnected(DisconnectReason::StopStreaming(_))
742        ));
743    }
744
745    #[test]
746    fn non_admin_ok_response_does_not_emit() {
747        let (tx, rx) = watch::channel(ConnectionEvent::Connected);
748        // SUBS success on LEVELONE_EQUITIES should not flip to LoggedIn.
749        let r = StreamerResponse::Response(vec![ResponsePayload {
750            request_id: 1,
751            service: Service::LevelOneEquities,
752            timestamp: 1,
753            command: StreamerCommand::Subs,
754            schwab_client_correlation_id: "x".into(),
755            content: ResponseContent {
756                code: ResponseCode::Ok,
757                message: "".into(),
758            },
759        }]);
760        classify_and_emit(&tx, &r);
761        // No change observed.
762        assert!(!rx.has_changed().unwrap());
763    }
764
765    #[test]
766    fn data_payload_does_not_emit() {
767        let (tx, rx) = watch::channel(ConnectionEvent::Connected);
768        let r = StreamerResponse::Notify(vec![]);
769        classify_and_emit(&tx, &r);
770        assert!(!rx.has_changed().unwrap());
771    }
772
773    #[test]
774    fn wss_is_accepted_in_both_modes() {
775        assert_eq!(
776            check_websocket_scheme(Some("wss"), false).unwrap(),
777            WsTransport::Tls
778        );
779        assert_eq!(
780            check_websocket_scheme(Some("wss"), true).unwrap(),
781            WsTransport::Tls
782        );
783    }
784
785    #[test]
786    fn ws_is_rejected_when_insecure_disallowed() {
787        match check_websocket_scheme(Some("ws"), false) {
788            Err(WebSocketError::UnsupportedScheme(scheme)) => assert_eq!(scheme, "ws"),
789            other => panic!("expected UnsupportedScheme(ws), got {other:?}"),
790        }
791    }
792
793    #[test]
794    fn ws_is_accepted_when_insecure_permitted() {
795        assert_eq!(
796            check_websocket_scheme(Some("ws"), true).unwrap(),
797            WsTransport::Plain
798        );
799    }
800
801    #[test]
802    fn other_schemes_are_always_rejected() {
803        for scheme in ["http", "https", "ftp", "file", ""] {
804            assert!(
805                matches!(
806                    check_websocket_scheme(Some(scheme), true).unwrap_err(),
807                    WebSocketError::UnsupportedScheme(_)
808                ),
809                "scheme {scheme:?} should be rejected with insecure mode on"
810            );
811            assert!(
812                matches!(
813                    check_websocket_scheme(Some(scheme), false).unwrap_err(),
814                    WebSocketError::UnsupportedScheme(_)
815                ),
816                "scheme {scheme:?} should be rejected with insecure mode off"
817            );
818        }
819    }
820
821    #[test]
822    fn no_scheme_is_rejected() {
823        assert!(matches!(
824            check_websocket_scheme(None, true).unwrap_err(),
825            WebSocketError::UnsupportedScheme(s) if s.is_empty()
826        ));
827        assert!(matches!(
828            check_websocket_scheme(None, false).unwrap_err(),
829            WebSocketError::UnsupportedScheme(s) if s.is_empty()
830        ));
831    }
832
833    #[test]
834    fn case_sensitive_scheme_match() {
835        assert!(check_websocket_scheme(Some("Wss"), false).is_err(),);
836        assert!(check_websocket_scheme(Some("WSS"), false).is_err(),);
837    }
838
839    #[test]
840    fn is_retryable_classifies_transport_failures_as_retryable() {
841        // TCP / TLS / handshake / runtime errors all warrant a reconnect.
842        assert!(WebSocketError::Connect(std::io::Error::other("x")).is_retryable());
843        assert!(WebSocketError::TlsStream(std::io::Error::other("x")).is_retryable());
844        assert!(
845            WebSocketError::Handshake(fastwebsockets::WebSocketError::ConnectionClosed)
846                .is_retryable()
847        );
848        assert!(
849            WebSocketError::Runtime(fastwebsockets::WebSocketError::ConnectionClosed)
850                .is_retryable()
851        );
852    }
853
854    #[test]
855    fn is_retryable_classifies_config_failures_as_terminal() {
856        // These will fail identically on retry; callers must not loop.
857        assert!(!WebSocketError::MissingHost.is_retryable());
858        assert!(!WebSocketError::UnsupportedScheme("ws".to_string()).is_retryable());
859        assert!(
860            !WebSocketError::InvalidDomain(
861                rustls_pki_types::ServerName::try_from("not a dns name").unwrap_err()
862            )
863            .is_retryable()
864        );
865        // `BuildRequest` and `TlsConfig` carry foreign error types that
866        // are awkward to fabricate in a unit test; the exhaustive match
867        // in `is_retryable` keeps them classified alongside the others
868        // here, and the surrounding `match` would fail to compile if a
869        // new variant were added without an explicit decision.
870    }
871
872    #[test]
873    fn error_is_retryable_delegates_to_websocket_error() {
874        // The parent `Error::is_retryable` used to blanket-return `true`
875        // for every `Error::WebSocket`; verify the per-variant path now
876        // surfaces a terminal config error as terminal.
877        let terminal = Error::WebSocket(WebSocketError::UnsupportedScheme("ws".to_string()));
878        assert!(!terminal.is_retryable());
879        let transient = Error::WebSocket(WebSocketError::Connect(std::io::Error::other(
880            "conn refused",
881        )));
882        assert!(transient.is_retryable());
883    }
884}