1use std::collections::VecDeque;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::ready;
5
6use futures::stream::Stream;
7use futures::task::{Context, Poll};
8use futures::{SinkExt, StreamExt};
9use tokio_tungstenite::tungstenite::Message as WsMessage;
10use tokio_tungstenite::MaybeTlsStream;
11use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, WebSocketStream};
12
13use chromiumoxide_cdp::cdp::browser_protocol::target::SessionId;
14use chromiumoxide_types::{CallId, EventMessage, Message, MethodCall, MethodId};
15
16use crate::error::CdpError;
17use crate::error::Result;
18
19type ConnectStream = MaybeTlsStream<tokio::net::TcpStream>;
20
21#[must_use = "streams do nothing unless polled"]
23#[derive(Debug)]
24pub struct Connection<T: EventMessage> {
25 pending_commands: VecDeque<MethodCall>,
27 ws: WebSocketStream<ConnectStream>,
29 next_id: usize,
31 needs_flush: bool,
32 pending_flush: Option<MethodCall>,
34 _marker: PhantomData<T>,
35}
36
37lazy_static::lazy_static! {
38 static ref DISABLE_NAGLE: bool = match std::env::var("DISABLE_NAGLE") {
40 Ok(disable_nagle) => disable_nagle == "true",
41 _ => true
42 };
43 static ref WEBSOCKET_DEFAULTS: bool = match std::env::var("WEBSOCKET_DEFAULTS") {
45 Ok(d) => d == "true",
46 _ => false
47 };
48}
49
50impl<T: EventMessage + Unpin> Connection<T> {
51 pub async fn connect(debug_ws_url: impl AsRef<str>) -> Result<Self> {
52 let mut config = WebSocketConfig::default();
53
54 if *WEBSOCKET_DEFAULTS == false {
55 config.max_message_size = None;
56 config.max_frame_size = None;
57 }
58
59 let (ws, _) = tokio_tungstenite::connect_async_with_config(
60 debug_ws_url.as_ref(),
61 Some(config),
62 *DISABLE_NAGLE,
63 )
64 .await?;
65
66 Ok(Self {
67 pending_commands: Default::default(),
68 ws,
69 next_id: 0,
70 needs_flush: false,
71 pending_flush: None,
72 _marker: Default::default(),
73 })
74 }
75}
76
77impl<T: EventMessage> Connection<T> {
78 fn next_call_id(&mut self) -> CallId {
79 let id = CallId::new(self.next_id);
80 self.next_id = self.next_id.wrapping_add(1);
81 id
82 }
83
84 pub fn submit_command(
87 &mut self,
88 method: MethodId,
89 session_id: Option<SessionId>,
90 params: serde_json::Value,
91 ) -> serde_json::Result<CallId> {
92 let id = self.next_call_id();
93 let call = MethodCall {
94 id,
95 method,
96 session_id: session_id.map(Into::into),
97 params,
98 };
99 self.pending_commands.push_back(call);
100 Ok(id)
101 }
102
103 fn start_send_next(&mut self, cx: &mut Context<'_>) -> Result<()> {
106 if self.needs_flush {
107 if let Poll::Ready(Ok(())) = self.ws.poll_flush_unpin(cx) {
108 self.needs_flush = false;
109 }
110 }
111 if self.pending_flush.is_none() && !self.needs_flush {
112 if let Some(cmd) = self.pending_commands.pop_front() {
113 tracing::trace!("Sending {:?}", cmd);
114 let msg = serde_json::to_string(&cmd)?;
115 self.ws.start_send_unpin(msg.into())?;
116 self.pending_flush = Some(cmd);
117 }
118 }
119 Ok(())
120 }
121}
122
123impl<T: EventMessage + Unpin> Stream for Connection<T> {
124 type Item = Result<Message<T>>;
125
126 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
127 let pin = self.get_mut();
128
129 loop {
130 if let Err(err) = pin.start_send_next(cx) {
132 return Poll::Ready(Some(Err(err)));
133 }
134
135 if let Some(call) = pin.pending_flush.take() {
137 if pin.ws.poll_ready_unpin(cx).is_ready() {
138 pin.needs_flush = true;
139 continue;
141 } else {
142 pin.pending_flush = Some(call);
143 }
144 }
145
146 break;
147 }
148
149 match ready!(pin.ws.poll_next_unpin(cx)) {
151 Some(Ok(WsMessage::Text(text))) => {
152 let ready = match crate::serde_json::from_str::<Message<T>>(&text) {
153 Ok(msg) => {
154 tracing::trace!("Received {:?}", msg);
155 Ok(msg)
156 }
157 Err(err) => {
158 tracing::error!(target: "chromiumoxide::conn::raw_ws::parse_errors", msg = text.to_string(), "Failed to parse raw WS message {err}");
159 Err(err.into())
160 }
161 };
162
163 Poll::Ready(Some(ready))
164 }
165 Some(Ok(WsMessage::Binary(mut text))) => {
166 let ready = match crate::serde_json::from_slice::<Message<T>>(&mut text) {
167 Ok(msg) => {
168 tracing::trace!("Received {:?}", msg);
169 Ok(msg)
170 }
171 Err(err) => {
172 tracing::error!(target: "chromiumoxide::conn::raw_ws::parse_errors", "Failed to parse raw WS message {err}");
173 Err(err.into())
174 }
175 };
176
177 Poll::Ready(Some(ready))
178 }
179 Some(Ok(WsMessage::Close(_))) => Poll::Ready(None),
180 Some(Ok(WsMessage::Ping(_))) | Some(Ok(WsMessage::Pong(_))) => {
182 cx.waker().wake_by_ref();
183 Poll::Pending
184 }
185 Some(Ok(msg)) => Poll::Ready(Some(Err(CdpError::UnexpectedWsMessage(msg)))),
186 Some(Err(err)) => Poll::Ready(Some(Err(CdpError::Ws(err)))),
187 None => {
188 Poll::Ready(None)
190 }
191 }
192 }
193}