upnp_client/
device_client.rs

1use std::{collections::HashMap, env, net::TcpListener, sync::Arc, time::Duration};
2
3use crate::{
4    parser::{
5        deserialize_metadata, parse_av_transport_uri_metadata, parse_current_play_mode,
6        parse_current_track_metadata, parse_last_change, parse_location, parse_transport_state,
7    },
8    types::{AVTransportEvent, Device, Event, Service},
9    BROADCAST_EVENT,
10};
11use anyhow::{anyhow, Result};
12use hyper::{
13    server::conn::AddrStream,
14    service::{make_service_fn, service_fn},
15};
16use hyper::{Body, Request, Response, Server};
17use surf::{Client, Config, Url};
18use tokio::sync::Mutex;
19use xml_builder::{XMLBuilder, XMLElement, XMLVersion};
20
21#[derive(Clone)]
22pub struct DeviceClient {
23    base_url: Url,
24    http_client: Client,
25    device: Option<Device>,
26    stop: Arc<Mutex<bool>>,
27}
28
29impl DeviceClient {
30    pub fn new(url: &str) -> Result<Self> {
31        Ok(Self {
32            base_url: Url::parse(url)?,
33            http_client: Config::new()
34                .set_timeout(Some(Duration::from_secs(5)))
35                .try_into()?,
36            device: None,
37            stop: Arc::new(Mutex::new(false)),
38        })
39    }
40
41    pub async fn connect(&mut self) -> Result<Self> {
42        self.device = Some(parse_location(self.base_url.as_str()).await?);
43        Ok(Self {
44            base_url: self.base_url.clone(),
45            http_client: self.http_client.clone(),
46            device: self.device.clone(),
47            stop: self.stop.clone(),
48        })
49    }
50
51    pub fn ip(&self) -> String {
52        self.base_url.host_str().unwrap().to_string()
53    }
54
55    pub async fn call_action(
56        &self,
57        service_id: &str,
58        action_name: &str,
59        params: HashMap<String, String>,
60    ) -> Result<String> {
61        if self.device.is_none() {
62            return Err(anyhow!("Device not connected"));
63        }
64        let service_id = resolve_service(service_id);
65        let service = self.get_service_description(&service_id).await?;
66
67        // check if action is available
68        let action = service.actions.iter().find(|a| a.name == action_name);
69        match action {
70            Some(_) => {
71                self.call_action_internal(&service, action_name, params)
72                    .await
73            }
74            None => Err(anyhow!("Action not found")),
75        }
76    }
77
78    async fn call_action_internal(
79        &self,
80        service: &Service,
81        action_name: &str,
82        params: HashMap<String, String>,
83    ) -> Result<String> {
84        let control_url = Url::parse(&service.control_url)?;
85
86        let mut xml = XMLBuilder::new()
87            .version(XMLVersion::XML1_1)
88            .encoding("UTF-8".into())
89            .build();
90
91        let mut envelope = XMLElement::new("s:Envelope");
92        envelope.add_attribute("xmlns:s", "http://schemas.xmlsoap.org/soap/envelope/");
93        envelope.add_attribute(
94            "s:encodingStyle",
95            "http://schemas.xmlsoap.org/soap/encoding/",
96        );
97
98        let mut body = XMLElement::new("s:Body");
99        let action = format!("u:{}", action_name);
100        let mut action = XMLElement::new(action.as_str());
101        action.add_attribute("xmlns:u", service.service_type.as_str());
102
103        for (name, value) in params {
104            let mut param = XMLElement::new(name.as_str());
105            param.add_text(value).map_err(|e| anyhow!("{:?}", e))?;
106            action.add_child(param).map_err(|e| anyhow!("{:?}", e))?;
107        }
108
109        body.add_child(action).map_err(|e| anyhow!("{:?}", e))?;
110        envelope.add_child(body).map_err(|e| anyhow!("{:?}", e))?;
111
112        xml.set_root_element(envelope);
113
114        let mut writer: Vec<u8> = Vec::new();
115        xml.generate(&mut writer).map_err(|e| anyhow!("{:?}", e))?;
116        let xml = String::from_utf8(writer)?;
117
118        let soap_action = format!("\"{}#{}\"", service.service_type, action_name);
119
120        let mut res = self
121            .http_client
122            .post(control_url)
123            .header("Content-Type", "text/xml; charset=\"utf-8\"")
124            .header("Content-Length", xml.len().to_string())
125            .header("SOAPACTION", soap_action)
126            .header("Connection", "close")
127            .body_string(xml.clone())
128            .send()
129            .await
130            .map_err(|e| anyhow!(e.to_string()))?;
131        res.body_string().await.map_err(|e| anyhow!(e.to_string()))
132    }
133
134    async fn get_service_description(&self, service_id: &str) -> Result<Service> {
135        if let Some(device) = &self.device {
136            let service = device
137                .services
138                .iter()
139                .find(|s| s.service_id == service_id)
140                .ok_or_else(|| {
141                    anyhow!(
142                        "Service with requested service_id {} does not exist",
143                        service_id
144                    )
145                })?;
146            return Ok(service.clone());
147        }
148        Err(anyhow!("Device not connected"))
149    }
150
151    pub async fn subscribe(&mut self, service_id: &str) -> Result<()> {
152        if self.device.is_none() {
153            return Err(anyhow!("Device not connected"));
154        }
155        let service_id = resolve_service(service_id);
156        let service = self.get_service_description(&service_id).await?;
157
158        let user_agent = format!(
159            "upnp-client/{} ({})",
160            env!("CARGO_PKG_VERSION"),
161            env::consts::OS
162        );
163
164        let (address, port) = self.ensure_eventing_server().await?;
165        let callback = format!("<http://{}:{}>", address, port);
166
167        let client = hyper::Client::new();
168        let req = hyper::Request::builder()
169            .method("SUBSCRIBE")
170            .uri(service.event_sub_url.clone())
171            .header("CALLBACK", callback)
172            .header("NT", "upnp:event")
173            .header("TIMEOUT", "Second-1800")
174            .header("USER-AGENT", user_agent)
175            .body(hyper::Body::empty())?;
176        client.request(req).await?;
177        Ok(())
178    }
179
180    pub async fn unsubscribe(&mut self, service_id: &str, sid: &str) -> Result<()> {
181        if self.device.is_none() {
182            return Err(anyhow!("Device not connected"));
183        }
184        let service_id = resolve_service(service_id);
185        let service = self.get_service_description(&service_id).await?;
186        let client = hyper::Client::new();
187        let req = hyper::Request::builder()
188            .method("UNSUBSCRIBE")
189            .uri(service.event_sub_url.clone())
190            .header("SID", sid)
191            .body(hyper::Body::empty())?;
192
193        client.request(req).await?;
194
195        self.release_eventing_server().await?;
196        Ok(())
197    }
198
199    async fn ensure_eventing_server(&mut self) -> Result<(String, u16)> {
200        let addr: &str = "0.0.0.0:0";
201        let listener = TcpListener::bind(addr)?;
202
203        let service = make_service_fn(|_: &AddrStream| async {
204            Ok::<_, hyper::Error>(service_fn(|req: Request<Body>| async move {
205                let sid = req
206                    .headers()
207                    .get("sid")
208                    .unwrap()
209                    .to_str()
210                    .unwrap()
211                    .to_string();
212                let body = hyper::body::to_bytes(req.into_body()).await?;
213                let xml = String::from_utf8(body.to_vec()).unwrap();
214
215                let last_change = parse_last_change(xml.as_str()).unwrap();
216                let last_change = last_change.unwrap_or_default();
217
218                let transport_state = parse_transport_state(last_change.as_str()).unwrap();
219                let play_mode = parse_current_play_mode(last_change.as_str()).unwrap();
220                let av_transport_uri_metadata =
221                    parse_av_transport_uri_metadata(last_change.as_str()).unwrap();
222                let current_track_metadata =
223                    parse_current_track_metadata(last_change.as_str()).unwrap();
224
225                if let Some(state) = transport_state {
226                    let tx = BROADCAST_EVENT.lock().unwrap();
227                    let tx = tx.as_ref();
228                    let ev = AVTransportEvent::TransportState {
229                        sid: sid.clone(),
230                        transport_state: state,
231                    };
232                    tx.unwrap().send(Event::AVTransport(ev)).unwrap();
233                }
234
235                if let Some(mode) = play_mode {
236                    let tx = BROADCAST_EVENT.lock().unwrap();
237                    let tx = tx.as_ref();
238                    let ev = AVTransportEvent::CurrentPlayMode {
239                        sid: sid.clone(),
240                        play_mode: mode,
241                    };
242                    tx.unwrap().send(Event::AVTransport(ev)).unwrap();
243                }
244
245                if let Some(metadata) = av_transport_uri_metadata {
246                    let tx = BROADCAST_EVENT.lock().unwrap();
247                    let tx = tx.as_ref();
248                    let m = deserialize_metadata(metadata.as_str()).unwrap();
249                    let ev = AVTransportEvent::AVTransportURIMetaData {
250                        sid: sid.clone(),
251                        url: m.url,
252                        title: m.title,
253                        artist: m.artist,
254                        album: m.album,
255                        album_art_uri: m.album_art_uri,
256                        genre: m.genre,
257                    };
258                    tx.unwrap().send(Event::AVTransport(ev)).unwrap();
259                }
260
261                if let Some(metadata) = current_track_metadata {
262                    let m = deserialize_metadata(metadata.as_str()).unwrap();
263                    let tx = BROADCAST_EVENT.lock().unwrap();
264                    let tx = tx.as_ref();
265                    let ev = AVTransportEvent::CurrentTrackMetadata {
266                        sid: sid.clone(),
267                        url: m.url,
268                        title: m.title,
269                        artist: m.artist,
270                        album: m.album,
271                        album_art_uri: m.album_art_uri,
272                        genre: m.genre,
273                    };
274                    tx.unwrap().send(Event::AVTransport(ev)).unwrap();
275                }
276
277                Ok::<_, hyper::Error>(Response::new(Body::empty()))
278            }))
279        });
280
281        let server = Server::from_tcp(listener).unwrap().serve(service);
282
283        let address = server.local_addr().ip().to_string();
284        let port = server.local_addr().port();
285
286        let stop = self.stop.clone();
287
288        tokio::spawn(async move {
289            server.await.unwrap();
290        });
291
292        tokio::spawn(async move {
293            while !*stop.lock().await {
294                tokio::time::sleep(Duration::from_millis(100)).await;
295            }
296        });
297
298        Ok((address, port))
299    }
300
301    async fn release_eventing_server(&mut self) -> Result<()> {
302        let mut stop = self.stop.lock().await;
303        *stop = true;
304        Ok(())
305    }
306}
307
308fn resolve_service(service_id: &str) -> String {
309    match service_id.contains(':') {
310        true => service_id.to_string(),
311        false => format!("urn:upnp-org:serviceId:{}", service_id),
312    }
313}