chromiumoxide/
conn.rs

1use std::collections::VecDeque;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::ready;
5
6use futures::stream::Stream;
7use futures::task::{Context, Poll};
8use futures::{SinkExt, StreamExt};
9use tokio_tungstenite::tungstenite::Message as WsMessage;
10use tokio_tungstenite::MaybeTlsStream;
11use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, WebSocketStream};
12
13use chromiumoxide_cdp::cdp::browser_protocol::target::SessionId;
14use chromiumoxide_types::{CallId, EventMessage, Message, MethodCall, MethodId};
15
16use crate::error::CdpError;
17use crate::error::Result;
18
19type ConnectStream = MaybeTlsStream<tokio::net::TcpStream>;
20
21/// Exchanges the messages with the websocket
22#[must_use = "streams do nothing unless polled"]
23#[derive(Debug)]
24pub struct Connection<T: EventMessage> {
25    /// Queue of commands to send.
26    pending_commands: VecDeque<MethodCall>,
27    /// The websocket of the chromium instance
28    ws: WebSocketStream<ConnectStream>,
29    /// The identifier for a specific command
30    next_id: usize,
31    needs_flush: bool,
32    /// The message that is currently being proceessed
33    pending_flush: Option<MethodCall>,
34    _marker: PhantomData<T>,
35}
36
37lazy_static::lazy_static! {
38    /// Nagle's algorithm disabled?
39    static ref DISABLE_NAGLE: bool = match std::env::var("DISABLE_NAGLE") {
40        Ok(disable_nagle) => disable_nagle == "true",
41        _ => true
42    };
43    /// Websocket config defaults
44    static ref WEBSOCKET_DEFAULTS: bool = match std::env::var("WEBSOCKET_DEFAULTS") {
45        Ok(d) => d == "true",
46        _ => false
47    };
48}
49
50impl<T: EventMessage + Unpin> Connection<T> {
51    pub async fn connect(debug_ws_url: impl AsRef<str>) -> Result<Self> {
52        let mut config = WebSocketConfig::default();
53
54        if *WEBSOCKET_DEFAULTS == false {
55            config.max_message_size = None;
56            config.max_frame_size = None;
57        }
58
59        let (ws, _) = tokio_tungstenite::connect_async_with_config(
60            debug_ws_url.as_ref(),
61            Some(config),
62            *DISABLE_NAGLE,
63        )
64        .await?;
65
66        Ok(Self {
67            pending_commands: Default::default(),
68            ws,
69            next_id: 0,
70            needs_flush: false,
71            pending_flush: None,
72            _marker: Default::default(),
73        })
74    }
75}
76
77impl<T: EventMessage> Connection<T> {
78    fn next_call_id(&mut self) -> CallId {
79        let id = CallId::new(self.next_id);
80        self.next_id = self.next_id.wrapping_add(1);
81        id
82    }
83
84    /// Queue in the command to send over the socket and return the id for this
85    /// command
86    pub fn submit_command(
87        &mut self,
88        method: MethodId,
89        session_id: Option<SessionId>,
90        params: serde_json::Value,
91    ) -> serde_json::Result<CallId> {
92        let id = self.next_call_id();
93        let call = MethodCall {
94            id,
95            method,
96            session_id: session_id.map(Into::into),
97            params,
98        };
99        self.pending_commands.push_back(call);
100        Ok(id)
101    }
102
103    /// flush any processed message and start sending the next over the conn
104    /// sink
105    fn start_send_next(&mut self, cx: &mut Context<'_>) -> Result<()> {
106        if self.needs_flush {
107            if let Poll::Ready(Ok(())) = self.ws.poll_flush_unpin(cx) {
108                self.needs_flush = false;
109            }
110        }
111        if self.pending_flush.is_none() && !self.needs_flush {
112            if let Some(cmd) = self.pending_commands.pop_front() {
113                tracing::trace!("Sending {:?}", cmd);
114                let msg = serde_json::to_string(&cmd)?;
115                self.ws.start_send_unpin(msg.into())?;
116                self.pending_flush = Some(cmd);
117            }
118        }
119        Ok(())
120    }
121}
122
123impl<T: EventMessage + Unpin> Stream for Connection<T> {
124    type Item = Result<Message<T>>;
125
126    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
127        let pin = self.get_mut();
128
129        loop {
130            // queue in the next message if not currently flushing
131            if let Err(err) = pin.start_send_next(cx) {
132                return Poll::Ready(Some(Err(err)));
133            }
134
135            // send the message
136            if let Some(call) = pin.pending_flush.take() {
137                if pin.ws.poll_ready_unpin(cx).is_ready() {
138                    pin.needs_flush = true;
139                    // try another flush
140                    continue;
141                } else {
142                    pin.pending_flush = Some(call);
143                }
144            }
145
146            break;
147        }
148
149        // read from the ws
150        match ready!(pin.ws.poll_next_unpin(cx)) {
151            Some(Ok(WsMessage::Text(text))) => {
152                let ready = match crate::serde_json::from_str::<Message<T>>(&text) {
153                    Ok(msg) => {
154                        tracing::trace!("Received {:?}", msg);
155                        Ok(msg)
156                    }
157                    Err(err) => {
158                        tracing::error!(target: "chromiumoxide::conn::raw_ws::parse_errors", msg = text.to_string(), "Failed to parse raw WS message {err}");
159                        Err(err.into())
160                    }
161                };
162
163                Poll::Ready(Some(ready))
164            }
165            Some(Ok(WsMessage::Binary(mut text))) => {
166                let ready = match crate::serde_json::from_slice::<Message<T>>(&mut text) {
167                    Ok(msg) => {
168                        tracing::trace!("Received {:?}", msg);
169                        Ok(msg)
170                    }
171                    Err(err) => {
172                        tracing::error!(target: "chromiumoxide::conn::raw_ws::parse_errors", "Failed to parse raw WS message {err}");
173                        Err(err.into())
174                    }
175                };
176
177                Poll::Ready(Some(ready))
178            }
179            Some(Ok(WsMessage::Close(_))) => Poll::Ready(None),
180            // ignore ping and pong
181            Some(Ok(WsMessage::Ping(_))) | Some(Ok(WsMessage::Pong(_))) => {
182                cx.waker().wake_by_ref();
183                Poll::Pending
184            }
185            Some(Ok(msg)) => Poll::Ready(Some(Err(CdpError::UnexpectedWsMessage(msg)))),
186            Some(Err(err)) => Poll::Ready(Some(Err(CdpError::Ws(err)))),
187            None => {
188                // ws connection closed
189                Poll::Ready(None)
190            }
191        }
192    }
193}