webdriver/
server.rs

1/* This Source Code Form is subject to the terms of the Mozilla Public
2 * License, v. 2.0. If a copy of the MPL was not distributed with this
3 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
4
5use crate::command::{WebDriverCommand, WebDriverMessage};
6use crate::error::{ErrorStatus, WebDriverError, WebDriverResult};
7use crate::httpapi::{
8    standard_routes, Route, VoidWebDriverExtensionRoute, WebDriverExtensionRoute,
9};
10use crate::response::{CloseWindowResponse, WebDriverResponse};
11use crate::Parameters;
12use bytes::Bytes;
13use http::{Method, StatusCode};
14use std::marker::PhantomData;
15use std::net::{SocketAddr, TcpListener as StdTcpListener};
16use std::sync::mpsc::{channel, Receiver, Sender};
17use std::sync::{Arc, Mutex};
18use std::thread;
19use tokio::net::TcpListener;
20use tokio_stream::wrappers::TcpListenerStream;
21use url::{Host, Url};
22use warp::{Buf, Filter, Rejection};
23
24// Silence warning about Quit being unused for now.
25#[allow(dead_code)]
26enum DispatchMessage<U: WebDriverExtensionRoute> {
27    HandleWebDriver(
28        WebDriverMessage<U>,
29        Sender<WebDriverResult<WebDriverResponse>>,
30    ),
31    Quit,
32}
33
34#[derive(Clone, Debug, PartialEq)]
35/// Representation of whether we managed to successfully send a DeleteSession message
36/// and read the response during session teardown.
37pub enum SessionTeardownKind {
38    /// A DeleteSession message has been sent and the response handled.
39    Deleted,
40    /// No DeleteSession message has been sent, or the response was not received.
41    NotDeleted,
42}
43
44#[derive(Clone, Debug, PartialEq)]
45pub struct Session {
46    pub id: String,
47}
48
49impl Session {
50    fn new(id: String) -> Session {
51        Session { id }
52    }
53}
54
55pub trait WebDriverHandler<U: WebDriverExtensionRoute = VoidWebDriverExtensionRoute>: Send {
56    fn handle_command(
57        &mut self,
58        session: &Option<Session>,
59        msg: WebDriverMessage<U>,
60    ) -> WebDriverResult<WebDriverResponse>;
61    fn teardown_session(&mut self, kind: SessionTeardownKind);
62}
63
64#[derive(Debug)]
65struct Dispatcher<T: WebDriverHandler<U>, U: WebDriverExtensionRoute> {
66    handler: T,
67    session: Option<Session>,
68    extension_type: PhantomData<U>,
69}
70
71impl<T: WebDriverHandler<U>, U: WebDriverExtensionRoute> Dispatcher<T, U> {
72    fn new(handler: T) -> Dispatcher<T, U> {
73        Dispatcher {
74            handler,
75            session: None,
76            extension_type: PhantomData,
77        }
78    }
79
80    fn run(&mut self, msg_chan: &Receiver<DispatchMessage<U>>) {
81        loop {
82            match msg_chan.recv() {
83                Ok(DispatchMessage::HandleWebDriver(msg, resp_chan)) => {
84                    let resp = match self.check_session(&msg) {
85                        Ok(_) => self.handler.handle_command(&self.session, msg),
86                        Err(e) => Err(e),
87                    };
88
89                    match resp {
90                        Ok(WebDriverResponse::NewSession(ref new_session)) => {
91                            self.session = Some(Session::new(new_session.session_id.clone()));
92                        }
93                        Ok(WebDriverResponse::CloseWindow(CloseWindowResponse(ref handles))) => {
94                            if handles.is_empty() {
95                                debug!("Last window was closed, deleting session");
96                                // The teardown_session implementation is responsible for actually
97                                // sending the DeleteSession message in this case
98                                self.teardown_session(SessionTeardownKind::NotDeleted);
99                            }
100                        }
101                        Ok(WebDriverResponse::DeleteSession) => {
102                            self.teardown_session(SessionTeardownKind::Deleted);
103                        }
104                        Err(ref x) if x.delete_session => {
105                            // This includes the case where we failed during session creation
106                            self.teardown_session(SessionTeardownKind::NotDeleted)
107                        }
108                        _ => {}
109                    }
110
111                    if resp_chan.send(resp).is_err() {
112                        error!("Sending response to the main thread failed");
113                    };
114                }
115                Ok(DispatchMessage::Quit) => break,
116                Err(e) => panic!("Error receiving message in handler: {:?}", e),
117            }
118        }
119    }
120
121    fn teardown_session(&mut self, kind: SessionTeardownKind) {
122        debug!("Teardown session");
123        let final_kind = match kind {
124            SessionTeardownKind::NotDeleted if self.session.is_some() => {
125                let delete_session = WebDriverMessage {
126                    session_id: Some(
127                        self.session
128                            .as_ref()
129                            .expect("Failed to get session")
130                            .id
131                            .clone(),
132                    ),
133                    command: WebDriverCommand::DeleteSession,
134                };
135                match self.handler.handle_command(&self.session, delete_session) {
136                    Ok(_) => SessionTeardownKind::Deleted,
137                    Err(_) => SessionTeardownKind::NotDeleted,
138                }
139            }
140            _ => kind,
141        };
142        self.handler.teardown_session(final_kind);
143        self.session = None;
144    }
145
146    fn check_session(&self, msg: &WebDriverMessage<U>) -> WebDriverResult<()> {
147        match msg.session_id {
148            Some(ref msg_session_id) => match self.session {
149                Some(ref existing_session) => {
150                    if existing_session.id != *msg_session_id {
151                        Err(WebDriverError::new(
152                            ErrorStatus::InvalidSessionId,
153                            format!("Got unexpected session id {}", msg_session_id),
154                        ))
155                    } else {
156                        Ok(())
157                    }
158                }
159                None => Ok(()),
160            },
161            None => {
162                match self.session {
163                    Some(_) => {
164                        match msg.command {
165                            WebDriverCommand::Status => Ok(()),
166                            WebDriverCommand::NewSession(_) => Err(WebDriverError::new(
167                                ErrorStatus::SessionNotCreated,
168                                "Session is already started",
169                            )),
170                            _ => {
171                                //This should be impossible
172                                error!("Got a message with no session id");
173                                Err(WebDriverError::new(
174                                    ErrorStatus::UnknownError,
175                                    "Got a command with no session?!",
176                                ))
177                            }
178                        }
179                    }
180                    None => match msg.command {
181                        WebDriverCommand::NewSession(_) => Ok(()),
182                        WebDriverCommand::Status => Ok(()),
183                        _ => Err(WebDriverError::new(
184                            ErrorStatus::InvalidSessionId,
185                            "Tried to run a command before creating a session",
186                        )),
187                    },
188                }
189            }
190        }
191    }
192}
193
194pub struct Listener {
195    guard: Option<thread::JoinHandle<()>>,
196    pub socket: SocketAddr,
197}
198
199impl Drop for Listener {
200    fn drop(&mut self) {
201        let _ = self.guard.take().map(|j| j.join());
202    }
203}
204
205pub fn start<T, U>(
206    mut address: SocketAddr,
207    allow_hosts: Vec<Host>,
208    allow_origins: Vec<Url>,
209    handler: T,
210    extension_routes: Vec<(Method, &'static str, U)>,
211) -> ::std::io::Result<Listener>
212where
213    T: 'static + WebDriverHandler<U>,
214    U: 'static + WebDriverExtensionRoute + Send + Sync,
215{
216    let listener = StdTcpListener::bind(address)?;
217    listener.set_nonblocking(true)?;
218    let addr = listener.local_addr()?;
219    if address.port() == 0 {
220        // If we passed in 0 as the port number the OS will assign an unused port;
221        // we want to update the address to the actual used port
222        address.set_port(addr.port())
223    }
224    let (msg_send, msg_recv) = channel();
225
226    let builder = thread::Builder::new().name("webdriver server".to_string());
227    let handle = builder.spawn(move || {
228        let rt = tokio::runtime::Builder::new_current_thread()
229            .enable_io()
230            .build()
231            .unwrap();
232        let listener = rt.block_on(async { TcpListener::from_std(listener).unwrap() });
233        let wroutes = build_warp_routes(
234            address,
235            allow_hosts,
236            allow_origins,
237            &extension_routes,
238            msg_send.clone(),
239        );
240        let fut = warp::serve(wroutes).run_incoming(TcpListenerStream::new(listener));
241        rt.block_on(fut);
242    })?;
243
244    let builder = thread::Builder::new().name("webdriver dispatcher".to_string());
245    builder.spawn(move || {
246        let mut dispatcher = Dispatcher::new(handler);
247        dispatcher.run(&msg_recv);
248    })?;
249
250    Ok(Listener {
251        guard: Some(handle),
252        socket: addr,
253    })
254}
255
256fn build_warp_routes<U: 'static + WebDriverExtensionRoute + Send + Sync>(
257    address: SocketAddr,
258    allow_hosts: Vec<Host>,
259    allow_origins: Vec<Url>,
260    ext_routes: &[(Method, &'static str, U)],
261    chan: Sender<DispatchMessage<U>>,
262) -> impl Filter<Extract = (impl warp::Reply,), Error = Rejection> + Clone {
263    let chan = Arc::new(Mutex::new(chan));
264    let mut std_routes = standard_routes::<U>();
265
266    let (method, path, res) = std_routes.pop().unwrap();
267    trace!("Build standard route for {path}");
268    let mut wroutes = build_route(
269        address,
270        allow_hosts.clone(),
271        allow_origins.clone(),
272        method,
273        path,
274        res,
275        chan.clone(),
276    );
277
278    for (method, path, res) in std_routes {
279        trace!("Build standard route for {path}");
280        wroutes = wroutes
281            .or(build_route(
282                address,
283                allow_hosts.clone(),
284                allow_origins.clone(),
285                method,
286                path,
287                res.clone(),
288                chan.clone(),
289            ))
290            .unify()
291            .boxed()
292    }
293
294    for (method, path, res) in ext_routes {
295        trace!("Build vendor route for {path}");
296        wroutes = wroutes
297            .or(build_route(
298                address,
299                allow_hosts.clone(),
300                allow_origins.clone(),
301                method.clone(),
302                path,
303                Route::Extension(res.clone()),
304                chan.clone(),
305            ))
306            .unify()
307            .boxed()
308    }
309
310    wroutes
311}
312
313fn is_host_allowed(server_address: &SocketAddr, allow_hosts: &[Host], host_header: &str) -> bool {
314    // Validate that the Host header value has a hostname in allow_hosts and
315    // the port matches the server configuration
316    let header_host_url = match Url::parse(&format!("http://{}", &host_header)) {
317        Ok(x) => x,
318        Err(_) => {
319            return false;
320        }
321    };
322
323    let host = match header_host_url.host() {
324        Some(host) => host.to_owned(),
325        None => {
326            // This shouldn't be possible since http URL always have a
327            // host, but conservatively return false here, which will cause
328            // an error response
329            return false;
330        }
331    };
332    let port = match header_host_url.port_or_known_default() {
333        Some(port) => port,
334        None => {
335            // This shouldn't be possible since http URL always have a
336            // default port, but conservatively return false here, which will cause
337            // an error response
338            return false;
339        }
340    };
341
342    let host_matches = match host {
343        Host::Domain(_) => allow_hosts.contains(&host),
344        Host::Ipv4(_) | Host::Ipv6(_) => true,
345    };
346    let port_matches = server_address.port() == port;
347    host_matches && port_matches
348}
349
350fn is_origin_allowed(allow_origins: &[Url], origin_url: Url) -> bool {
351    // Validate that the Origin header value is in allow_origins
352    allow_origins.contains(&origin_url)
353}
354
355fn build_route<U: 'static + WebDriverExtensionRoute + Send + Sync>(
356    server_address: SocketAddr,
357    allow_hosts: Vec<Host>,
358    allow_origins: Vec<Url>,
359    method: Method,
360    path: &'static str,
361    route: Route<U>,
362    chan: Arc<Mutex<Sender<DispatchMessage<U>>>>,
363) -> warp::filters::BoxedFilter<(impl warp::Reply,)> {
364    // Create an empty filter based on the provided method and append an empty hashmap to it. The
365    // hashmap will be used to store path parameters.
366    let mut subroute = match method {
367        Method::GET => warp::get().boxed(),
368        Method::POST => warp::post().boxed(),
369        Method::DELETE => warp::delete().boxed(),
370        Method::OPTIONS => warp::options().boxed(),
371        Method::PUT => warp::put().boxed(),
372        _ => panic!("Unsupported method"),
373    }
374    .or(warp::head())
375    .unify()
376    .map(Parameters::new)
377    .boxed();
378
379    // For each part of the path, if it's a normal part, just append it to the current filter,
380    // otherwise if it's a parameter (a named enclosed in { }), we take that parameter and insert
381    // it into the hashmap created earlier.
382    for part in path.split('/') {
383        if part.is_empty() {
384            continue;
385        } else if part.starts_with('{') {
386            assert!(part.ends_with('}'));
387
388            subroute = subroute
389                .and(warp::path::param())
390                .map(move |mut params: Parameters, param: String| {
391                    let name = &part[1..part.len() - 1];
392                    params.insert(name.to_string(), param);
393                    params
394                })
395                .boxed();
396        } else {
397            subroute = subroute.and(warp::path(part)).boxed();
398        }
399    }
400
401    // Finally, tell warp that the path is complete
402    subroute
403        .and(warp::path::end())
404        .and(warp::path::full())
405        .and(warp::method())
406        .and(warp::header::optional::<String>("origin"))
407        .and(warp::header::optional::<String>("host"))
408        .and(warp::header::optional::<String>("content-type"))
409        .and(warp::body::bytes())
410        .map(
411            move |params,
412                  full_path: warp::path::FullPath,
413                  method,
414                  origin_header: Option<String>,
415                  host_header: Option<String>,
416                  content_type_header: Option<String>,
417                  body: Bytes| {
418                if method == Method::HEAD {
419                    return warp::reply::with_status("".into(), StatusCode::OK);
420                }
421                if let Some(host) = host_header {
422                    if !is_host_allowed(&server_address, &allow_hosts, &host) {
423                        warn!(
424                            "Rejected request with Host header {}, allowed values are [{}]",
425                            host,
426                            allow_hosts
427                                .iter()
428                                .map(|x| format!("{}:{}", x, server_address.port()))
429                                .collect::<Vec<_>>()
430                                .join(",")
431                        );
432                        let err = WebDriverError::new(
433                            ErrorStatus::UnknownError,
434                            format!("Invalid Host header {}", host),
435                        );
436                        return warp::reply::with_status(
437                            serde_json::to_string(&err).unwrap(),
438                            StatusCode::INTERNAL_SERVER_ERROR,
439                        );
440                    };
441                } else {
442                    warn!("Rejected request with missing Host header");
443                    let err = WebDriverError::new(
444                        ErrorStatus::UnknownError,
445                        "Missing Host header".to_string(),
446                    );
447                    return warp::reply::with_status(
448                        serde_json::to_string(&err).unwrap(),
449                        StatusCode::INTERNAL_SERVER_ERROR,
450                    );
451                }
452                if let Some(origin) = origin_header {
453                    let make_err = || {
454                        warn!(
455                            "Rejected request with Origin header {}, allowed values are [{}]",
456                            origin,
457                            allow_origins
458                                .iter()
459                                .map(|x| x.to_string())
460                                .collect::<Vec<_>>()
461                                .join(",")
462                        );
463                        WebDriverError::new(
464                            ErrorStatus::UnknownError,
465                            format!("Invalid Origin header {}", origin),
466                        )
467                    };
468                    let origin_url = match Url::parse(&origin) {
469                        Ok(url) => url,
470                        Err(_) => {
471                            return warp::reply::with_status(
472                                serde_json::to_string(&make_err()).unwrap(),
473                                StatusCode::INTERNAL_SERVER_ERROR,
474                            );
475                        }
476                    };
477                    if !is_origin_allowed(&allow_origins, origin_url) {
478                        return warp::reply::with_status(
479                            serde_json::to_string(&make_err()).unwrap(),
480                            StatusCode::INTERNAL_SERVER_ERROR,
481                        );
482                    }
483                }
484                if method == Method::POST {
485                    // Disallow CORS-safelisted request headers
486                    // c.f. https://fetch.spec.whatwg.org/#cors-safelisted-request-header
487                    let content_type = content_type_header
488                        .as_ref()
489                        .map(|x| x.find(';').and_then(|idx| x.get(0..idx)).unwrap_or(x))
490                        .map(|x| x.trim())
491                        .map(|x| x.to_lowercase());
492                    match content_type.as_ref().map(|x| x.as_ref()) {
493                        Some("application/x-www-form-urlencoded")
494                        | Some("multipart/form-data")
495                        | Some("text/plain") => {
496                            warn!(
497                                "Rejected POST request with disallowed content type {}",
498                                content_type.unwrap_or_else(|| "".into())
499                            );
500                            let err = WebDriverError::new(
501                                ErrorStatus::UnknownError,
502                                "Invalid Content-Type",
503                            );
504                            return warp::reply::with_status(
505                                serde_json::to_string(&err).unwrap(),
506                                StatusCode::INTERNAL_SERVER_ERROR,
507                            );
508                        }
509                        Some(_) | None => {}
510                    }
511                }
512                let body = String::from_utf8(body.chunk().to_vec());
513                if body.is_err() {
514                    let err = WebDriverError::new(
515                        ErrorStatus::UnknownError,
516                        "Request body wasn't valid UTF-8",
517                    );
518                    return warp::reply::with_status(
519                        serde_json::to_string(&err).unwrap(),
520                        StatusCode::INTERNAL_SERVER_ERROR,
521                    );
522                }
523                let body = body.unwrap();
524
525                debug!("-> {} {} {}", method, full_path.as_str(), body);
526                let msg_result = WebDriverMessage::from_http(
527                    route.clone(),
528                    &params,
529                    &body,
530                    method == Method::POST,
531                );
532
533                let (status, resp_body) = match msg_result {
534                    Ok(message) => {
535                        let (send_res, recv_res) = channel();
536                        match chan.lock() {
537                            Ok(ref c) => {
538                                let res =
539                                    c.send(DispatchMessage::HandleWebDriver(message, send_res));
540                                match res {
541                                    Ok(x) => x,
542                                    Err(e) => panic!("Error: {:?}", e),
543                                }
544                            }
545                            Err(e) => panic!("Error reading response: {:?}", e),
546                        }
547
548                        match recv_res.recv() {
549                            Ok(data) => match data {
550                                Ok(response) => {
551                                    (StatusCode::OK, serde_json::to_string(&response).unwrap())
552                                }
553                                Err(e) => (e.http_status(), serde_json::to_string(&e).unwrap()),
554                            },
555                            Err(e) => panic!("Error reading response: {:?}", e),
556                        }
557                    }
558                    Err(e) => (e.http_status(), serde_json::to_string(&e).unwrap()),
559                };
560
561                debug!("<- {} {}", status, resp_body);
562                warp::reply::with_status(resp_body, status)
563            },
564        )
565        .with(warp::reply::with::header(
566            http::header::CONTENT_TYPE,
567            "application/json; charset=utf-8",
568        ))
569        .with(warp::reply::with::header(
570            http::header::CACHE_CONTROL,
571            "no-cache",
572        ))
573        .boxed()
574}
575
576#[cfg(test)]
577mod tests {
578    use super::*;
579    use std::net::IpAddr;
580    use std::str::FromStr;
581
582    #[test]
583    fn test_host_allowed() {
584        let addr_80 = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), 80);
585        let addr_8000 = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), 8000);
586        let addr_v6_80 = SocketAddr::new(IpAddr::from_str("::1").unwrap(), 80);
587        let addr_v6_8000 = SocketAddr::new(IpAddr::from_str("::1").unwrap(), 8000);
588
589        // We match the host ip address to the server, so we can only use hosts that actually resolve
590        let localhost_host = Host::Domain("localhost".to_string());
591        let test_host = Host::Domain("example.test".to_string());
592        let subdomain_localhost_host = Host::Domain("subdomain.localhost".to_string());
593
594        assert!(is_host_allowed(
595            &addr_80,
596            &[localhost_host.clone()],
597            "localhost:80"
598        ));
599        assert!(is_host_allowed(
600            &addr_80,
601            &[test_host.clone()],
602            "example.test:80"
603        ));
604        assert!(is_host_allowed(
605            &addr_80,
606            &[test_host.clone(), localhost_host.clone()],
607            "example.test"
608        ));
609        assert!(is_host_allowed(
610            &addr_80,
611            &[subdomain_localhost_host.clone()],
612            "subdomain.localhost"
613        ));
614
615        // ip address cases
616        assert!(is_host_allowed(&addr_80, &[], "127.0.0.1:80"));
617        assert!(is_host_allowed(&addr_v6_80, &[], "127.0.0.1"));
618        assert!(is_host_allowed(&addr_80, &[], "[::1]"));
619        assert!(is_host_allowed(&addr_8000, &[], "127.0.0.1:8000"));
620        assert!(is_host_allowed(
621            &addr_80,
622            &[subdomain_localhost_host.clone()],
623            "[::1]"
624        ));
625        assert!(is_host_allowed(
626            &addr_v6_8000,
627            &[subdomain_localhost_host.clone()],
628            "[::1]:8000"
629        ));
630
631        // Mismatch cases
632
633        assert!(!is_host_allowed(&addr_80, &[test_host], "localhost"));
634
635        assert!(!is_host_allowed(&addr_80, &[], "localhost:80"));
636
637        // Port mismatch cases
638
639        assert!(!is_host_allowed(
640            &addr_80,
641            &[localhost_host.clone()],
642            "localhost:8000"
643        ));
644        assert!(!is_host_allowed(
645            &addr_8000,
646            &[localhost_host.clone()],
647            "localhost"
648        ));
649        assert!(!is_host_allowed(
650            &addr_v6_8000,
651            &[localhost_host.clone()],
652            "[::1]"
653        ));
654    }
655
656    #[test]
657    fn test_origin_allowed() {
658        assert!(is_origin_allowed(
659            &[Url::parse("http://localhost").unwrap()],
660            Url::parse("http://localhost").unwrap()
661        ));
662        assert!(is_origin_allowed(
663            &[Url::parse("http://localhost").unwrap()],
664            Url::parse("http://localhost:80").unwrap()
665        ));
666        assert!(is_origin_allowed(
667            &[
668                Url::parse("https://test.example").unwrap(),
669                Url::parse("http://localhost").unwrap()
670            ],
671            Url::parse("http://localhost").unwrap()
672        ));
673        assert!(is_origin_allowed(
674            &[
675                Url::parse("https://test.example").unwrap(),
676                Url::parse("http://localhost").unwrap()
677            ],
678            Url::parse("https://test.example:443").unwrap()
679        ));
680        // Mismatch cases
681        assert!(!is_origin_allowed(
682            &[],
683            Url::parse("http://localhost").unwrap()
684        ));
685        assert!(!is_origin_allowed(
686            &[Url::parse("http://localhost").unwrap()],
687            Url::parse("http://localhost:8000").unwrap()
688        ));
689        assert!(!is_origin_allowed(
690            &[Url::parse("https://localhost").unwrap()],
691            Url::parse("http://localhost").unwrap()
692        ));
693        assert!(!is_origin_allowed(
694            &[Url::parse("https://example.test").unwrap()],
695            Url::parse("http://subdomain.example.test").unwrap()
696        ));
697    }
698}