1use crate::error::{Error, Result};
4use crate::models::{
5 Catalog, Collection, FieldsFilter, Item, ItemCollection, SearchParams, SortBy, SortDirection,
6};
7use reqwest;
8use serde_json;
9use std::collections::HashMap;
10use url::Url;
11
12#[derive(Debug, Clone)]
20pub struct Client {
21 base_url: Url,
22 client: reqwest::Client,
23}
24
25impl Client {
26 pub fn new(base_url: &str) -> Result<Self> {
37 let base_url = Url::parse(base_url)?;
38 let client = reqwest::Client::new();
39 Ok(Self { base_url, client })
40 }
41
42 pub fn with_client(base_url: &str, client: reqwest::Client) -> Result<Self> {
51 let base_url = Url::parse(base_url)?;
52 Ok(Self { base_url, client })
53 }
54
55 #[must_use]
57 pub fn base_url(&self) -> &Url {
58 &self.base_url
59 }
60
61 pub async fn get_catalog(&self) -> Result<Catalog> {
67 let url = self.base_url.clone();
68 self.fetch_json(&url).await
69 }
70
71 pub async fn get_collections(&self) -> Result<Vec<Collection>> {
77 #[derive(serde::Deserialize)]
78 struct CollectionsResponse {
79 collections: Vec<Collection>,
80 }
81
82 let mut url = self.base_url.clone();
83 url.path_segments_mut()
84 .map_err(|()| Error::InvalidEndpoint("Cannot modify URL path".to_string()))?
85 .push("collections");
86
87 let response: CollectionsResponse = self.fetch_json(&url).await?;
88 Ok(response.collections)
89 }
90
91 pub async fn get_collection(&self, collection_id: &str) -> Result<Collection> {
97 let mut url = self.base_url.clone();
98 url.path_segments_mut()
99 .map_err(|()| Error::InvalidEndpoint("Cannot modify URL path".to_string()))?
100 .push("collections")
101 .push(collection_id);
102
103 self.fetch_json(&url).await
104 }
105
106 pub async fn get_collection_items(
116 &self,
117 collection_id: &str,
118 limit: Option<u32>,
119 ) -> Result<ItemCollection> {
120 let mut url = self.base_url.clone();
121 url.path_segments_mut()
122 .map_err(|()| Error::InvalidEndpoint("Cannot modify URL path".to_string()))?
123 .push("collections")
124 .push(collection_id)
125 .push("items");
126
127 if let Some(limit) = limit {
128 url.query_pairs_mut()
129 .append_pair("limit", &limit.to_string());
130 }
131
132 self.fetch_json(&url).await
133 }
134
135 pub async fn get_item(&self, collection_id: &str, item_id: &str) -> Result<Item> {
141 let mut url = self.base_url.clone();
142 url.path_segments_mut()
143 .map_err(|()| Error::InvalidEndpoint("Cannot modify URL path".to_string()))?
144 .push("collections")
145 .push(collection_id)
146 .push("items")
147 .push(item_id);
148
149 self.fetch_json(&url).await
150 }
151
152 pub async fn search(&self, params: &SearchParams) -> Result<ItemCollection> {
161 let mut url = self.base_url.clone();
162 url.path_segments_mut()
163 .map_err(|()| Error::InvalidEndpoint("Cannot modify URL path".to_string()))?
164 .push("search");
165
166 let response = self.client.post(url).json(params).send().await?;
167
168 self.handle_response(response).await
169 }
170
171 pub async fn search_get(&self, params: &SearchParams) -> Result<ItemCollection> {
179 let mut url = self.base_url.clone();
180 url.path_segments_mut()
181 .map_err(|()| Error::InvalidEndpoint("Cannot modify URL path".to_string()))?
182 .push("search");
183
184 let query_params = self.search_params_to_query(params)?;
186 for (key, value) in query_params {
187 url.query_pairs_mut().append_pair(&key, &value);
188 }
189
190 self.fetch_json(&url).await
191 }
192
193 pub async fn get_conformance(&self) -> Result<serde_json::Value> {
199 let mut url = self.base_url.clone();
200 url.path_segments_mut()
201 .map_err(|()| Error::InvalidEndpoint("Cannot modify URL path".to_string()))?
202 .push("conformance");
203
204 self.fetch_json(&url).await
205 }
206
207 async fn fetch_json<T>(&self, url: &Url) -> Result<T>
209 where
210 T: for<'de> serde::Deserialize<'de>,
211 {
212 let response = self.client.get(url.clone()).send().await?;
213 self.handle_response(response).await
214 }
215
216 async fn handle_response<T>(&self, response: reqwest::Response) -> Result<T>
219 where
220 T: for<'de> serde::Deserialize<'de>,
221 {
222 let status = response.status();
223 if status.is_success() {
224 let text = response.text().await?;
225 let result = serde_json::from_str(&text)?;
226 return Ok(result);
227 }
228
229 if status.as_u16() == 429 {
230 let retry_after = response
232 .headers()
233 .get(reqwest::header::RETRY_AFTER)
234 .and_then(|v| v.to_str().ok())
235 .and_then(|s| s.parse::<u64>().ok());
236 return Err(Error::RateLimited { retry_after });
237 }
238
239 let error_text = response
240 .text()
241 .await
242 .unwrap_or_else(|_| "Unknown error".to_string());
243 Err(Error::Api {
244 status: status.as_u16(),
245 message: error_text,
246 })
247 }
248
249 fn search_params_to_query(&self, params: &SearchParams) -> Result<Vec<(String, String)>> {
256 _ = self.base_url; let mut query_params = Vec::new();
258
259 if let Some(limit) = params.limit {
260 query_params.push(("limit".to_string(), limit.to_string()));
261 }
262
263 if let Some(bbox) = ¶ms.bbox {
264 let bbox_str = bbox
265 .iter()
266 .map(std::string::ToString::to_string)
267 .collect::<Vec<_>>()
268 .join(",");
269 query_params.push(("bbox".to_string(), bbox_str));
270 }
271
272 if let Some(datetime) = ¶ms.datetime {
273 query_params.push(("datetime".to_string(), datetime.clone()));
274 }
275
276 if let Some(collections) = ¶ms.collections {
277 let collections_str = collections.join(",");
278 query_params.push(("collections".to_string(), collections_str));
279 }
280
281 if let Some(ids) = ¶ms.ids {
282 let ids_str = ids.join(",");
283 query_params.push(("ids".to_string(), ids_str));
284 }
285
286 if let Some(intersects) = ¶ms.intersects {
287 let intersects_str = serde_json::to_string(intersects)?;
288 query_params.push(("intersects".to_string(), intersects_str));
289 }
290
291 if let Some(query) = ¶ms.query {
293 for (key, value) in query {
294 let value_str = serde_json::to_string(value)?;
295 query_params.push((format!("query[{key}]"), value_str));
296 }
297 }
298
299 if let Some(sort_by) = ¶ms.sortby {
300 let sort_str = sort_by
301 .iter()
302 .map(|s| {
303 let prefix = match s.direction {
304 SortDirection::Asc => "+",
305 SortDirection::Desc => "-",
306 };
307 format!("{}{}", prefix, s.field)
308 })
309 .collect::<Vec<_>>()
310 .join(",");
311 query_params.push(("sortby".to_string(), sort_str));
312 }
313
314 if let Some(fields) = ¶ms.fields {
315 let mut field_specs = Vec::new();
316 if let Some(include) = &fields.include {
317 field_specs.extend(include.iter().cloned());
318 }
319 if let Some(exclude) = &fields.exclude {
320 field_specs.extend(exclude.iter().map(|f| format!("-{f}")));
321 }
322
323 if !field_specs.is_empty() {
324 query_params.push(("fields".to_string(), field_specs.join(",")));
325 }
326 }
327
328 Ok(query_params)
329 }
330
331 #[cfg(feature = "pagination")]
343 pub async fn search_next_page(
344 &self,
345 current: &ItemCollection,
346 ) -> Result<Option<ItemCollection>> {
347 let next_href = match ¤t.links {
348 Some(links) => links
349 .iter()
350 .find(|l| l.rel == "next")
351 .map(|l| l.href.clone()),
352 None => None,
353 };
354 let Some(href) = next_href else {
355 return Ok(None);
356 };
357 let url = Url::parse(&href).map_err(|e| Error::InvalidEndpoint(e.to_string()))?;
358 let page: ItemCollection = self.fetch_json(&url).await?;
359 Ok(Some(page))
360 }
361}
362
363pub struct SearchBuilder {
368 params: SearchParams,
369}
370
371impl SearchBuilder {
372 #[must_use]
374 pub fn new() -> Self {
375 Self {
376 params: SearchParams::default(),
377 }
378 }
379
380 #[must_use]
382 pub fn limit(mut self, limit: u32) -> Self {
383 self.params.limit = Some(limit);
384 self
385 }
386
387 #[must_use]
393 pub fn bbox(mut self, bbox: Vec<f64>) -> Self {
394 self.params.bbox = Some(bbox);
395 self
396 }
397
398 #[must_use]
404 pub fn datetime(mut self, datetime: &str) -> Self {
405 self.params.datetime = Some(datetime.to_string());
406 self
407 }
408
409 #[must_use]
411 pub fn collections(mut self, collections: Vec<String>) -> Self {
412 self.params.collections = Some(collections);
413 self
414 }
415
416 #[must_use]
418 pub fn ids(mut self, ids: Vec<String>) -> Self {
419 self.params.ids = Some(ids);
420 self
421 }
422
423 #[must_use]
425 pub fn intersects(mut self, geometry: serde_json::Value) -> Self {
426 self.params.intersects = Some(geometry);
427 self
428 }
429
430 #[must_use]
434 pub fn query(mut self, key: &str, value: serde_json::Value) -> Self {
435 self.params
436 .query
437 .get_or_insert_with(HashMap::new)
438 .insert(key.to_string(), value);
439 self
440 }
441
442 #[must_use]
444 pub fn sort_by(mut self, field: &str, direction: SortDirection) -> Self {
445 self.params
446 .sortby
447 .get_or_insert_with(Vec::new)
448 .push(SortBy {
449 field: field.to_string(),
450 direction,
451 });
452 self
453 }
454
455 #[must_use]
459 pub fn include_fields(mut self, fields: Vec<String>) -> Self {
460 self.params
461 .fields
462 .get_or_insert_with(FieldsFilter::default)
463 .include = Some(fields);
464 self
465 }
466
467 #[must_use]
471 pub fn exclude_fields(mut self, fields: Vec<String>) -> Self {
472 self.params
473 .fields
474 .get_or_insert_with(FieldsFilter::default)
475 .exclude = Some(fields);
476 self
477 }
478
479 #[must_use]
481 pub fn build(self) -> SearchParams {
482 self.params
483 }
484}
485
486impl Default for SearchBuilder {
487 fn default() -> Self {
488 Self::new()
489 }
490}
491
492#[cfg(test)]
493mod tests {
494 use super::*;
495 use mockito;
496 use serde_json::json;
497
498 #[test]
499 fn test_client_creation() {
500 let client = Client::new("https://example.com/stac").unwrap();
501 assert_eq!(client.base_url.as_str(), "https://example.com/stac");
502 }
503
504 #[test]
505 fn test_invalid_url() {
506 let result = Client::new("not-a-valid-url");
507 assert!(result.is_err());
508 }
509
510 #[test]
511 fn test_search_builder() {
512 let params = SearchBuilder::new()
513 .limit(10)
514 .bbox(vec![-180.0, -90.0, 180.0, 90.0])
515 .datetime("2023-01-01T00:00:00Z/2023-12-31T23:59:59Z")
516 .collections(vec!["collection1".to_string(), "collection2".to_string()])
517 .ids(vec!["item1".to_string(), "item2".to_string()])
518 .query("eo:cloud_cover", json!({"lt": 10}))
519 .sort_by("datetime", SortDirection::Desc)
520 .include_fields(vec!["id".to_string(), "geometry".to_string()])
521 .build();
522
523 assert_eq!(params.limit, Some(10));
524 assert_eq!(params.bbox, Some(vec![-180.0, -90.0, 180.0, 90.0]));
525 assert_eq!(
526 params.datetime,
527 Some("2023-01-01T00:00:00Z/2023-12-31T23:59:59Z".to_string())
528 );
529 assert_eq!(
530 params.collections,
531 Some(vec!["collection1".to_string(), "collection2".to_string()])
532 );
533 assert_eq!(
534 params.ids,
535 Some(vec!["item1".to_string(), "item2".to_string()])
536 );
537 assert!(params.query.is_some());
538 assert!(params.sortby.is_some());
539 assert!(params.fields.is_some());
540 }
541
542 #[tokio::test]
543 async fn test_get_catalog_mock() {
544 let mut server = mockito::Server::new_async().await;
545 let mock_catalog = json!({
546 "type": "Catalog",
547 "stac_version": "1.0.0",
548 "id": "test-catalog",
549 "description": "Test catalog",
550 "links": []
551 });
552
553 let mock = server
554 .mock("GET", "/")
555 .with_status(200)
556 .with_header("content-type", "application/json")
557 .with_body(mock_catalog.to_string())
558 .create_async()
559 .await;
560
561 let client = Client::new(&server.url()).unwrap();
562 let catalog = client.get_catalog().await.unwrap();
563
564 mock.assert_async().await;
565 assert_eq!(catalog.id, "test-catalog");
566 assert_eq!(catalog.stac_version, "1.0.0");
567 }
568
569 #[tokio::test]
570 async fn test_get_collections_mock() {
571 let mut server = mockito::Server::new_async().await;
572 let mock_response = json!({
573 "collections": [
574 {
575 "type": "Collection",
576 "stac_version": "1.0.0",
577 "id": "test-collection",
578 "description": "Test collection",
579 "license": "MIT",
580 "extent": {
581 "spatial": {
582 "bbox": [[-180.0, -90.0, 180.0, 90.0]]
583 },
584 "temporal": {
585 "interval": [["2023-01-01T00:00:00Z", "2023-12-31T23:59:59Z"]]
586 }
587 },
588 "links": []
589 }
590 ]
591 });
592
593 let mock = server
594 .mock("GET", "/collections")
595 .with_status(200)
596 .with_header("content-type", "application/json")
597 .with_body(mock_response.to_string())
598 .create_async()
599 .await;
600
601 let client = Client::new(&server.url()).unwrap();
602 let collections = client.get_collections().await.unwrap();
603
604 mock.assert_async().await;
605 assert_eq!(collections.len(), 1);
606 assert_eq!(collections[0].id, "test-collection");
607 }
608
609 #[tokio::test]
610 async fn test_search_mock() {
611 let mut server = mockito::Server::new_async().await;
612 let mock_response = json!({
613 "type": "FeatureCollection",
614 "features": [
615 {
616 "type": "Feature",
617 "stac_version": "1.0.0",
618 "id": "test-item",
619 "geometry": null,
620 "properties": {
621 "datetime": "2023-01-01T12:00:00Z"
622 },
623 "links": [],
624 "assets": {},
625 "collection": "test-collection"
626 }
627 ]
628 });
629
630 let mock = server
631 .mock("POST", "/search")
632 .with_status(200)
633 .with_header("content-type", "application/json")
634 .with_body(mock_response.to_string())
635 .create_async()
636 .await;
637
638 let client = Client::new(&server.url()).unwrap();
639 let search_params = SearchBuilder::new()
640 .limit(10)
641 .collections(vec!["test-collection".to_string()])
642 .build();
643
644 let results = client.search(&search_params).await.unwrap();
645
646 mock.assert_async().await;
647 assert_eq!(results.features.len(), 1);
648 assert_eq!(results.features[0].id, "test-item");
649 assert_eq!(
650 results.features[0].collection.as_ref().unwrap(),
651 "test-collection"
652 );
653 }
654
655 #[tokio::test]
656 async fn test_error_handling() {
657 let mut server = mockito::Server::new_async().await;
658 let mock = server
659 .mock("GET", "/")
660 .with_status(404)
661 .with_body("Not found")
662 .create_async()
663 .await;
664
665 let client = Client::new(&server.url()).unwrap();
666 let result = client.get_catalog().await;
667
668 mock.assert_async().await;
669 assert!(result.is_err());
670 match result.unwrap_err() {
671 Error::Api { status, .. } => assert_eq!(status, 404),
672 _ => panic!("Expected API error"),
673 }
674 }
675
676 #[test]
677 fn test_search_params_to_query() {
678 let client = Client::new("https://example.com").unwrap();
679 let params = SearchParams {
680 limit: Some(10),
681 bbox: Some(vec![-180.0, -90.0, 180.0, 90.0]),
682 datetime: Some("2023-01-01T00:00:00Z".to_string()),
683 collections: Some(vec!["col1".to_string(), "col2".to_string()]),
684 ids: Some(vec!["id1".to_string(), "id2".to_string()]),
685 ..Default::default()
686 };
687
688 let query_params = client.search_params_to_query(¶ms).unwrap();
689
690 let param_map: std::collections::HashMap<String, String> =
692 query_params.into_iter().collect();
693
694 assert_eq!(param_map.get("limit").unwrap(), "10");
695 assert_eq!(param_map.get("bbox").unwrap(), "-180,-90,180,90");
696 assert_eq!(param_map.get("datetime").unwrap(), "2023-01-01T00:00:00Z");
697 assert_eq!(param_map.get("collections").unwrap(), "col1,col2");
698 assert_eq!(param_map.get("ids").unwrap(), "id1,id2");
699 }
700
701 #[test]
702 fn test_search_params_to_query_with_intersects_and_query() {
703 let client = Client::new("https://example.com").unwrap();
704 let mut query_map = HashMap::new();
705 query_map.insert("eo:cloud_cover".to_string(), json!({"lt": 5}));
706 let geom = json!({
707 "type": "Point",
708 "coordinates": [0.0, 0.0]
709 });
710 let params = SearchParams {
711 intersects: Some(geom.clone()),
712 query: Some(query_map.clone()),
713 ..Default::default()
714 };
715
716 let query_params = client.search_params_to_query(¶ms).unwrap();
717 let param_map: std::collections::HashMap<String, String> =
718 query_params.into_iter().collect();
719
720 assert!(param_map.contains_key("intersects"));
722 assert!(param_map.get("intersects").unwrap().contains("\"Point\""));
724 assert!(param_map.contains_key("query[eo:cloud_cover]"));
725 assert_eq!(
726 param_map.get("query[eo:cloud_cover]").unwrap(),
727 &serde_json::to_string(&json!({"lt": 5})).unwrap()
728 );
729 }
730
731 #[test]
732 fn test_search_params_to_query_with_sortby_and_fields() {
733 let client = Client::new("https://example.com").unwrap();
734 let params = SearchBuilder::new()
735 .sort_by("datetime", SortDirection::Asc)
736 .sort_by("eo:cloud_cover", SortDirection::Desc)
737 .include_fields(vec!["id".to_string(), "properties".to_string()])
738 .exclude_fields(vec!["geometry".to_string()])
739 .build();
740
741 let query_params = client.search_params_to_query(¶ms).unwrap();
742 let param_map: std::collections::HashMap<String, String> =
743 query_params.into_iter().collect();
744
745 assert_eq!(
746 param_map.get("sortby").unwrap(),
747 "+datetime,-eo:cloud_cover"
748 );
749 assert_eq!(param_map.get("fields").unwrap(), "id,properties,-geometry");
750 }
751
752 #[test]
753 fn test_search_builder_exclude_fields() {
754 let params = SearchBuilder::new()
755 .exclude_fields(vec!["geometry".to_string(), "assets".to_string()])
756 .build();
757 assert!(params.fields.is_some());
758 let fields = params.fields.unwrap();
759 assert!(fields.include.is_none());
760 assert_eq!(
761 fields.exclude.unwrap(),
762 vec!["geometry".to_string(), "assets".to_string()]
763 );
764 }
765}