Skip to main content

unifly_api/session/
client.rs

1// Session API HTTP client
2//
3// Wraps `reqwest::Client` with UniFi-specific URL construction, envelope
4// unwrapping, and platform-aware path prefixing. All endpoint modules
5// (devices, clients, etc.) are implemented as inherent methods via
6// separate files to keep this module focused on transport mechanics.
7
8use std::sync::{Arc, RwLock};
9
10use reqwest::cookie::{CookieStore, Jar};
11use serde::Serialize;
12use serde::de::DeserializeOwned;
13use tracing::{debug, trace};
14use url::Url;
15
16use crate::auth::ControllerPlatform;
17use crate::error::Error;
18use crate::session::models::SessionResponse;
19use crate::transport::TransportConfig;
20
21/// UniFi OS wraps some errors as `{"error":{"code":N,"message":"..."}}` with HTTP 200.
22#[derive(serde::Deserialize)]
23struct UnifiOsError {
24    error: Option<UnifiOsErrorInner>,
25}
26
27#[derive(serde::Deserialize)]
28struct UnifiOsErrorInner {
29    code: u16,
30    message: Option<String>,
31}
32
33/// How this session client authenticates with the controller.
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum SessionAuth {
36    /// Real session cookie from username/password login.
37    Cookie,
38    /// API key passed via `X-API-KEY` header (no session cookie).
39    /// Some session endpoints (e.g. `stat/event`) are unavailable.
40    ApiKey,
41}
42
43/// Raw HTTP client for the UniFi controller's session API.
44///
45/// Handles the `{ data: [], meta: { rc, msg } }` envelope, site-scoped
46/// URL construction, and platform-aware path prefixing. All methods return
47/// unwrapped `data` payloads -- the envelope is stripped before the caller
48/// sees it.
49pub struct SessionClient {
50    http: reqwest::Client,
51    base_url: Url,
52    site: String,
53    platform: ControllerPlatform,
54    auth: SessionAuth,
55    /// CSRF token for UniFi OS. Required on all POST/PUT/DELETE requests
56    /// through the `/proxy/network/` path. Captured from login response
57    /// headers and rotated via `X-Updated-CSRF-Token`.
58    csrf_token: RwLock<Option<String>>,
59    /// Cookie jar reference for extracting session cookies (e.g. for WebSocket auth).
60    cookie_jar: Option<Arc<Jar>>,
61}
62
63impl SessionClient {
64    /// Create a new session client from a `TransportConfig`.
65    ///
66    /// If the config doesn't already include a cookie jar, one is created
67    /// automatically (session auth requires cookies). The `base_url` should be
68    /// the controller root (e.g. `https://192.168.1.1` for UniFi OS or
69    /// `https://controller:8443` for standalone).
70    pub fn new(
71        base_url: Url,
72        site: String,
73        platform: ControllerPlatform,
74        transport: &TransportConfig,
75    ) -> Result<Self, Error> {
76        let config = if transport.cookie_jar.is_some() {
77            transport.clone()
78        } else {
79            transport.clone().with_cookie_jar()
80        };
81        let cookie_jar = config.cookie_jar.clone();
82        let http = config.build_client()?;
83        Ok(Self {
84            http,
85            base_url,
86            site,
87            platform,
88            auth: SessionAuth::Cookie,
89            csrf_token: RwLock::new(None),
90            cookie_jar,
91        })
92    }
93
94    /// Create a session client with a pre-built `reqwest::Client`.
95    ///
96    /// Use this when you already have a client with a session cookie in its
97    /// jar (e.g. after authenticating via a shared client).
98    pub fn with_client(
99        http: reqwest::Client,
100        base_url: Url,
101        site: String,
102        platform: ControllerPlatform,
103        auth: SessionAuth,
104    ) -> Self {
105        Self {
106            http,
107            base_url,
108            site,
109            platform,
110            auth,
111            csrf_token: RwLock::new(None),
112            cookie_jar: None,
113        }
114    }
115
116    /// The authentication method used by this client.
117    pub fn auth(&self) -> SessionAuth {
118        self.auth
119    }
120
121    /// The current site identifier.
122    pub fn site(&self) -> &str {
123        &self.site
124    }
125
126    /// The underlying HTTP client (for auth flows that need direct access).
127    pub fn http(&self) -> &reqwest::Client {
128        &self.http
129    }
130
131    /// The controller base URL.
132    pub fn base_url(&self) -> &Url {
133        &self.base_url
134    }
135
136    /// The detected controller platform.
137    pub fn platform(&self) -> ControllerPlatform {
138        self.platform
139    }
140
141    /// Extract the session cookie header value for WebSocket auth.
142    ///
143    /// Returns the `Cookie` header string (e.g. `"TOKEN=abc123"`) if a
144    /// cookie jar is available and contains cookies for the controller URL.
145    pub fn cookie_header(&self) -> Option<String> {
146        let jar = self.cookie_jar.as_ref()?;
147        let cookies = jar.cookies(&self.base_url)?;
148        cookies.to_str().ok().map(String::from)
149    }
150
151    // ── Cookie injection (for MFA flow) ───────────────────────────────
152
153    /// Inject a `Set-Cookie` header value into the client's cookie jar.
154    ///
155    /// Used by the MFA flow to inject the `UBIC_2FA` cookie before retrying
156    /// login with the TOTP token.
157    pub(crate) fn add_cookie(&self, set_cookie_value: &str, url: &Url) -> Result<(), Error> {
158        let jar = self
159            .cookie_jar
160            .as_ref()
161            .ok_or_else(|| Error::Authentication {
162                message: "no cookie jar available for MFA flow".into(),
163            })?;
164        let header_value: reqwest::header::HeaderValue =
165            set_cookie_value
166                .parse()
167                .map_err(|_| Error::Authentication {
168                    message: "failed to parse MFA cookie value".into(),
169                })?;
170        jar.set_cookies(&mut std::iter::once(&header_value), url);
171        Ok(())
172    }
173
174    // ── CSRF token management ─────────────────────────────────────────
175
176    /// Read the current CSRF token value (for session caching).
177    pub(crate) fn csrf_token_value(&self) -> Option<String> {
178        self.csrf_token.read().expect("CSRF lock poisoned").clone()
179    }
180
181    /// Store a CSRF token (captured from login response headers).
182    pub(crate) fn set_csrf_token(&self, token: String) {
183        debug!("storing CSRF token");
184        *self.csrf_token.write().expect("CSRF lock poisoned") = Some(token);
185    }
186
187    /// Update CSRF token if the response contains a rotated value.
188    fn update_csrf_from_response(&self, headers: &reqwest::header::HeaderMap) {
189        // UniFi OS may rotate tokens — prefer the updated one.
190        let new_token = headers
191            .get("X-Updated-CSRF-Token")
192            .or_else(|| headers.get("x-csrf-token"))
193            .and_then(|v| v.to_str().ok())
194            .map(String::from);
195
196        if let Some(token) = new_token {
197            trace!("CSRF token rotated");
198            *self.csrf_token.write().expect("CSRF lock poisoned") = Some(token);
199        }
200    }
201
202    /// Apply the stored CSRF token to a request builder.
203    fn apply_csrf(&self, builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
204        let guard = self.csrf_token.read().expect("CSRF lock poisoned");
205        match guard.as_deref() {
206            Some(token) => builder.header("X-CSRF-Token", token),
207            None => builder,
208        }
209    }
210
211    /// Classify a session 401 based on the active auth strategy.
212    ///
213    /// Cookie-backed session clients surface 401s as an expired session.
214    /// API-key clients surface the same status as a rejected key.
215    fn unauthorized_error(&self) -> Error {
216        match self.auth {
217            SessionAuth::Cookie => Error::SessionExpired,
218            SessionAuth::ApiKey => Error::InvalidApiKey,
219        }
220    }
221
222    // ── URL builders ─────────────────────────────────────────────────
223
224    /// Build a full URL for a controller-level API path.
225    ///
226    /// Applies the platform-specific session prefix, then appends `/api/{path}`.
227    /// For example, on UniFi OS: `https://host/proxy/network/api/{path}`
228    pub(crate) fn api_url(&self, path: &str) -> Url {
229        let prefix = self.platform.session_prefix().unwrap_or("");
230        let base = self.base_url.as_str().trim_end_matches('/');
231        let prefix = prefix.trim_end_matches('/');
232        let full = format!("{base}{prefix}/api/{path}");
233        Url::parse(&full).expect("invalid API URL")
234    }
235
236    /// Build a site-scoped URL: `{base}{prefix}/api/s/{site}/{path}`
237    ///
238    /// Most session endpoints are site-scoped: stat/device, cmd/devmgr, etc.
239    pub(crate) fn site_url(&self, path: &str) -> Url {
240        let prefix = self.platform.session_prefix().unwrap_or("");
241        let base = self.base_url.as_str().trim_end_matches('/');
242        let prefix = prefix.trim_end_matches('/');
243        let full = format!("{base}{prefix}/api/s/{}/{path}", self.site);
244        Url::parse(&full).expect("invalid site URL")
245    }
246
247    /// Build a v2 site-scoped URL: `{base}{prefix}/v2/api/site/{site}/{path}`
248    ///
249    /// Used by newer endpoints (Network Application 9+) that use the v2 path
250    /// format, e.g. traffic-flow-latest-statistics.
251    pub(crate) fn site_url_v2(&self, path: &str) -> Url {
252        let prefix = self.platform.session_prefix().unwrap_or("");
253        let base = self.base_url.as_str().trim_end_matches('/');
254        let prefix = prefix.trim_end_matches('/');
255        let full = format!("{base}{prefix}/v2/api/site/{}/{path}", self.site);
256        Url::parse(&full).expect("invalid v2 site URL")
257    }
258
259    // ── Request helpers ──────────────────────────────────────────────
260
261    /// Send a GET request and unwrap the session envelope.
262    pub(crate) async fn get<T: DeserializeOwned>(&self, url: Url) -> Result<Vec<T>, Error> {
263        debug!("GET {}", url);
264
265        let resp = self.http.get(url).send().await.map_err(Error::Transport)?;
266
267        self.parse_envelope(resp).await
268    }
269
270    /// Send a GET request and return the raw JSON response (no envelope unwrapping).
271    ///
272    /// Used for v2 API endpoints that return plain JSON instead of the
273    /// session `{ meta, data }` envelope.
274    pub(crate) async fn get_raw(&self, url: Url) -> Result<serde_json::Value, Error> {
275        debug!("GET (raw) {}", url);
276
277        let resp = self.http.get(url).send().await.map_err(Error::Transport)?;
278        let status = resp.status();
279
280        if status == reqwest::StatusCode::UNAUTHORIZED {
281            return Err(self.unauthorized_error());
282        }
283        if !status.is_success() {
284            let body = resp.text().await.unwrap_or_default();
285            return Err(Error::SessionApi {
286                message: format!("HTTP {status}: {}", &body[..body.len().min(200)]),
287            });
288        }
289
290        let body = resp.text().await.map_err(Error::Transport)?;
291        serde_json::from_str(&body).map_err(|e| Error::Deserialization {
292            message: format!("{e}"),
293            body,
294        })
295    }
296
297    /// Send a POST request with JSON body and unwrap the session envelope.
298    pub(crate) async fn post<T: DeserializeOwned>(
299        &self,
300        url: Url,
301        body: &(impl Serialize + Sync),
302    ) -> Result<Vec<T>, Error> {
303        debug!("POST {}", url);
304
305        let builder = self.apply_csrf(self.http.post(url).json(body));
306        let resp = builder.send().await.map_err(Error::Transport)?;
307
308        self.parse_envelope(resp).await
309    }
310
311    /// Send a PUT request with JSON body and unwrap the session envelope.
312    #[allow(dead_code)]
313    pub(crate) async fn put<T: DeserializeOwned>(
314        &self,
315        url: Url,
316        body: &(impl Serialize + Sync),
317    ) -> Result<Vec<T>, Error> {
318        debug!("PUT {}", url);
319
320        let builder = self.apply_csrf(self.http.put(url).json(body));
321        let resp = builder.send().await.map_err(Error::Transport)?;
322
323        self.parse_envelope(resp).await
324    }
325
326    /// Send a DELETE request and unwrap the session envelope.
327    #[allow(dead_code)]
328    pub(crate) async fn delete<T: DeserializeOwned>(&self, url: Url) -> Result<Vec<T>, Error> {
329        debug!("DELETE {}", url);
330
331        let builder = self.apply_csrf(self.http.delete(url));
332        let resp = builder.send().await.map_err(Error::Transport)?;
333
334        self.parse_envelope(resp).await
335    }
336
337    /// Send a raw GET to an arbitrary path (no envelope unwrapping).
338    ///
339    /// The `path` is appended directly after `{base}{prefix}/`.
340    pub async fn raw_get(&self, path: &str) -> Result<serde_json::Value, Error> {
341        let prefix = self.platform.session_prefix().unwrap_or("");
342        let base = self.base_url.as_str().trim_end_matches('/');
343        let prefix = prefix.trim_end_matches('/');
344        let url = Url::parse(&format!("{base}{prefix}/{path}")).expect("invalid raw URL");
345        self.get_raw(url).await
346    }
347
348    /// Send a raw POST to an arbitrary path (no envelope unwrapping).
349    pub async fn raw_post(
350        &self,
351        path: &str,
352        body: &serde_json::Value,
353    ) -> Result<serde_json::Value, Error> {
354        let prefix = self.platform.session_prefix().unwrap_or("");
355        let base = self.base_url.as_str().trim_end_matches('/');
356        let prefix = prefix.trim_end_matches('/');
357        let url = Url::parse(&format!("{base}{prefix}/{path}")).expect("invalid raw URL");
358        debug!("POST (raw) {}", url);
359
360        let builder = self.apply_csrf(self.http.post(url).json(body));
361        let resp = builder.send().await.map_err(Error::Transport)?;
362        let status = resp.status();
363
364        if status == reqwest::StatusCode::UNAUTHORIZED {
365            return Err(self.unauthorized_error());
366        }
367        if !status.is_success() {
368            let body = resp.text().await.unwrap_or_default();
369            return Err(Error::SessionApi {
370                message: format!("HTTP {status}: {}", &body[..body.len().min(200)]),
371            });
372        }
373
374        let body = resp.text().await.map_err(Error::Transport)?;
375        serde_json::from_str(&body).map_err(|e| Error::Deserialization {
376            message: format!("{e}"),
377            body,
378        })
379    }
380
381    /// Send a raw PUT to an arbitrary path (no envelope unwrapping).
382    pub async fn raw_put(
383        &self,
384        path: &str,
385        body: &serde_json::Value,
386    ) -> Result<serde_json::Value, Error> {
387        let prefix = self.platform.session_prefix().unwrap_or("");
388        let base = self.base_url.as_str().trim_end_matches('/');
389        let prefix = prefix.trim_end_matches('/');
390        let url = Url::parse(&format!("{base}{prefix}/{path}")).expect("invalid raw URL");
391        debug!("PUT (raw) {}", url);
392
393        let builder = self.apply_csrf(self.http.put(url).json(body));
394        let resp = builder.send().await.map_err(Error::Transport)?;
395        let status = resp.status();
396
397        if status == reqwest::StatusCode::UNAUTHORIZED {
398            return Err(self.unauthorized_error());
399        }
400        if !status.is_success() {
401            let body = resp.text().await.unwrap_or_default();
402            return Err(Error::SessionApi {
403                message: format!("HTTP {status}: {}", &body[..body.len().min(200)]),
404            });
405        }
406
407        let body = resp.text().await.map_err(Error::Transport)?;
408        serde_json::from_str(&body).map_err(|e| Error::Deserialization {
409            message: format!("{e}"),
410            body,
411        })
412    }
413
414    /// Send a raw PATCH to an arbitrary path (no envelope unwrapping).
415    pub async fn raw_patch(
416        &self,
417        path: &str,
418        body: &serde_json::Value,
419    ) -> Result<serde_json::Value, Error> {
420        let prefix = self.platform.session_prefix().unwrap_or("");
421        let base = self.base_url.as_str().trim_end_matches('/');
422        let prefix = prefix.trim_end_matches('/');
423        let url = Url::parse(&format!("{base}{prefix}/{path}")).expect("invalid raw URL");
424        debug!("PATCH (raw) {}", url);
425
426        let builder = self.apply_csrf(self.http.patch(url).json(body));
427        let resp = builder.send().await.map_err(Error::Transport)?;
428        let status = resp.status();
429
430        if status == reqwest::StatusCode::UNAUTHORIZED {
431            return Err(self.unauthorized_error());
432        }
433        if !status.is_success() {
434            let body = resp.text().await.unwrap_or_default();
435            return Err(Error::SessionApi {
436                message: format!("HTTP {status}: {}", &body[..body.len().min(200)]),
437            });
438        }
439
440        let body = resp.text().await.map_err(Error::Transport)?;
441        serde_json::from_str(&body).map_err(|e| Error::Deserialization {
442            message: format!("{e}"),
443            body,
444        })
445    }
446
447    /// Send a raw DELETE to an arbitrary path (no envelope unwrapping).
448    pub async fn raw_delete(&self, path: &str) -> Result<(), Error> {
449        let prefix = self.platform.session_prefix().unwrap_or("");
450        let base = self.base_url.as_str().trim_end_matches('/');
451        let prefix = prefix.trim_end_matches('/');
452        let url = Url::parse(&format!("{base}{prefix}/{path}")).expect("invalid raw URL");
453        debug!("DELETE (raw) {}", url);
454
455        let builder = self.apply_csrf(self.http.delete(url));
456        let resp = builder.send().await.map_err(Error::Transport)?;
457        let status = resp.status();
458
459        if status == reqwest::StatusCode::UNAUTHORIZED {
460            return Err(self.unauthorized_error());
461        }
462        if !status.is_success() {
463            let body = resp.text().await.unwrap_or_default();
464            return Err(Error::SessionApi {
465                message: format!("HTTP {status}: {}", &body[..body.len().min(200)]),
466            });
467        }
468
469        Ok(())
470    }
471
472    /// Parse the `{ meta, data }` envelope, returning `data` on success
473    /// or an `Error::SessionApi` if `meta.rc != "ok"`.
474    ///
475    /// Also handles UniFi OS error responses that use a different shape:
476    /// `{"error": {"code": 403, "message": "..."}}` (returned with HTTP 200).
477    async fn parse_envelope<T: DeserializeOwned>(
478        &self,
479        resp: reqwest::Response,
480    ) -> Result<Vec<T>, Error> {
481        let status = resp.status();
482
483        // Capture any CSRF token rotation before consuming the response.
484        self.update_csrf_from_response(resp.headers());
485
486        if status == reqwest::StatusCode::UNAUTHORIZED {
487            return Err(self.unauthorized_error());
488        }
489
490        if status == reqwest::StatusCode::FORBIDDEN {
491            return Err(Error::SessionApi {
492                message: "insufficient permissions (HTTP 403)".into(),
493            });
494        }
495
496        if !status.is_success() {
497            let body = resp.text().await.unwrap_or_default();
498            return Err(Error::SessionApi {
499                message: format!("HTTP {status}: {}", &body[..body.len().min(200)]),
500            });
501        }
502
503        let body = resp.text().await.map_err(Error::Transport)?;
504
505        // UniFi OS sometimes returns `{"error":{"code":N,"message":"..."}}` with HTTP 200.
506        if let Ok(wrapper) = serde_json::from_str::<UnifiOsError>(&body)
507            && let Some(err) = wrapper.error
508        {
509            let msg = err.message.unwrap_or_default();
510            return Err(if err.code == 401 {
511                if msg.is_empty() {
512                    self.unauthorized_error()
513                } else {
514                    match self.unauthorized_error() {
515                        Error::SessionExpired => Error::Authentication {
516                            message: format!("session expired: {msg}"),
517                        },
518                        Error::InvalidApiKey => Error::Authentication {
519                            message: format!("API key rejected: {msg}"),
520                        },
521                        other => other,
522                    }
523                }
524            } else {
525                Error::SessionApi {
526                    message: format!("UniFi OS error {}: {msg}", err.code),
527                }
528            });
529        }
530
531        let envelope: SessionResponse<T> = serde_json::from_str(&body).map_err(|e| {
532            let preview = &body[..body.len().min(200)];
533            Error::Deserialization {
534                message: format!("{e} (body preview: {preview:?})"),
535                body: body.clone(),
536            }
537        })?;
538
539        match envelope.meta.rc.as_str() {
540            "ok" => Ok(envelope.data),
541            _ => Err(Error::SessionApi {
542                message: envelope
543                    .meta
544                    .msg
545                    .unwrap_or_else(|| format!("rc={}", envelope.meta.rc)),
546            }),
547        }
548    }
549}
550
551#[cfg(test)]
552mod tests {
553    use url::Url;
554
555    use super::{SessionAuth, SessionClient};
556    use crate::{ControllerPlatform, Error};
557
558    fn client(auth: SessionAuth) -> SessionClient {
559        SessionClient::with_client(
560            reqwest::Client::new(),
561            Url::parse("https://controller.example").expect("valid test URL"),
562            "default".into(),
563            ControllerPlatform::ClassicController,
564            auth,
565        )
566    }
567
568    #[test]
569    fn unauthorized_cookie_client_reports_session_expired() {
570        assert!(matches!(
571            client(SessionAuth::Cookie).unauthorized_error(),
572            Error::SessionExpired
573        ));
574    }
575
576    #[test]
577    fn unauthorized_api_key_client_reports_invalid_api_key() {
578        assert!(matches!(
579            client(SessionAuth::ApiKey).unauthorized_error(),
580            Error::InvalidApiKey
581        ));
582    }
583}