chromiumoxide/
conn.rs

1use std::collections::VecDeque;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::ready;
5
6use futures::stream::Stream;
7use futures::task::{Context, Poll};
8use futures::{SinkExt, StreamExt};
9use tokio::io::AsyncWriteExt;
10use tokio_tungstenite::tungstenite::Message as WsMessage;
11use tokio_tungstenite::MaybeTlsStream;
12use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, WebSocketStream};
13
14use chromiumoxide_cdp::cdp::browser_protocol::target::SessionId;
15use chromiumoxide_types::{CallId, EventMessage, Message, MethodCall, MethodId};
16
17use crate::error::CdpError;
18use crate::error::Result;
19
20type ConnectStream = MaybeTlsStream<tokio::net::TcpStream>;
21
22/// Exchanges the messages with the websocket
23#[must_use = "streams do nothing unless polled"]
24#[derive(Debug)]
25pub struct Connection<T: EventMessage> {
26    /// Queue of commands to send.
27    pending_commands: VecDeque<MethodCall>,
28    /// The websocket of the chromium instance
29    ws: WebSocketStream<ConnectStream>,
30    /// The identifier for a specific command
31    next_id: usize,
32    needs_flush: bool,
33    /// The message that is currently being proceessed
34    pending_flush: Option<MethodCall>,
35    _marker: PhantomData<T>,
36}
37
38lazy_static::lazy_static! {
39    /// Nagle's algorithm disabled?
40    static ref DISABLE_NAGLE: bool = match std::env::var("DISABLE_NAGLE") {
41        Ok(disable_nagle) => disable_nagle == "true",
42        _ => true
43    };
44    /// Websocket config defaults
45    static ref WEBSOCKET_DEFAULTS: bool = match std::env::var("WEBSOCKET_DEFAULTS") {
46        Ok(d) => d == "true",
47        _ => false
48    };
49}
50
51impl<T: EventMessage + Unpin> Connection<T> {
52    pub async fn connect(debug_ws_url: impl AsRef<str>) -> Result<Self> {
53        let mut config = WebSocketConfig::default();
54
55        if *WEBSOCKET_DEFAULTS == false {
56            config.max_message_size = None;
57            config.max_frame_size = None;
58        }
59
60        let (ws, _) = tokio_tungstenite::connect_async_with_config(
61            debug_ws_url.as_ref(),
62            Some(config),
63            *DISABLE_NAGLE,
64        )
65        .await?;
66
67        Ok(Self {
68            pending_commands: Default::default(),
69            ws,
70            next_id: 0,
71            needs_flush: false,
72            pending_flush: None,
73            _marker: Default::default(),
74        })
75    }
76}
77
78impl<T: EventMessage> Connection<T> {
79    fn next_call_id(&mut self) -> CallId {
80        let id = CallId::new(self.next_id);
81        self.next_id = self.next_id.wrapping_add(1);
82        id
83    }
84
85    /// Queue in the command to send over the socket and return the id for this
86    /// command
87    pub fn submit_command(
88        &mut self,
89        method: MethodId,
90        session_id: Option<SessionId>,
91        params: serde_json::Value,
92    ) -> serde_json::Result<CallId> {
93        let id = self.next_call_id();
94        let call = MethodCall {
95            id,
96            method,
97            session_id: session_id.map(Into::into),
98            params,
99        };
100        self.pending_commands.push_back(call);
101        Ok(id)
102    }
103
104    /// flush any processed message and start sending the next over the conn
105    /// sink
106    fn start_send_next(&mut self, cx: &mut Context<'_>) -> Result<()> {
107        if self.needs_flush {
108            if let Poll::Ready(Ok(())) = self.ws.poll_flush_unpin(cx) {
109                self.needs_flush = false;
110            }
111        }
112        if self.pending_flush.is_none() && !self.needs_flush {
113            if let Some(cmd) = self.pending_commands.pop_front() {
114                tracing::trace!("Sending {:?}", cmd);
115                let msg = serde_json::to_string(&cmd)?;
116                self.ws.start_send_unpin(msg.into())?;
117                self.pending_flush = Some(cmd);
118            }
119        }
120        Ok(())
121    }
122}
123
124impl<T: EventMessage + Unpin> Stream for Connection<T> {
125    type Item = Result<Message<T>>;
126
127    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
128        let pin = self.get_mut();
129
130        loop {
131            // queue in the next message if not currently flushing
132            if let Err(err) = pin.start_send_next(cx) {
133                return Poll::Ready(Some(Err(err)));
134            }
135
136            // send the message
137            if let Some(call) = pin.pending_flush.take() {
138                if pin.ws.poll_ready_unpin(cx).is_ready() {
139                    pin.needs_flush = true;
140                    // try another flush
141                    continue;
142                } else {
143                    pin.pending_flush = Some(call);
144                }
145            }
146
147            break;
148        }
149
150        // read from the ws
151        match ready!(pin.ws.poll_next_unpin(cx)) {
152            Some(Ok(WsMessage::Text(text))) => {
153                let ready = match crate::serde_json::from_str::<Message<T>>(&text) {
154                    Ok(msg) => {
155                        tracing::trace!("Received {:?}", msg);
156                        Ok(msg)
157                    }
158                    Err(err) => {
159                        tracing::error!(target: "chromiumoxide::conn::raw_ws::parse_errors", msg = text.to_string(), "Failed to parse raw WS message {err}");
160                        Err(err.into())
161                    }
162                };
163
164                Poll::Ready(Some(ready))
165            }
166            Some(Ok(WsMessage::Binary(mut text))) => {
167                let ready = match crate::serde_json::from_slice::<Message<T>>(&mut text) {
168                    Ok(msg) => {
169                        tracing::trace!("Received {:?}", msg);
170                        Ok(msg)
171                    }
172                    Err(err) => {
173                        tracing::error!(target: "chromiumoxide::conn::raw_ws::parse_errors", "Failed to parse raw WS message {err}");
174                        Err(err.into())
175                    }
176                };
177
178                Poll::Ready(Some(ready))
179            }
180            Some(Ok(WsMessage::Close(_))) => Poll::Ready(None),
181            // ignore ping and pong
182            Some(Ok(WsMessage::Ping(_))) | Some(Ok(WsMessage::Pong(_))) => {
183                cx.waker().wake_by_ref();
184                Poll::Pending
185            }
186            Some(Ok(msg)) => Poll::Ready(Some(Err(CdpError::UnexpectedWsMessage(msg)))),
187            Some(Err(err)) => Poll::Ready(Some(Err(CdpError::Ws(err)))),
188            None => {
189                // ws connection closed
190                Poll::Ready(None)
191            }
192        }
193    }
194}