1use crate::{Error, SearchTarget};
2
3use futures_core::stream::Stream;
4use genawaiter::sync::{Co, Gen};
5use std::{collections::HashMap, net::SocketAddr, time::Duration};
6use tokio::net::UdpSocket;
7
8const INSUFFICIENT_BUFFER_MSG: &str = "buffer size too small, udp packets lost";
9const DEFAULT_SEARCH_TTL: u32 = 2;
10
11#[derive(Debug)]
12pub struct SearchResponse {
14 location: String,
15 st: SearchTarget,
16 usn: String,
17 server: String,
18 extra_headers: HashMap<String, String>,
19}
20
21impl SearchResponse {
22 pub fn location(&self) -> &str {
24 &self.location
25 }
26 pub fn search_target(&self) -> &SearchTarget {
28 &self.st
29 }
30 pub fn usn(&self) -> &str {
32 &self.usn
33 }
34 pub fn server(&self) -> &str {
36 &self.server
37 }
38 pub fn extra_header(&self, key: &str) -> Option<&str> {
40 self.extra_headers.get(key).map(|x| x.as_str())
41 }
42}
43
44#[cfg(not(windows))]
45async fn get_bind_addr() -> Result<SocketAddr, std::io::Error> {
46 Ok(([0, 0, 0, 0], 0).into())
47}
48
49#[cfg(windows)]
50async fn get_bind_addr() -> Result<SocketAddr, std::io::Error> {
51 let any: SocketAddr = ([0, 0, 0, 0], 0).into();
54 let socket = UdpSocket::bind(any).await?;
55 let googledns: SocketAddr = ([8, 8, 8, 8], 80).into();
56 socket.connect(googledns).await?;
57 let bind_addr = socket.local_addr()?;
58
59 Ok(bind_addr)
60}
61
62pub async fn search(
66 search_target: &SearchTarget,
67 timeout: Duration,
68 mx: usize,
69 ttl: Option<u32>,
70) -> Result<impl Stream<Item = Result<SearchResponse, Error>>, Error> {
71 let bind_addr: SocketAddr = get_bind_addr().await?;
72 let broadcast_address: SocketAddr = ([239, 255, 255, 250], 1900).into();
73
74 let socket = UdpSocket::bind(&bind_addr).await?;
75 socket
76 .set_multicast_ttl_v4(ttl.unwrap_or(DEFAULT_SEARCH_TTL))
77 .ok();
78
79 let msg = format!(
80 "M-SEARCH * HTTP/1.1\r
81Host:239.255.255.250:1900\r
82Man:\"ssdp:discover\"\r
83ST: {search_target}\r
84MX: {mx}\r\n\r\n"
85 );
86 socket.send_to(msg.as_bytes(), &broadcast_address).await?;
87
88 Ok(Gen::new(move |co| socket_stream(socket, timeout, co)))
89}
90
91macro_rules! yield_try {
92 ( $co:expr => $expr:expr ) => {
93 match $expr {
94 Ok(val) => val,
95 Err(e) => {
96 $co.yield_(Err(e.into())).await;
97 continue;
98 }
99 }
100 };
101}
102
103async fn socket_stream(
104 socket: UdpSocket,
105 timeout: Duration,
106 co: Co<Result<SearchResponse, Error>>,
107) {
108 loop {
109 let mut buf = [0u8; 2048];
110 let text = match tokio::time::timeout(timeout, socket.recv(&mut buf)).await {
111 Err(_) => break,
112 Ok(res) => match res {
113 Ok(2024) => {
114 log::warn!("{INSUFFICIENT_BUFFER_MSG}");
115 continue;
116 }
117 Ok(read) => yield_try!(co => std::str::from_utf8(&buf[..read])),
118 Err(e) => {
119 co.yield_(Err(e.into())).await;
120 continue;
121 }
122 },
123 };
124
125 let headers = yield_try!(co => parse_headers(text));
126
127 let mut location = None;
128 let mut st = None;
129 let mut usn = None;
130 let mut server = None;
131 let mut extra_headers = HashMap::new();
132
133 for (header, value) in headers {
134 if header.eq_ignore_ascii_case("location") {
135 location = Some(value);
136 } else if header.eq_ignore_ascii_case("st") {
137 st = Some(value);
138 } else if header.eq_ignore_ascii_case("usn") {
139 usn = Some(value);
140 } else if header.eq_ignore_ascii_case("server") {
141 server = Some(value);
142 } else {
143 extra_headers.insert(header.to_owned(), value.to_owned());
144 }
145 }
146
147 let location = yield_try!(co => location
148 .ok_or(Error::MissingHeader("location")))
149 .to_string();
150 let st = yield_try!(co => yield_try!(co => st.ok_or(Error::MissingHeader("st"))).parse::<SearchTarget>());
151 let usn = yield_try!(co => usn.ok_or(Error::MissingHeader("urn"))).to_string();
152 let server = yield_try!(co => server.ok_or(Error::MissingHeader("server"))).to_string();
153
154 co.yield_(Ok(SearchResponse {
155 location,
156 st,
157 usn,
158 server,
159 extra_headers,
160 }))
161 .await;
162 }
163}
164
165fn parse_headers(response: &str) -> Result<impl Iterator<Item = (&str, &str)>, Error> {
166 let mut response = response.split("\r\n");
167 let status_code = response
168 .next()
169 .ok_or(Error::InvalidHTTP("http response is empty"))?
170 .trim_start_matches("HTTP/1.1 ")
171 .chars()
172 .take_while(|x| x.is_numeric())
173 .collect::<String>()
174 .parse::<u32>()
175 .map_err(|_| Error::InvalidHTTP("status code is not a number"))?;
176
177 if status_code != 200 {
178 return Err(Error::HTTPError(status_code));
179 }
180
181 let iter = response.filter_map(|l| {
182 let mut split = l.splitn(2, ':');
183 match (split.next(), split.next()) {
184 (Some(header), Some(value)) => Some((header, value.trim())),
185 _ => None,
186 }
187 });
188
189 Ok(iter)
190}