1use crate::command::{WebDriverCommand, WebDriverMessage};
6use crate::error::{ErrorStatus, WebDriverError, WebDriverResult};
7use crate::httpapi::{
8 standard_routes, Route, VoidWebDriverExtensionRoute, WebDriverExtensionRoute,
9};
10use crate::response::{CloseWindowResponse, WebDriverResponse};
11use crate::Parameters;
12use bytes::Bytes;
13use http::{Method, StatusCode};
14use std::marker::PhantomData;
15use std::net::{SocketAddr, TcpListener as StdTcpListener};
16use std::sync::mpsc::{channel, Receiver, Sender};
17use std::sync::{Arc, Mutex};
18use std::thread;
19use tokio::net::TcpListener;
20use tokio_stream::wrappers::TcpListenerStream;
21use url::{Host, Url};
22use warp::{Buf, Filter, Rejection};
23
24#[allow(dead_code)]
26enum DispatchMessage<U: WebDriverExtensionRoute> {
27 HandleWebDriver(
28 WebDriverMessage<U>,
29 Sender<WebDriverResult<WebDriverResponse>>,
30 ),
31 Quit,
32}
33
34#[derive(Clone, Debug, PartialEq)]
35pub enum SessionTeardownKind {
38 Deleted,
40 NotDeleted,
42}
43
44#[derive(Clone, Debug, PartialEq)]
45pub struct Session {
46 pub id: String,
47}
48
49impl Session {
50 fn new(id: String) -> Session {
51 Session { id }
52 }
53}
54
55pub trait WebDriverHandler<U: WebDriverExtensionRoute = VoidWebDriverExtensionRoute>: Send {
56 fn handle_command(
57 &mut self,
58 session: &Option<Session>,
59 msg: WebDriverMessage<U>,
60 ) -> WebDriverResult<WebDriverResponse>;
61 fn teardown_session(&mut self, kind: SessionTeardownKind);
62}
63
64#[derive(Debug)]
65struct Dispatcher<T: WebDriverHandler<U>, U: WebDriverExtensionRoute> {
66 handler: T,
67 session: Option<Session>,
68 extension_type: PhantomData<U>,
69}
70
71impl<T: WebDriverHandler<U>, U: WebDriverExtensionRoute> Dispatcher<T, U> {
72 fn new(handler: T) -> Dispatcher<T, U> {
73 Dispatcher {
74 handler,
75 session: None,
76 extension_type: PhantomData,
77 }
78 }
79
80 fn run(&mut self, msg_chan: &Receiver<DispatchMessage<U>>) {
81 loop {
82 match msg_chan.recv() {
83 Ok(DispatchMessage::HandleWebDriver(msg, resp_chan)) => {
84 let resp = match self.check_session(&msg) {
85 Ok(_) => self.handler.handle_command(&self.session, msg),
86 Err(e) => Err(e),
87 };
88
89 match resp {
90 Ok(WebDriverResponse::NewSession(ref new_session)) => {
91 self.session = Some(Session::new(new_session.session_id.clone()));
92 }
93 Ok(WebDriverResponse::CloseWindow(CloseWindowResponse(ref handles))) => {
94 if handles.is_empty() {
95 debug!("Last window was closed, deleting session");
96 self.teardown_session(SessionTeardownKind::NotDeleted);
99 }
100 }
101 Ok(WebDriverResponse::DeleteSession) => {
102 self.teardown_session(SessionTeardownKind::Deleted);
103 }
104 Err(ref x) if x.delete_session => {
105 self.teardown_session(SessionTeardownKind::NotDeleted)
107 }
108 _ => {}
109 }
110
111 if resp_chan.send(resp).is_err() {
112 error!("Sending response to the main thread failed");
113 };
114 }
115 Ok(DispatchMessage::Quit) => break,
116 Err(e) => panic!("Error receiving message in handler: {:?}", e),
117 }
118 }
119 }
120
121 fn teardown_session(&mut self, kind: SessionTeardownKind) {
122 debug!("Teardown session");
123 let final_kind = match kind {
124 SessionTeardownKind::NotDeleted if self.session.is_some() => {
125 let delete_session = WebDriverMessage {
126 session_id: Some(
127 self.session
128 .as_ref()
129 .expect("Failed to get session")
130 .id
131 .clone(),
132 ),
133 command: WebDriverCommand::DeleteSession,
134 };
135 match self.handler.handle_command(&self.session, delete_session) {
136 Ok(_) => SessionTeardownKind::Deleted,
137 Err(_) => SessionTeardownKind::NotDeleted,
138 }
139 }
140 _ => kind,
141 };
142 self.handler.teardown_session(final_kind);
143 self.session = None;
144 }
145
146 fn check_session(&self, msg: &WebDriverMessage<U>) -> WebDriverResult<()> {
147 match msg.session_id {
148 Some(ref msg_session_id) => match self.session {
149 Some(ref existing_session) => {
150 if existing_session.id != *msg_session_id {
151 Err(WebDriverError::new(
152 ErrorStatus::InvalidSessionId,
153 format!("Got unexpected session id {}", msg_session_id),
154 ))
155 } else {
156 Ok(())
157 }
158 }
159 None => Ok(()),
160 },
161 None => {
162 match self.session {
163 Some(_) => {
164 match msg.command {
165 WebDriverCommand::Status => Ok(()),
166 WebDriverCommand::NewSession(_) => Err(WebDriverError::new(
167 ErrorStatus::SessionNotCreated,
168 "Session is already started",
169 )),
170 _ => {
171 error!("Got a message with no session id");
173 Err(WebDriverError::new(
174 ErrorStatus::UnknownError,
175 "Got a command with no session?!",
176 ))
177 }
178 }
179 }
180 None => match msg.command {
181 WebDriverCommand::NewSession(_) => Ok(()),
182 WebDriverCommand::Status => Ok(()),
183 _ => Err(WebDriverError::new(
184 ErrorStatus::InvalidSessionId,
185 "Tried to run a command before creating a session",
186 )),
187 },
188 }
189 }
190 }
191 }
192}
193
194pub struct Listener {
195 guard: Option<thread::JoinHandle<()>>,
196 pub socket: SocketAddr,
197}
198
199impl Drop for Listener {
200 fn drop(&mut self) {
201 let _ = self.guard.take().map(|j| j.join());
202 }
203}
204
205pub fn start<T, U>(
206 mut address: SocketAddr,
207 allow_hosts: Vec<Host>,
208 allow_origins: Vec<Url>,
209 handler: T,
210 extension_routes: Vec<(Method, &'static str, U)>,
211) -> ::std::io::Result<Listener>
212where
213 T: 'static + WebDriverHandler<U>,
214 U: 'static + WebDriverExtensionRoute + Send + Sync,
215{
216 let listener = StdTcpListener::bind(address)?;
217 listener.set_nonblocking(true)?;
218 let addr = listener.local_addr()?;
219 if address.port() == 0 {
220 address.set_port(addr.port())
223 }
224 let (msg_send, msg_recv) = channel();
225
226 let builder = thread::Builder::new().name("webdriver server".to_string());
227 let handle = builder.spawn(move || {
228 let rt = tokio::runtime::Builder::new_current_thread()
229 .enable_io()
230 .build()
231 .unwrap();
232 let listener = rt.block_on(async { TcpListener::from_std(listener).unwrap() });
233 let wroutes = build_warp_routes(
234 address,
235 allow_hosts,
236 allow_origins,
237 &extension_routes,
238 msg_send.clone(),
239 );
240 let fut = warp::serve(wroutes).run_incoming(TcpListenerStream::new(listener));
241 rt.block_on(fut);
242 })?;
243
244 let builder = thread::Builder::new().name("webdriver dispatcher".to_string());
245 builder.spawn(move || {
246 let mut dispatcher = Dispatcher::new(handler);
247 dispatcher.run(&msg_recv);
248 })?;
249
250 Ok(Listener {
251 guard: Some(handle),
252 socket: addr,
253 })
254}
255
256fn build_warp_routes<U: 'static + WebDriverExtensionRoute + Send + Sync>(
257 address: SocketAddr,
258 allow_hosts: Vec<Host>,
259 allow_origins: Vec<Url>,
260 ext_routes: &[(Method, &'static str, U)],
261 chan: Sender<DispatchMessage<U>>,
262) -> impl Filter<Extract = (impl warp::Reply,), Error = Rejection> + Clone {
263 let chan = Arc::new(Mutex::new(chan));
264 let mut std_routes = standard_routes::<U>();
265
266 let (method, path, res) = std_routes.pop().unwrap();
267 trace!("Build standard route for {path}");
268 let mut wroutes = build_route(
269 address,
270 allow_hosts.clone(),
271 allow_origins.clone(),
272 method,
273 path,
274 res,
275 chan.clone(),
276 );
277
278 for (method, path, res) in std_routes {
279 trace!("Build standard route for {path}");
280 wroutes = wroutes
281 .or(build_route(
282 address,
283 allow_hosts.clone(),
284 allow_origins.clone(),
285 method,
286 path,
287 res.clone(),
288 chan.clone(),
289 ))
290 .unify()
291 .boxed()
292 }
293
294 for (method, path, res) in ext_routes {
295 trace!("Build vendor route for {path}");
296 wroutes = wroutes
297 .or(build_route(
298 address,
299 allow_hosts.clone(),
300 allow_origins.clone(),
301 method.clone(),
302 path,
303 Route::Extension(res.clone()),
304 chan.clone(),
305 ))
306 .unify()
307 .boxed()
308 }
309
310 wroutes
311}
312
313fn is_host_allowed(server_address: &SocketAddr, allow_hosts: &[Host], host_header: &str) -> bool {
314 let header_host_url = match Url::parse(&format!("http://{}", &host_header)) {
317 Ok(x) => x,
318 Err(_) => {
319 return false;
320 }
321 };
322
323 let host = match header_host_url.host() {
324 Some(host) => host.to_owned(),
325 None => {
326 return false;
330 }
331 };
332 let port = match header_host_url.port_or_known_default() {
333 Some(port) => port,
334 None => {
335 return false;
339 }
340 };
341
342 let host_matches = match host {
343 Host::Domain(_) => allow_hosts.contains(&host),
344 Host::Ipv4(_) | Host::Ipv6(_) => true,
345 };
346 let port_matches = server_address.port() == port;
347 host_matches && port_matches
348}
349
350fn is_origin_allowed(allow_origins: &[Url], origin_url: Url) -> bool {
351 allow_origins.contains(&origin_url)
353}
354
355fn build_route<U: 'static + WebDriverExtensionRoute + Send + Sync>(
356 server_address: SocketAddr,
357 allow_hosts: Vec<Host>,
358 allow_origins: Vec<Url>,
359 method: Method,
360 path: &'static str,
361 route: Route<U>,
362 chan: Arc<Mutex<Sender<DispatchMessage<U>>>>,
363) -> warp::filters::BoxedFilter<(impl warp::Reply,)> {
364 let mut subroute = match method {
367 Method::GET => warp::get().boxed(),
368 Method::POST => warp::post().boxed(),
369 Method::DELETE => warp::delete().boxed(),
370 Method::OPTIONS => warp::options().boxed(),
371 Method::PUT => warp::put().boxed(),
372 _ => panic!("Unsupported method"),
373 }
374 .or(warp::head())
375 .unify()
376 .map(Parameters::new)
377 .boxed();
378
379 for part in path.split('/') {
383 if part.is_empty() {
384 continue;
385 } else if part.starts_with('{') {
386 assert!(part.ends_with('}'));
387
388 subroute = subroute
389 .and(warp::path::param())
390 .map(move |mut params: Parameters, param: String| {
391 let name = &part[1..part.len() - 1];
392 params.insert(name.to_string(), param);
393 params
394 })
395 .boxed();
396 } else {
397 subroute = subroute.and(warp::path(part)).boxed();
398 }
399 }
400
401 subroute
403 .and(warp::path::end())
404 .and(warp::path::full())
405 .and(warp::method())
406 .and(warp::header::optional::<String>("origin"))
407 .and(warp::header::optional::<String>("host"))
408 .and(warp::header::optional::<String>("content-type"))
409 .and(warp::body::bytes())
410 .map(
411 move |params,
412 full_path: warp::path::FullPath,
413 method,
414 origin_header: Option<String>,
415 host_header: Option<String>,
416 content_type_header: Option<String>,
417 body: Bytes| {
418 if method == Method::HEAD {
419 return warp::reply::with_status("".into(), StatusCode::OK);
420 }
421 if let Some(host) = host_header {
422 if !is_host_allowed(&server_address, &allow_hosts, &host) {
423 warn!(
424 "Rejected request with Host header {}, allowed values are [{}]",
425 host,
426 allow_hosts
427 .iter()
428 .map(|x| format!("{}:{}", x, server_address.port()))
429 .collect::<Vec<_>>()
430 .join(",")
431 );
432 let err = WebDriverError::new(
433 ErrorStatus::UnknownError,
434 format!("Invalid Host header {}", host),
435 );
436 return warp::reply::with_status(
437 serde_json::to_string(&err).unwrap(),
438 StatusCode::INTERNAL_SERVER_ERROR,
439 );
440 };
441 } else {
442 warn!("Rejected request with missing Host header");
443 let err = WebDriverError::new(
444 ErrorStatus::UnknownError,
445 "Missing Host header".to_string(),
446 );
447 return warp::reply::with_status(
448 serde_json::to_string(&err).unwrap(),
449 StatusCode::INTERNAL_SERVER_ERROR,
450 );
451 }
452 if let Some(origin) = origin_header {
453 let make_err = || {
454 warn!(
455 "Rejected request with Origin header {}, allowed values are [{}]",
456 origin,
457 allow_origins
458 .iter()
459 .map(|x| x.to_string())
460 .collect::<Vec<_>>()
461 .join(",")
462 );
463 WebDriverError::new(
464 ErrorStatus::UnknownError,
465 format!("Invalid Origin header {}", origin),
466 )
467 };
468 let origin_url = match Url::parse(&origin) {
469 Ok(url) => url,
470 Err(_) => {
471 return warp::reply::with_status(
472 serde_json::to_string(&make_err()).unwrap(),
473 StatusCode::INTERNAL_SERVER_ERROR,
474 );
475 }
476 };
477 if !is_origin_allowed(&allow_origins, origin_url) {
478 return warp::reply::with_status(
479 serde_json::to_string(&make_err()).unwrap(),
480 StatusCode::INTERNAL_SERVER_ERROR,
481 );
482 }
483 }
484 if method == Method::POST {
485 let content_type = content_type_header
488 .as_ref()
489 .map(|x| x.find(';').and_then(|idx| x.get(0..idx)).unwrap_or(x))
490 .map(|x| x.trim())
491 .map(|x| x.to_lowercase());
492 match content_type.as_ref().map(|x| x.as_ref()) {
493 Some("application/x-www-form-urlencoded")
494 | Some("multipart/form-data")
495 | Some("text/plain") => {
496 warn!(
497 "Rejected POST request with disallowed content type {}",
498 content_type.unwrap_or_else(|| "".into())
499 );
500 let err = WebDriverError::new(
501 ErrorStatus::UnknownError,
502 "Invalid Content-Type",
503 );
504 return warp::reply::with_status(
505 serde_json::to_string(&err).unwrap(),
506 StatusCode::INTERNAL_SERVER_ERROR,
507 );
508 }
509 Some(_) | None => {}
510 }
511 }
512 let body = String::from_utf8(body.chunk().to_vec());
513 if body.is_err() {
514 let err = WebDriverError::new(
515 ErrorStatus::UnknownError,
516 "Request body wasn't valid UTF-8",
517 );
518 return warp::reply::with_status(
519 serde_json::to_string(&err).unwrap(),
520 StatusCode::INTERNAL_SERVER_ERROR,
521 );
522 }
523 let body = body.unwrap();
524
525 debug!("-> {} {} {}", method, full_path.as_str(), body);
526 let msg_result = WebDriverMessage::from_http(
527 route.clone(),
528 ¶ms,
529 &body,
530 method == Method::POST,
531 );
532
533 let (status, resp_body) = match msg_result {
534 Ok(message) => {
535 let (send_res, recv_res) = channel();
536 match chan.lock() {
537 Ok(ref c) => {
538 let res =
539 c.send(DispatchMessage::HandleWebDriver(message, send_res));
540 match res {
541 Ok(x) => x,
542 Err(e) => panic!("Error: {:?}", e),
543 }
544 }
545 Err(e) => panic!("Error reading response: {:?}", e),
546 }
547
548 match recv_res.recv() {
549 Ok(data) => match data {
550 Ok(response) => {
551 (StatusCode::OK, serde_json::to_string(&response).unwrap())
552 }
553 Err(e) => (e.http_status(), serde_json::to_string(&e).unwrap()),
554 },
555 Err(e) => panic!("Error reading response: {:?}", e),
556 }
557 }
558 Err(e) => (e.http_status(), serde_json::to_string(&e).unwrap()),
559 };
560
561 debug!("<- {} {}", status, resp_body);
562 warp::reply::with_status(resp_body, status)
563 },
564 )
565 .with(warp::reply::with::header(
566 http::header::CONTENT_TYPE,
567 "application/json; charset=utf-8",
568 ))
569 .with(warp::reply::with::header(
570 http::header::CACHE_CONTROL,
571 "no-cache",
572 ))
573 .boxed()
574}
575
576#[cfg(test)]
577mod tests {
578 use super::*;
579 use std::net::IpAddr;
580 use std::str::FromStr;
581
582 #[test]
583 fn test_host_allowed() {
584 let addr_80 = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), 80);
585 let addr_8000 = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), 8000);
586 let addr_v6_80 = SocketAddr::new(IpAddr::from_str("::1").unwrap(), 80);
587 let addr_v6_8000 = SocketAddr::new(IpAddr::from_str("::1").unwrap(), 8000);
588
589 let localhost_host = Host::Domain("localhost".to_string());
591 let test_host = Host::Domain("example.test".to_string());
592 let subdomain_localhost_host = Host::Domain("subdomain.localhost".to_string());
593
594 assert!(is_host_allowed(
595 &addr_80,
596 &[localhost_host.clone()],
597 "localhost:80"
598 ));
599 assert!(is_host_allowed(
600 &addr_80,
601 &[test_host.clone()],
602 "example.test:80"
603 ));
604 assert!(is_host_allowed(
605 &addr_80,
606 &[test_host.clone(), localhost_host.clone()],
607 "example.test"
608 ));
609 assert!(is_host_allowed(
610 &addr_80,
611 &[subdomain_localhost_host.clone()],
612 "subdomain.localhost"
613 ));
614
615 assert!(is_host_allowed(&addr_80, &[], "127.0.0.1:80"));
617 assert!(is_host_allowed(&addr_v6_80, &[], "127.0.0.1"));
618 assert!(is_host_allowed(&addr_80, &[], "[::1]"));
619 assert!(is_host_allowed(&addr_8000, &[], "127.0.0.1:8000"));
620 assert!(is_host_allowed(
621 &addr_80,
622 &[subdomain_localhost_host.clone()],
623 "[::1]"
624 ));
625 assert!(is_host_allowed(
626 &addr_v6_8000,
627 &[subdomain_localhost_host.clone()],
628 "[::1]:8000"
629 ));
630
631 assert!(!is_host_allowed(&addr_80, &[test_host], "localhost"));
634
635 assert!(!is_host_allowed(&addr_80, &[], "localhost:80"));
636
637 assert!(!is_host_allowed(
640 &addr_80,
641 &[localhost_host.clone()],
642 "localhost:8000"
643 ));
644 assert!(!is_host_allowed(
645 &addr_8000,
646 &[localhost_host.clone()],
647 "localhost"
648 ));
649 assert!(!is_host_allowed(
650 &addr_v6_8000,
651 &[localhost_host.clone()],
652 "[::1]"
653 ));
654 }
655
656 #[test]
657 fn test_origin_allowed() {
658 assert!(is_origin_allowed(
659 &[Url::parse("http://localhost").unwrap()],
660 Url::parse("http://localhost").unwrap()
661 ));
662 assert!(is_origin_allowed(
663 &[Url::parse("http://localhost").unwrap()],
664 Url::parse("http://localhost:80").unwrap()
665 ));
666 assert!(is_origin_allowed(
667 &[
668 Url::parse("https://test.example").unwrap(),
669 Url::parse("http://localhost").unwrap()
670 ],
671 Url::parse("http://localhost").unwrap()
672 ));
673 assert!(is_origin_allowed(
674 &[
675 Url::parse("https://test.example").unwrap(),
676 Url::parse("http://localhost").unwrap()
677 ],
678 Url::parse("https://test.example:443").unwrap()
679 ));
680 assert!(!is_origin_allowed(
682 &[],
683 Url::parse("http://localhost").unwrap()
684 ));
685 assert!(!is_origin_allowed(
686 &[Url::parse("http://localhost").unwrap()],
687 Url::parse("http://localhost:8000").unwrap()
688 ));
689 assert!(!is_origin_allowed(
690 &[Url::parse("https://localhost").unwrap()],
691 Url::parse("http://localhost").unwrap()
692 ));
693 assert!(!is_origin_allowed(
694 &[Url::parse("https://example.test").unwrap()],
695 Url::parse("http://subdomain.example.test").unwrap()
696 ));
697 }
698}