ssdp_client/
search.rs

1use crate::{Error, SearchTarget};
2
3use futures_core::stream::Stream;
4use genawaiter::sync::{Co, Gen};
5use std::{collections::HashMap, net::SocketAddr, time::Duration};
6use tokio::net::UdpSocket;
7
8const INSUFFICIENT_BUFFER_MSG: &str = "buffer size too small, udp packets lost";
9const DEFAULT_SEARCH_TTL: u32 = 2;
10
11#[derive(Debug)]
12/// Response given by ssdp control point
13pub struct SearchResponse {
14    location: String,
15    st: SearchTarget,
16    usn: String,
17    server: String,
18    extra_headers: HashMap<String, String>,
19}
20
21impl SearchResponse {
22    /// URL of the control point
23    pub fn location(&self) -> &str {
24        &self.location
25    }
26    /// search target returned by the control point
27    pub fn search_target(&self) -> &SearchTarget {
28        &self.st
29    }
30    /// Unique Service Name
31    pub fn usn(&self) -> &str {
32        &self.usn
33    }
34    /// Server (user agent)
35    pub fn server(&self) -> &str {
36        &self.server
37    }
38    /// Other Custom header
39    pub fn extra_header(&self, key: &str) -> Option<&str> {
40        self.extra_headers.get(key).map(|x| x.as_str())
41    }
42}
43
44#[cfg(not(windows))]
45async fn get_bind_addr() -> Result<SocketAddr, std::io::Error> {
46    Ok(([0, 0, 0, 0], 0).into())
47}
48
49#[cfg(windows)]
50async fn get_bind_addr() -> Result<SocketAddr, std::io::Error> {
51    // Windows 10 is multihomed so that the address that is used for the broadcast send is not guaranteed to be your local ip address, it can be any of the virtual interfaces instead.
52    // Thanks to @dheijl for figuring this out <3 (https://github.com/jakobhellermann/ssdp-client/issues/3#issuecomment-687098826)
53    let any: SocketAddr = ([0, 0, 0, 0], 0).into();
54    let socket = UdpSocket::bind(any).await?;
55    let googledns: SocketAddr = ([8, 8, 8, 8], 80).into();
56    socket.connect(googledns).await?;
57    let bind_addr = socket.local_addr()?;
58
59    Ok(bind_addr)
60}
61
62/// Search for SSDP control points within a network.
63/// Control Points will wait a random amount of time between 0 and mx seconds before responing to avoid flooding the requester with responses.
64/// Therefore, the timeout should be at least mx seconds.
65pub async fn search(
66    search_target: &SearchTarget,
67    timeout: Duration,
68    mx: usize,
69    ttl: Option<u32>,
70) -> Result<impl Stream<Item = Result<SearchResponse, Error>>, Error> {
71    let bind_addr: SocketAddr = get_bind_addr().await?;
72    let broadcast_address: SocketAddr = ([239, 255, 255, 250], 1900).into();
73
74    let socket = UdpSocket::bind(&bind_addr).await?;
75    socket
76        .set_multicast_ttl_v4(ttl.unwrap_or(DEFAULT_SEARCH_TTL))
77        .ok();
78
79    let msg = format!(
80        "M-SEARCH * HTTP/1.1\r
81Host:239.255.255.250:1900\r
82Man:\"ssdp:discover\"\r
83ST: {search_target}\r
84MX: {mx}\r\n\r\n"
85    );
86    socket.send_to(msg.as_bytes(), &broadcast_address).await?;
87
88    Ok(Gen::new(move |co| socket_stream(socket, timeout, co)))
89}
90
91macro_rules! yield_try {
92    ( $co:expr => $expr:expr ) => {
93        match $expr {
94            Ok(val) => val,
95            Err(e) => {
96                $co.yield_(Err(e.into())).await;
97                continue;
98            }
99        }
100    };
101}
102
103async fn socket_stream(
104    socket: UdpSocket,
105    timeout: Duration,
106    co: Co<Result<SearchResponse, Error>>,
107) {
108    loop {
109        let mut buf = [0u8; 2048];
110        let text = match tokio::time::timeout(timeout, socket.recv(&mut buf)).await {
111            Err(_) => break,
112            Ok(res) => match res {
113                Ok(2024) => {
114                    log::warn!("{INSUFFICIENT_BUFFER_MSG}");
115                    continue;
116                }
117                Ok(read) => yield_try!(co => std::str::from_utf8(&buf[..read])),
118                Err(e) => {
119                    co.yield_(Err(e.into())).await;
120                    continue;
121                }
122            },
123        };
124
125        let headers = yield_try!(co => parse_headers(text));
126
127        let mut location = None;
128        let mut st = None;
129        let mut usn = None;
130        let mut server = None;
131        let mut extra_headers = HashMap::new();
132
133        for (header, value) in headers {
134            if header.eq_ignore_ascii_case("location") {
135                location = Some(value);
136            } else if header.eq_ignore_ascii_case("st") {
137                st = Some(value);
138            } else if header.eq_ignore_ascii_case("usn") {
139                usn = Some(value);
140            } else if header.eq_ignore_ascii_case("server") {
141                server = Some(value);
142            } else {
143                extra_headers.insert(header.to_owned(), value.to_owned());
144            }
145        }
146
147        let location = yield_try!(co => location
148            .ok_or(Error::MissingHeader("location")))
149        .to_string();
150        let st = yield_try!(co => yield_try!(co => st.ok_or(Error::MissingHeader("st"))).parse::<SearchTarget>());
151        let usn = yield_try!(co => usn.ok_or(Error::MissingHeader("urn"))).to_string();
152        let server = yield_try!(co => server.ok_or(Error::MissingHeader("server"))).to_string();
153
154        co.yield_(Ok(SearchResponse {
155            location,
156            st,
157            usn,
158            server,
159            extra_headers,
160        }))
161        .await;
162    }
163}
164
165fn parse_headers(response: &str) -> Result<impl Iterator<Item = (&str, &str)>, Error> {
166    let mut response = response.split("\r\n");
167    let status_code = response
168        .next()
169        .ok_or(Error::InvalidHTTP("http response is empty"))?
170        .trim_start_matches("HTTP/1.1 ")
171        .chars()
172        .take_while(|x| x.is_numeric())
173        .collect::<String>()
174        .parse::<u32>()
175        .map_err(|_| Error::InvalidHTTP("status code is not a number"))?;
176
177    if status_code != 200 {
178        return Err(Error::HTTPError(status_code));
179    }
180
181    let iter = response.filter_map(|l| {
182        let mut split = l.splitn(2, ':');
183        match (split.next(), split.next()) {
184            (Some(header), Some(value)) => Some((header, value.trim())),
185            _ => None,
186        }
187    });
188
189    Ok(iter)
190}