stac_client/
client.rs

1//! The core STAC API client and search builder.
2
3use crate::error::{Error, Result};
4use crate::models::{
5    Asset, Catalog, Collection, Conformance, DownloadedAsset, FieldsFilter, Item, ItemCollection,
6    SearchParams, SortBy, SortDirection,
7};
8use reqwest;
9use serde_json;
10use std::collections::HashMap;
11use std::sync::Arc;
12use tokio::sync::OnceCell;
13use url::Url;
14
15/// An async client for a STAC API.
16///
17/// This client provides methods for interacting with a STAC-compliant API,
18/// allowing you to fetch `Catalog`, `Collection`, and `Item` objects, and to
19/// perform searches.
20///
21/// The client is inexpensive to clone, as it wraps its internal state in an `Arc`.
22#[derive(Debug, Clone)]
23pub struct Client {
24    inner: Arc<ClientInner>,
25}
26
27#[derive(Debug)]
28struct ClientInner {
29    base_url: Url,
30    client: reqwest::Client,
31    conformance: OnceCell<Conformance>,
32    #[cfg(feature = "resilience")]
33    resilience_policy: Option<crate::resilience::ResiliencePolicy>,
34    #[cfg(feature = "auth")]
35    auth_layers: Vec<Box<dyn crate::auth::AuthLayer>>,
36}
37
38impl Client {
39    /// Creates a new `Client` for a given STAC API base URL.
40    ///
41    /// # Arguments
42    ///
43    /// * `base_url` - The base URL of the STAC API (e.g.,
44    ///   `"https://planetarycomputer.microsoft.com/api/stac/v1"`).
45    ///
46    /// # Errors
47    ///
48    /// Returns an [`Error::Url`] if the provided `base_url` is not a valid URL.
49    pub fn new(base_url: &str) -> Result<Self> {
50        let base_url = Url::parse(base_url)?;
51        let client = reqwest::Client::new();
52        Ok(Self {
53            inner: Arc::new(ClientInner {
54                base_url,
55                client,
56                conformance: OnceCell::new(),
57                #[cfg(feature = "resilience")]
58                resilience_policy: None,
59                #[cfg(feature = "auth")]
60                auth_layers: Vec::new(),
61            }),
62        })
63    }
64
65    /// Creates a new `Client` from an existing `reqwest::Client`.
66    ///
67    /// This allows for customization of the underlying HTTP client, such as
68    /// setting default headers, proxies, or timeouts.
69    ///
70    /// # Errors
71    ///
72    /// Returns an [`Error::Url`] if the provided `base_url` is not a valid URL.
73    pub fn with_client(base_url: &str, client: reqwest::Client) -> Result<Self> {
74        let base_url = Url::parse(base_url)?;
75        Ok(Self {
76            inner: Arc::new(ClientInner {
77                base_url,
78                client,
79                conformance: OnceCell::new(),
80                #[cfg(feature = "resilience")]
81                resilience_policy: None,
82                #[cfg(feature = "auth")]
83                auth_layers: Vec::new(),
84            }),
85        })
86    }
87
88    /// Returns the base URL of the STAC API.
89    #[must_use]
90    pub fn base_url(&self) -> &Url {
91        &self.inner.base_url
92    }
93
94    /// Applies all configured authentication layers to a request builder.
95    #[cfg(feature = "auth")]
96    fn apply_auth(&self, req: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
97        self.inner
98            .auth_layers
99            .iter()
100            .fold(req, |req, layer| layer.apply(req))
101    }
102
103    /// No-op when auth feature is disabled.
104    #[cfg(not(feature = "auth"))]
105    fn apply_auth(&self, req: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
106        req
107    }
108
109    /// Fetches the root `Catalog` or `Collection` from the API.
110    ///
111    /// # Errors
112    ///
113    /// Returns an `Error` if the request fails or the response cannot be parsed.
114    pub async fn get_catalog(&self) -> Result<Catalog> {
115        let url = self.inner.base_url.clone();
116        self.fetch_json(&url).await
117    }
118
119    /// Fetches all `Collection` objects from the `/collections` endpoint.
120    ///
121    /// # Errors
122    ///
123    /// Returns an `Error` if the request fails or the response cannot be parsed.
124    pub async fn get_collections(&self) -> Result<Vec<Collection>> {
125        #[derive(serde::Deserialize)]
126        struct CollectionsResponse {
127            collections: Vec<Collection>,
128        }
129
130        let mut url = self.inner.base_url.clone();
131        url.path_segments_mut()
132            .map_err(|()| Error::InvalidEndpoint("Cannot modify URL path".to_string()))?
133            .push("collections");
134
135        let response: CollectionsResponse = self.fetch_json(&url).await?;
136        Ok(response.collections)
137    }
138
139    /// Fetches a single `Collection` by its ID from the `/collections/{collection_id}` endpoint.
140    ///
141    /// # Errors
142    ///
143    /// Returns an `Error` if the request fails or the response cannot be parsed.
144    pub async fn get_collection(&self, collection_id: &str) -> Result<Collection> {
145        let mut url = self.inner.base_url.clone();
146        url.path_segments_mut()
147            .map_err(|()| Error::InvalidEndpoint("Cannot modify URL path".to_string()))?
148            .push("collections")
149            .push(collection_id);
150
151        self.fetch_json(&url).await
152    }
153
154    /// Fetches an `ItemCollection` of `Item` objects from a specific collection.
155    ///
156    /// This method retrieves items from the `/collections/{collection_id}/items` endpoint.
157    /// Note that this retrieves only a single page of items; the `limit` parameter
158    /// can be used to control the page size.
159    ///
160    /// # Errors
161    ///
162    /// Returns an `Error` if the request fails or the response cannot be parsed.
163    pub async fn get_collection_items(
164        &self,
165        collection_id: &str,
166        limit: Option<u32>,
167    ) -> Result<ItemCollection> {
168        let mut url = self.inner.base_url.clone();
169        url.path_segments_mut()
170            .map_err(|()| Error::InvalidEndpoint("Cannot modify URL path".to_string()))?
171            .push("collections")
172            .push(collection_id)
173            .push("items");
174
175        if let Some(limit) = limit {
176            url.query_pairs_mut()
177                .append_pair("limit", &limit.to_string());
178        }
179
180        self.fetch_json(&url).await
181    }
182
183    /// Fetches a single `Item` by its collection ID and item ID.
184    ///
185    /// # Errors
186    ///
187    /// Returns an `Error` if the request fails or the response cannot be parsed.
188    pub async fn get_item(&self, collection_id: &str, item_id: &str) -> Result<Item> {
189        let mut url = self.inner.base_url.clone();
190        url.path_segments_mut()
191            .map_err(|()| Error::InvalidEndpoint("Cannot modify URL path".to_string()))?
192            .push("collections")
193            .push(collection_id)
194            .push("items")
195            .push(item_id);
196
197        self.fetch_json(&url).await
198    }
199
200    /// Searches for `Item` objects using the `POST /search` endpoint.
201    ///
202    /// This is the preferred method for searching, as it supports complex queries
203    /// that may be too long for a GET request's URL.
204    ///
205    /// # Errors
206    ///
207    /// Returns an `Error` if the request fails or the response cannot be parsed.
208    pub async fn search(&self, params: &SearchParams) -> Result<ItemCollection> {
209        let mut url = self.inner.base_url.clone();
210        url.path_segments_mut()
211            .map_err(|()| Error::InvalidEndpoint("Cannot modify URL path".to_string()))?
212            .push("search");
213
214        #[cfg(feature = "resilience")]
215        if let Some(ref policy) = self.inner.resilience_policy {
216            return self.post_with_retry(&url, params, policy).await;
217        }
218
219        let req = self.inner.client.post(url).json(params);
220        let req = self.apply_auth(req);
221        let response = req.send().await?;
222
223        self.handle_response(response).await
224    }
225
226    #[cfg(feature = "resilience")]
227    /// Posts JSON with retry logic according to the resilience policy.
228    async fn post_with_retry<T, B>(
229        &self,
230        url: &Url,
231        body: &B,
232        policy: &crate::resilience::ResiliencePolicy,
233    ) -> Result<T>
234    where
235        T: for<'de> serde::Deserialize<'de>,
236        B: serde::Serialize,
237    {
238        use std::time::Instant;
239
240        let start_time = Instant::now();
241        let mut attempt = 0;
242
243        loop {
244            // Check total timeout
245            if let Some(total_timeout) = policy.total_timeout {
246                if start_time.elapsed() >= total_timeout {
247                    return Err(Error::Api {
248                        status: 0,
249                        message: "Total operation timeout exceeded".to_string(),
250                    });
251                }
252            }
253
254            let req = self.inner.client.post(url.clone()).json(body);
255            let req = self.apply_auth(req);
256            let result = req.send().await;
257
258            match result {
259                Ok(response) => {
260                    let status = response.status().as_u16();
261
262                    // Check if we should retry based on status
263                    if policy.should_retry_status(status) && attempt < policy.max_attempts {
264                        let delay = if status == 429 {
265                            // Handle 429 with Retry-After header
266                            let retry_after = response
267                                .headers()
268                                .get(reqwest::header::RETRY_AFTER)
269                                .and_then(|v| v.to_str().ok())
270                                .and_then(|s| s.parse::<u64>().ok())
271                                .map(std::time::Duration::from_secs);
272
273                            retry_after
274                                .unwrap_or_else(|| policy.calculate_delay(attempt))
275                                .min(policy.max_delay)
276                        } else {
277                            policy.calculate_delay(attempt)
278                        };
279
280                        attempt += 1;
281                        tokio::time::sleep(delay).await;
282                        continue;
283                    }
284
285                    // Not retryable or max attempts reached, handle response
286                    return self.handle_response(response).await;
287                }
288                Err(e) => {
289                    // Check if network error is retryable
290                    if (e.is_timeout() || e.is_connect()) && attempt < policy.max_attempts {
291                        let delay = policy.calculate_delay(attempt);
292                        attempt += 1;
293                        tokio::time::sleep(delay).await;
294                        continue;
295                    }
296                    return Err(Error::Http(e));
297                }
298            }
299        }
300    }
301
302    /// Searches for `Item` objects using the `GET /search` endpoint.
303    ///
304    /// The `SearchParams` are converted into URL query parameters.
305    ///
306    /// # Errors
307    ///
308    /// Returns an `Error` if the request fails or the response cannot be parsed.
309    pub async fn search_get(&self, params: &SearchParams) -> Result<ItemCollection> {
310        let mut url = self.inner.base_url.clone();
311        url.path_segments_mut()
312            .map_err(|()| Error::InvalidEndpoint("Cannot modify URL path".to_string()))?
313            .push("search");
314
315        // Convert search params to query parameters
316        let query_params = Client::search_params_to_query(params)?;
317        for (key, value) in query_params {
318            url.query_pairs_mut().append_pair(&key, &value);
319        }
320
321        self.fetch_json(&url).await
322    }
323
324    /// Fetches the API's conformance classes from the `/conformance` endpoint.
325    ///
326    /// The result is cached, so subsequent calls will not make a new network request.
327    ///
328    /// # Errors
329    ///
330    /// Returns an `Error` if the request fails or the response cannot be parsed.
331    pub async fn conformance(&self) -> Result<&Conformance> {
332        self.inner
333            .conformance
334            .get_or_try_init(|| self.fetch_conformance())
335            .await
336    }
337
338    /// Fetches the API's conformance classes from the `/conformance` endpoint.
339    async fn fetch_conformance(&self) -> Result<Conformance> {
340        let mut url = self.inner.base_url.clone();
341        url.path_segments_mut()
342            .map_err(|()| Error::InvalidEndpoint("Cannot modify URL path".to_string()))?
343            .push("conformance");
344
345        self.fetch_json(&url).await
346    }
347
348    /// Fetches JSON from a URL and deserializes it into a target type.
349    async fn fetch_json<T>(&self, url: &Url) -> Result<T>
350    where
351        T: for<'de> serde::Deserialize<'de>,
352    {
353        #[cfg(feature = "resilience")]
354        if let Some(ref policy) = self.inner.resilience_policy {
355            return self.fetch_json_with_retry(url, policy).await;
356        }
357
358        let req = self.inner.client.get(url.clone());
359        let req = self.apply_auth(req);
360        let response = req.send().await?;
361        self.handle_response(response).await
362    }
363
364    #[cfg(feature = "resilience")]
365    /// Fetches JSON with retry logic according to the resilience policy.
366    async fn fetch_json_with_retry<T>(
367        &self,
368        url: &Url,
369        policy: &crate::resilience::ResiliencePolicy,
370    ) -> Result<T>
371    where
372        T: for<'de> serde::Deserialize<'de>,
373    {
374        use std::time::Instant;
375
376        let start_time = Instant::now();
377        let mut attempt = 0;
378
379        loop {
380            // Check total timeout
381            if let Some(total_timeout) = policy.total_timeout {
382                if start_time.elapsed() >= total_timeout {
383                    return Err(Error::Api {
384                        status: 0,
385                        message: "Total operation timeout exceeded".to_string(),
386                    });
387                }
388            }
389
390            let req = self.inner.client.get(url.clone());
391            let req = self.apply_auth(req);
392            let result = req.send().await;
393
394            match result {
395                Ok(response) => {
396                    let status = response.status().as_u16();
397
398                    // Check if we should retry based on status
399                    if policy.should_retry_status(status) && attempt < policy.max_attempts {
400                        let delay = if status == 429 {
401                            // Handle 429 with Retry-After header
402                            let retry_after = response
403                                .headers()
404                                .get(reqwest::header::RETRY_AFTER)
405                                .and_then(|v| v.to_str().ok())
406                                .and_then(|s| s.parse::<u64>().ok())
407                                .map(std::time::Duration::from_secs);
408
409                            retry_after
410                                .unwrap_or_else(|| policy.calculate_delay(attempt))
411                                .min(policy.max_delay)
412                        } else {
413                            policy.calculate_delay(attempt)
414                        };
415
416                        attempt += 1;
417                        tokio::time::sleep(delay).await;
418                        continue;
419                    }
420
421                    // Not retryable or max attempts reached, handle response
422                    return self.handle_response(response).await;
423                }
424                Err(e) => {
425                    // Check if network error is retryable
426                    if (e.is_timeout() || e.is_connect()) && attempt < policy.max_attempts {
427                        let delay = policy.calculate_delay(attempt);
428                        attempt += 1;
429                        tokio::time::sleep(delay).await;
430                        continue;
431                    }
432                    return Err(Error::Http(e));
433                }
434            }
435        }
436    }
437
438    /// Handles a `reqwest::Response`, deserializing a successful response body
439    /// or converting an error status into an `Error`.
440    async fn handle_response<T>(&self, response: reqwest::Response) -> Result<T>
441    where
442        T: for<'de> serde::Deserialize<'de>,
443    {
444        let status = response.status();
445        if status.is_success() {
446            let text = response.text().await?;
447            let result = serde_json::from_str(&text)?;
448            return Ok(result);
449        }
450
451        if status.as_u16() == 429 {
452            // Retry-After may be delta-seconds or an HTTP-date; we only parse integer seconds.
453            let retry_after = response
454                .headers()
455                .get(reqwest::header::RETRY_AFTER)
456                .and_then(|v| v.to_str().ok())
457                .and_then(|s| s.parse::<u64>().ok());
458            return Err(Error::RateLimited { retry_after });
459        }
460
461        let error_text = response
462            .text()
463            .await
464            .unwrap_or_else(|_| "Unknown error".to_string());
465        Err(Error::Api {
466            status: status.as_u16(),
467            message: error_text,
468        })
469    }
470
471    /// Converts `SearchParams` into a vector of key-value pairs for a GET request.
472    ///
473    /// # Errors
474    ///
475    /// Returns an [`Error::Json`] if any part of the search parameters
476    /// cannot be serialized into a string.
477    fn search_params_to_query(params: &SearchParams) -> Result<Vec<(String, String)>> {
478        let mut query_params = Vec::new();
479
480        if let Some(limit) = params.limit {
481            query_params.push(("limit".to_string(), limit.to_string()));
482        }
483
484        if let Some(bbox) = &params.bbox {
485            let bbox_str = bbox
486                .iter()
487                .map(std::string::ToString::to_string)
488                .collect::<Vec<_>>()
489                .join(",");
490            query_params.push(("bbox".to_string(), bbox_str));
491        }
492
493        if let Some(datetime) = &params.datetime {
494            query_params.push(("datetime".to_string(), datetime.clone()));
495        }
496
497        if let Some(collections) = &params.collections {
498            let collections_str = collections.join(",");
499            query_params.push(("collections".to_string(), collections_str));
500        }
501
502        if let Some(ids) = &params.ids {
503            let ids_str = ids.join(",");
504            query_params.push(("ids".to_string(), ids_str));
505        }
506
507        if let Some(intersects) = &params.intersects {
508            let intersects_str = serde_json::to_string(intersects)?;
509            query_params.push(("intersects".to_string(), intersects_str));
510        }
511
512        // Handle query parameters (simplified - full implementation would need more complex handling)
513        if let Some(query) = &params.query {
514            for (key, value) in query {
515                let value_str = serde_json::to_string(value)?;
516                query_params.push((format!("query[{key}]"), value_str));
517            }
518        }
519
520        if let Some(sort_by) = &params.sortby {
521            let sort_str = sort_by
522                .iter()
523                .map(|s| {
524                    let prefix = match s.direction {
525                        SortDirection::Asc => "+",
526                        SortDirection::Desc => "-",
527                    };
528                    format!("{}{}", prefix, s.field)
529                })
530                .collect::<Vec<_>>()
531                .join(",");
532            query_params.push(("sortby".to_string(), sort_str));
533        }
534
535        if let Some(fields) = &params.fields {
536            let mut field_specs = Vec::new();
537            if let Some(include) = &fields.include {
538                field_specs.extend(include.iter().cloned());
539            }
540            if let Some(exclude) = &fields.exclude {
541                field_specs.extend(exclude.iter().map(|f| format!("-{f}")));
542            }
543
544            if !field_specs.is_empty() {
545                query_params.push(("fields".to_string(), field_specs.join(",")));
546            }
547        }
548
549        Ok(query_params)
550    }
551
552    /// Fetches the next page of results from an `ItemCollection`.
553    ///
554    /// This is a convenience helper available when the `pagination` feature is enabled.
555    /// It searches the `ItemCollection` links for one with `rel="next"` and, if
556    /// found, fetches the corresponding URL.
557    ///
558    /// Returns `Ok(None)` if no "next" link is present.
559    ///
560    /// # Errors
561    ///
562    /// Returns an `Error` if the request for the next page fails.
563    #[cfg(feature = "pagination")]
564    pub async fn search_next_page(
565        &self,
566        current: &ItemCollection,
567    ) -> Result<Option<ItemCollection>> {
568        let next_href = match &current.links {
569            Some(links) => links
570                .iter()
571                .find(|l| l.rel == "next")
572                .map(|l| l.href.clone()),
573            None => None,
574        };
575        let Some(href) = next_href else {
576            return Ok(None);
577        };
578        let url = Url::parse(&href).map_err(|e| Error::InvalidEndpoint(e.to_string()))?;
579        let page: ItemCollection = self.fetch_json(&url).await?;
580        Ok(Some(page))
581    }
582
583    /// Downloads a STAC `Asset` into memory.
584    ///
585    /// This method fetches the asset's data from its `href` URL and returns a
586    /// `DownloadedAsset` containing the raw bytes. It reuses the client's
587    /// authentication and resilience settings.
588    ///
589    /// # Arguments
590    ///
591    /// * `asset` - A reference to the `Asset` to be downloaded.
592    ///
593    /// # Errors
594    ///
595    /// Returns an `Error` if the `href` is not a valid URL, the request fails,
596    /// or the response body cannot be read.
597    pub async fn download_asset(&self, asset: &Asset) -> Result<DownloadedAsset> {
598        let url = Url::parse(&asset.href)?;
599        let req = self.inner.client.get(url);
600        let req = self.apply_auth(req);
601        let response = req.send().await?;
602
603        if !response.status().is_success() {
604            return Err(Error::Api {
605                status: response.status().as_u16(),
606                message: format!("Failed to download asset: {}", asset.href),
607            });
608        }
609
610        let bytes = response.bytes().await?;
611        Ok(DownloadedAsset {
612            content: bytes.to_vec(),
613        })
614    }
615}
616
617#[cfg(any(feature = "resilience", feature = "auth"))]
618/// A builder for constructing a `Client` with resilience and/or authentication features.
619///
620/// This builder allows for fluent configuration of the STAC client,
621/// including resilience policies for retries/timeouts and pluggable
622/// authentication layers.
623///
624/// # Example
625///
626/// ```rust,ignore
627/// use stac_client::{ClientBuilder, ResiliencePolicy};
628/// use std::time::Duration;
629///
630/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
631/// let policy = ResiliencePolicy::new()
632///     .max_attempts(5)
633///     .base_delay(Duration::from_millis(200));
634///
635/// let client = ClientBuilder::new("https://api.example.com/stac")
636///     .resilience_policy(policy) // resilience feature
637///     .auth_layer(stac_client::auth::ApiKey::new("X-API-Key", "secret")) // auth feature
638///     .build()?;
639/// # Ok(())
640/// # }
641/// ```
642#[derive(Debug, Default)]
643pub struct ClientBuilder {
644    base_url: String,
645    #[cfg(feature = "resilience")]
646    resilience_policy: Option<crate::resilience::ResiliencePolicy>,
647    #[cfg(feature = "auth")]
648    auth_layers: Vec<Box<dyn crate::auth::AuthLayer>>,
649}
650
651#[cfg(any(feature = "resilience", feature = "auth"))]
652impl ClientBuilder {
653    /// Creates a new `ClientBuilder` for the given base URL.
654    ///
655    /// # Arguments
656    ///
657    /// * `base_url` - The base URL of the STAC API.
658    #[must_use]
659    pub fn new(base_url: &str) -> Self {
660        Self {
661            base_url: base_url.to_string(),
662            ..Default::default()
663        }
664    }
665
666    /// Sets the resilience policy for the client.
667    ///
668    /// Requires the `resilience` feature.
669    ///
670    /// # Arguments
671    ///
672    /// * `policy` - The `ResiliencePolicy` to use for retries and timeouts.
673    #[cfg(feature = "resilience")]
674    #[must_use]
675    pub fn resilience_policy(mut self, policy: crate::resilience::ResiliencePolicy) -> Self {
676        self.resilience_policy = Some(policy);
677        self
678    }
679
680    /// Adds an authentication layer to the client.
681    ///
682    /// Requires the `auth` feature.
683    ///
684    /// # Arguments
685    ///
686    /// * `layer` - An implementation of `AuthLayer` to apply to all requests.
687    #[cfg(feature = "auth")]
688    #[must_use]
689    pub fn auth_layer(mut self, layer: impl crate::auth::AuthLayer + 'static) -> Self {
690        self.auth_layers.push(Box::new(layer));
691        self
692    }
693
694    /// Builds and returns a configured `Client`.
695    ///
696    /// # Errors
697    ///
698    /// Returns an [`Error::Url`] if the provided `base_url` is not a valid URL.
699    pub fn build(self) -> Result<Client> {
700        let base_url = Url::parse(&self.base_url)?;
701        let mut client_builder = reqwest::Client::builder();
702
703        #[cfg(feature = "resilience")]
704        if let Some(ref policy) = self.resilience_policy {
705            if let Some(timeout) = policy.request_timeout {
706                client_builder = client_builder.timeout(timeout);
707            }
708            if let Some(connect_timeout) = policy.connect_timeout {
709                client_builder = client_builder.connect_timeout(connect_timeout);
710            }
711        }
712
713        let client = client_builder.build()?;
714        let inner = ClientInner {
715            base_url,
716            client,
717            conformance: OnceCell::new(),
718            #[cfg(feature = "resilience")]
719            resilience_policy: self.resilience_policy,
720            #[cfg(feature = "auth")]
721            auth_layers: self.auth_layers,
722        };
723
724        Ok(Client {
725            inner: Arc::new(inner),
726        })
727    }
728}
729
730/// A fluent builder for constructing `SearchParams`.
731///
732/// This builder helps create a `SearchParams` struct, which can be passed to
733/// the `Client::search` or `Client::search_get` methods.
734pub struct SearchBuilder {
735    params: SearchParams,
736}
737
738impl SearchBuilder {
739    /// Creates a new, empty `SearchBuilder`.
740    #[must_use]
741    pub fn new() -> Self {
742        Self {
743            params: SearchParams::default(),
744        }
745    }
746
747    /// Sets the maximum number of items to return (the `limit` parameter).
748    #[must_use]
749    pub fn limit(mut self, limit: u32) -> Self {
750        self.params.limit = Some(limit);
751        self
752    }
753
754    /// Sets the spatial bounding box for the search.
755    ///
756    /// The coordinates must be in the order: `[west, south, east, north]`.
757    /// An optional fifth and sixth element can be used to specify a vertical
758    /// range (`[min_elevation, max_elevation]`).
759    #[must_use]
760    pub fn bbox(mut self, bbox: Vec<f64>) -> Self {
761        self.params.bbox = Some(bbox);
762        self
763    }
764
765    /// Sets the temporal window for the search using a `datetime` string.
766    ///
767    /// This can be a single datetime or a closed/open interval.
768    /// See the [STAC API spec](https://github.com/radiantearth/stac-api-spec/blob/master/fragments/datetime/README.md)
769    /// for valid formats.
770    #[must_use]
771    pub fn datetime(mut self, datetime: &str) -> Self {
772        self.params.datetime = Some(datetime.to_string());
773        self
774    }
775
776    /// Restricts the search to a set of collection IDs.
777    #[must_use]
778    pub fn collections(mut self, collections: Vec<String>) -> Self {
779        self.params.collections = Some(collections);
780        self
781    }
782
783    /// Restricts the search to a set of item IDs.
784    #[must_use]
785    pub fn ids(mut self, ids: Vec<String>) -> Self {
786        self.params.ids = Some(ids);
787        self
788    }
789
790    /// Filters items that intersect a `GeoJSON` geometry.
791    #[must_use]
792    pub fn intersects(mut self, geometry: serde_json::Value) -> Self {
793        self.params.intersects = Some(geometry);
794        self
795    }
796
797    /// Adds a filter expression using the STAC Query Extension.
798    ///
799    /// If a query already exists for the given key, it will be overwritten.
800    #[must_use]
801    pub fn query(mut self, key: &str, value: serde_json::Value) -> Self {
802        self.params
803            .query
804            .get_or_insert_with(HashMap::new)
805            .insert(key.to_string(), value);
806        self
807    }
808
809    /// Adds a sorting rule. Multiple calls will append additional sort rules.
810    #[must_use]
811    pub fn sort_by(mut self, field: &str, direction: SortDirection) -> Self {
812        self.params
813            .sortby
814            .get_or_insert_with(Vec::new)
815            .push(SortBy {
816                field: field.to_string(),
817                direction,
818            });
819        self
820    }
821
822    /// Includes only the specified fields in the response.
823    ///
824    /// This will overwrite any previously set `include` fields.
825    #[must_use]
826    pub fn include_fields(mut self, fields: Vec<String>) -> Self {
827        self.params
828            .fields
829            .get_or_insert_with(FieldsFilter::default)
830            .include = Some(fields);
831        self
832    }
833
834    /// Excludes the specified fields from the response.
835    ///
836    /// This will overwrite any previously set `exclude` fields.
837    #[must_use]
838    pub fn exclude_fields(mut self, fields: Vec<String>) -> Self {
839        self.params
840            .fields
841            .get_or_insert_with(FieldsFilter::default)
842            .exclude = Some(fields);
843        self
844    }
845
846    /// Finalizes the builder and returns the constructed `SearchParams`.
847    #[must_use]
848    pub fn build(self) -> SearchParams {
849        self.params
850    }
851}
852
853impl Default for SearchBuilder {
854    fn default() -> Self {
855        Self::new()
856    }
857}
858
859#[cfg(test)]
860mod tests {
861    use super::*;
862    use mockito;
863    use serde_json::json;
864
865    #[test]
866    fn test_client_creation() {
867        let client = Client::new("https://example.com/stac").unwrap();
868        assert_eq!(client.base_url().as_str(), "https://example.com/stac");
869    }
870
871    #[test]
872    fn test_invalid_url() {
873        let result = Client::new("not-a-valid-url");
874        assert!(result.is_err());
875    }
876
877    #[test]
878    fn test_search_builder() {
879        let params = SearchBuilder::new()
880            .limit(10)
881            .bbox(vec![-180.0, -90.0, 180.0, 90.0])
882            .datetime("2023-01-01T00:00:00Z/2023-12-31T23:59:59Z")
883            .collections(vec!["collection1".to_string(), "collection2".to_string()])
884            .ids(vec!["item1".to_string(), "item2".to_string()])
885            .query("eo:cloud_cover", json!({"lt": 10}))
886            .sort_by("datetime", SortDirection::Desc)
887            .include_fields(vec!["id".to_string(), "geometry".to_string()])
888            .build();
889
890        assert_eq!(params.limit, Some(10));
891        assert_eq!(params.bbox, Some(vec![-180.0, -90.0, 180.0, 90.0]));
892        assert_eq!(
893            params.datetime,
894            Some("2023-01-01T00:00:00Z/2023-12-31T23:59:59Z".to_string())
895        );
896        assert_eq!(
897            params.collections,
898            Some(vec!["collection1".to_string(), "collection2".to_string()])
899        );
900        assert_eq!(
901            params.ids,
902            Some(vec!["item1".to_string(), "item2".to_string()])
903        );
904        assert!(params.query.is_some());
905        assert!(params.sortby.is_some());
906        assert!(params.fields.is_some());
907    }
908
909    #[tokio::test]
910    async fn test_get_catalog_mock() {
911        let mut server = mockito::Server::new_async().await;
912        let mock_catalog = json!({
913            "type": "Catalog",
914            "stac_version": "1.0.0",
915            "id": "test-catalog",
916            "description": "Test catalog",
917            "links": []
918        });
919
920        let mock = server
921            .mock("GET", "/")
922            .with_status(200)
923            .with_header("content-type", "application/json")
924            .with_body(mock_catalog.to_string())
925            .create_async()
926            .await;
927
928        let client = Client::new(&server.url()).unwrap();
929        let catalog = client.get_catalog().await.unwrap();
930
931        mock.assert_async().await;
932        assert_eq!(catalog.id, "test-catalog");
933        assert_eq!(catalog.stac_version, "1.0.0");
934    }
935
936    #[tokio::test]
937    async fn test_get_collections_mock() {
938        let mut server = mockito::Server::new_async().await;
939        let mock_response = json!({
940            "collections": [
941                {
942                    "type": "Collection",
943                    "stac_version": "1.0.0",
944                    "id": "test-collection",
945                    "description": "Test collection",
946                    "license": "MIT",
947                    "extent": {
948                        "spatial": {
949                            "bbox": [[-180.0, -90.0, 180.0, 90.0]]
950                        },
951                        "temporal": {
952                            "interval": [["2023-01-01T00:00:00Z", "2023-12-31T23:59:59Z"]]
953                        }
954                    },
955                    "links": []
956                }
957            ]
958        });
959
960        let mock = server
961            .mock("GET", "/collections")
962            .with_status(200)
963            .with_header("content-type", "application/json")
964            .with_body(mock_response.to_string())
965            .create_async()
966            .await;
967
968        let client = Client::new(&server.url()).unwrap();
969        let collections = client.get_collections().await.unwrap();
970
971        mock.assert_async().await;
972        assert_eq!(collections.len(), 1);
973        assert_eq!(collections[0].id, "test-collection");
974    }
975
976    #[tokio::test]
977    async fn test_search_mock() {
978        let mut server = mockito::Server::new_async().await;
979        let mock_response = json!({
980            "type": "FeatureCollection",
981            "features": [
982                {
983                    "type": "Feature",
984                    "stac_version": "1.0.0",
985                    "id": "test-item",
986                    "geometry": null,
987                    "properties": {
988                        "datetime": "2023-01-01T12:00:00Z"
989                    },
990                    "links": [],
991                    "assets": {},
992                    "collection": "test-collection"
993                }
994            ]
995        });
996
997        let mock = server
998            .mock("POST", "/search")
999            .with_status(200)
1000            .with_header("content-type", "application/json")
1001            .with_body(mock_response.to_string())
1002            .create_async()
1003            .await;
1004
1005        let client = Client::new(&server.url()).unwrap();
1006        let search_params = SearchBuilder::new()
1007            .limit(10)
1008            .collections(vec!["test-collection".to_string()])
1009            .build();
1010
1011        let results = client.search(&search_params).await.unwrap();
1012
1013        mock.assert_async().await;
1014        assert_eq!(results.features.len(), 1);
1015        assert_eq!(results.features[0].id, "test-item");
1016        assert_eq!(
1017            results.features[0].collection.as_ref().unwrap(),
1018            "test-collection"
1019        );
1020    }
1021
1022    #[tokio::test]
1023    async fn test_error_handling() {
1024        let mut server = mockito::Server::new_async().await;
1025        let mock = server
1026            .mock("GET", "/")
1027            .with_status(404)
1028            .with_body("Not found")
1029            .create_async()
1030            .await;
1031
1032        let client = Client::new(&server.url()).unwrap();
1033        let result = client.get_catalog().await;
1034
1035        mock.assert_async().await;
1036        assert!(result.is_err());
1037        match result.unwrap_err() {
1038            Error::Api { status, .. } => assert_eq!(status, 404),
1039            _ => panic!("Expected API error"),
1040        }
1041    }
1042
1043    #[test]
1044    fn test_search_params_to_query() {
1045        let params = SearchParams {
1046            limit: Some(10),
1047            bbox: Some(vec![-180.0, -90.0, 180.0, 90.0]),
1048            datetime: Some("2023-01-01T00:00:00Z".to_string()),
1049            collections: Some(vec!["col1".to_string(), "col2".to_string()]),
1050            ids: Some(vec!["id1".to_string(), "id2".to_string()]),
1051            ..Default::default()
1052        };
1053
1054        let query_params = Client::search_params_to_query(&params).unwrap();
1055
1056        // Check that all expected parameters are present
1057        let param_map: std::collections::HashMap<String, String> =
1058            query_params.into_iter().collect();
1059
1060        assert_eq!(param_map.get("limit").unwrap(), "10");
1061        assert_eq!(param_map.get("bbox").unwrap(), "-180,-90,180,90");
1062        assert_eq!(param_map.get("datetime").unwrap(), "2023-01-01T00:00:00Z");
1063        assert_eq!(param_map.get("collections").unwrap(), "col1,col2");
1064        assert_eq!(param_map.get("ids").unwrap(), "id1,id2");
1065    }
1066
1067    #[test]
1068    fn test_search_params_to_query_with_intersects_and_query() {
1069        let mut query_map = HashMap::new();
1070        query_map.insert("eo:cloud_cover".to_string(), json!({"lt": 5}));
1071        let geom = json!({
1072            "type": "Point",
1073            "coordinates": [0.0, 0.0]
1074        });
1075        let params = SearchParams {
1076            intersects: Some(geom.clone()),
1077            query: Some(query_map.clone()),
1078            ..Default::default()
1079        };
1080
1081        let query_params = Client::search_params_to_query(&params).unwrap();
1082        let param_map: std::collections::HashMap<String, String> =
1083            query_params.into_iter().collect();
1084
1085        // Ensure intersects serialized and query expression present
1086        assert!(param_map.contains_key("intersects"));
1087        // URL encoding not applied yet (raw value) so we can check JSON substring
1088        assert!(param_map.get("intersects").unwrap().contains("\"Point\""));
1089        assert!(param_map.contains_key("query[eo:cloud_cover]"));
1090        assert_eq!(
1091            param_map.get("query[eo:cloud_cover]").unwrap(),
1092            &serde_json::to_string(&json!({"lt": 5})).unwrap()
1093        );
1094    }
1095
1096    #[test]
1097    fn test_search_params_to_query_with_sortby_and_fields() {
1098        let params = SearchBuilder::new()
1099            .sort_by("datetime", SortDirection::Asc)
1100            .sort_by("eo:cloud_cover", SortDirection::Desc)
1101            .include_fields(vec!["id".to_string(), "properties".to_string()])
1102            .exclude_fields(vec!["geometry".to_string()])
1103            .build();
1104
1105        let query_params = Client::search_params_to_query(&params).unwrap();
1106        let param_map: std::collections::HashMap<String, String> =
1107            query_params.into_iter().collect();
1108
1109        assert_eq!(
1110            param_map.get("sortby").unwrap(),
1111            "+datetime,-eo:cloud_cover"
1112        );
1113        assert_eq!(param_map.get("fields").unwrap(), "id,properties,-geometry");
1114    }
1115
1116    #[tokio::test]
1117    async fn test_conformance_handling_mock() {
1118        let mut server = mockito::Server::new_async().await;
1119        let mock_conformance = json!({
1120            "conformsTo": [
1121                "https://api.stacspec.org/v1.0.0/core",
1122                "https://api.stacspec.org/v1.0.0/collections",
1123                "http://www.opengis.net/spec/ogcapi-features-1/1.0/conf/core"
1124            ]
1125        });
1126
1127        let mock = server
1128            .mock("GET", "/conformance")
1129            .with_status(200)
1130            .with_header("content-type", "application/json")
1131            .with_body(mock_conformance.to_string())
1132            .create_async()
1133            .await;
1134
1135        let client = Client::new(&server.url()).unwrap();
1136
1137        // First call should fetch and cache
1138        let conformance = client.conformance().await.unwrap();
1139        assert!(conformance.conforms_to("https://api.stacspec.org/v1.0.0/core"));
1140        assert!(!conformance.conforms_to("https://api.stacspec.org/v1.0.0/item-search"));
1141
1142        // Second call should use the cache
1143        let conformance_cached = client.conformance().await.unwrap();
1144        assert_eq!(conformance.conforms_to, conformance_cached.conforms_to);
1145
1146        // The mock should have been called exactly once
1147        mock.assert_async().await;
1148    }
1149
1150    #[test]
1151    fn test_search_builder_exclude_fields() {
1152        let params = SearchBuilder::new()
1153            .exclude_fields(vec!["geometry".to_string(), "assets".to_string()])
1154            .build();
1155        assert!(params.fields.is_some());
1156        let fields = params.fields.unwrap();
1157        assert!(fields.include.is_none());
1158        assert_eq!(
1159            fields.exclude.unwrap(),
1160            vec!["geometry".to_string(), "assets".to_string()]
1161        );
1162    }
1163
1164    #[tokio::test]
1165    async fn test_download_asset_mock() {
1166        let mut server = mockito::Server::new_async().await;
1167        let mock_asset_content = "mock asset data";
1168
1169        let mock = server
1170            .mock("GET", "/asset.txt")
1171            .with_status(200)
1172            .with_body(mock_asset_content)
1173            .create_async()
1174            .await;
1175
1176        let client = Client::new(&server.url()).unwrap();
1177        let asset = Asset {
1178            href: server.url() + "/asset.txt",
1179            title: None,
1180            description: None,
1181            media_type: None,
1182            roles: None,
1183            extra: HashMap::default(),
1184        };
1185
1186        let downloaded = client.download_asset(&asset).await.unwrap();
1187        mock.assert_async().await;
1188        assert_eq!(downloaded.content, mock_asset_content.as_bytes());
1189    }
1190}