1use crate::client::XaiClient;
4use crate::models::collection::{
5 AddDocumentsResponse, BatchGetDocumentsRequest, Collection, CollectionListResponse,
6 CreateCollectionRequest, Document, DocumentListResponse, SearchRequest, SearchResponse,
7 UpdateCollectionRequest,
8};
9use crate::{Error, Result};
10
11#[derive(Debug, Clone)]
13pub struct CollectionsApi {
14 client: XaiClient,
15}
16
17impl CollectionsApi {
18 pub(crate) fn new(client: XaiClient) -> Self {
19 Self { client }
20 }
21
22 pub async fn create(&self, request: CreateCollectionRequest) -> Result<Collection> {
41 let url = format!("{}/collections", self.client.base_url());
42
43 let response = self
44 .client
45 .send(self.client.http().post(&url).json(&request))
46 .await?;
47
48 if !response.status().is_success() {
49 return Err(Error::from_response(response).await);
50 }
51
52 Ok(response.json().await?)
53 }
54
55 pub async fn create_named(&self, name: impl Into<String>) -> Result<Collection> {
70 self.create(CreateCollectionRequest::new(name)).await
71 }
72
73 pub async fn get(&self, collection_id: impl AsRef<str>) -> Result<Collection> {
75 let id = XaiClient::encode_path(collection_id.as_ref());
76 let url = format!("{}/collections/{}", self.client.base_url(), id);
77
78 let response = self.client.send(self.client.http().get(&url)).await?;
79
80 if !response.status().is_success() {
81 return Err(Error::from_response(response).await);
82 }
83
84 Ok(response.json().await?)
85 }
86
87 pub async fn update(
89 &self,
90 collection_id: impl AsRef<str>,
91 request: UpdateCollectionRequest,
92 ) -> Result<Collection> {
93 let id = XaiClient::encode_path(collection_id.as_ref());
94 let url = format!("{}/collections/{}", self.client.base_url(), id);
95
96 let response = self
97 .client
98 .send(self.client.http().put(&url).json(&request))
99 .await?;
100
101 if !response.status().is_success() {
102 return Err(Error::from_response(response).await);
103 }
104
105 Ok(response.json().await?)
106 }
107
108 pub async fn upsert_document(
110 &self,
111 collection_id: impl AsRef<str>,
112 document: Document,
113 ) -> Result<Document> {
114 let document_id = document
115 .id
116 .as_ref()
117 .ok_or_else(|| Error::InvalidRequest("document.id is required".to_string()))?;
118 let cid = XaiClient::encode_path(collection_id.as_ref());
119 let did = XaiClient::encode_path(document_id);
120 let url = format!(
121 "{}/collections/{}/documents/{}",
122 self.client.base_url(),
123 cid,
124 did
125 );
126
127 let response = self
128 .client
129 .send(self.client.http().patch(&url).json(&document))
130 .await?;
131
132 if !response.status().is_success() {
133 return Err(Error::from_response(response).await);
134 }
135
136 Ok(response.json().await?)
137 }
138
139 pub async fn add_document_by_id(
141 &self,
142 collection_id: impl AsRef<str>,
143 document_id: impl AsRef<str>,
144 document: Document,
145 ) -> Result<Document> {
146 let cid = XaiClient::encode_path(collection_id.as_ref());
147 let did = XaiClient::encode_path(document_id.as_ref());
148 let url = format!(
149 "{}/collections/{}/documents/{}",
150 self.client.base_url(),
151 cid,
152 did
153 );
154
155 let response = self
156 .client
157 .send(self.client.http().post(&url).json(&document))
158 .await?;
159
160 if !response.status().is_success() {
161 return Err(Error::from_response(response).await);
162 }
163
164 Ok(response.json().await?)
165 }
166
167 pub async fn batch_get_documents(
169 &self,
170 collection_id: impl AsRef<str>,
171 request: BatchGetDocumentsRequest,
172 ) -> Result<DocumentListResponse> {
173 let id = XaiClient::encode_path(collection_id.as_ref());
174 let mut url = url::Url::parse(&format!(
175 "{}/collections/{}/documents:batchGet",
176 self.client.base_url(),
177 id
178 ))?;
179
180 for document_id in request.ids {
181 url.query_pairs_mut().append_pair("ids", &document_id);
182 }
183
184 let response = self
185 .client
186 .send(self.client.http().get(url.as_str()))
187 .await?;
188
189 if !response.status().is_success() {
190 return Err(Error::from_response(response).await);
191 }
192
193 Ok(response.json().await?)
194 }
195
196 pub async fn search_documents(&self, request: SearchRequest) -> Result<SearchResponse> {
198 let url = format!("{}/documents/search", self.client.base_url());
199
200 let response = self
201 .client
202 .send(self.client.http().post(&url).json(&request))
203 .await?;
204
205 if !response.status().is_success() {
206 return Err(Error::from_response(response).await);
207 }
208
209 Ok(response.json().await?)
210 }
211
212 pub async fn list(&self) -> Result<CollectionListResponse> {
230 self.list_with_options(None, None).await
231 }
232
233 pub async fn list_with_options(
235 &self,
236 limit: Option<u32>,
237 next_token: Option<&str>,
238 ) -> Result<CollectionListResponse> {
239 let mut url = url::Url::parse(&format!("{}/collections", self.client.base_url()))?;
240
241 if let Some(l) = limit {
242 url.query_pairs_mut().append_pair("limit", &l.to_string());
243 }
244 if let Some(token) = next_token {
245 url.query_pairs_mut().append_pair("next_token", token);
246 }
247
248 let response = self
249 .client
250 .send(self.client.http().get(url.as_str()))
251 .await?;
252
253 if !response.status().is_success() {
254 return Err(Error::from_response(response).await);
255 }
256
257 Ok(response.json().await?)
258 }
259
260 pub async fn delete(&self, collection_id: impl AsRef<str>) -> Result<()> {
262 let id = XaiClient::encode_path(collection_id.as_ref());
263 let url = format!("{}/collections/{}", self.client.base_url(), id);
264
265 let response = self.client.send(self.client.http().delete(&url)).await?;
266
267 if !response.status().is_success() {
268 return Err(Error::from_response(response).await);
269 }
270
271 Ok(())
272 }
273
274 pub async fn add_documents(
300 &self,
301 collection_id: impl AsRef<str>,
302 documents: Vec<Document>,
303 ) -> Result<AddDocumentsResponse> {
304 let id = XaiClient::encode_path(collection_id.as_ref());
305 let url = format!("{}/collections/{}/documents", self.client.base_url(), id);
306
307 let body = serde_json::json!({ "documents": documents });
308
309 let response = self
310 .client
311 .send(self.client.http().post(&url).json(&body))
312 .await?;
313
314 if !response.status().is_success() {
315 return Err(Error::from_response(response).await);
316 }
317
318 Ok(response.json().await?)
319 }
320
321 pub async fn add_document(
323 &self,
324 collection_id: impl AsRef<str>,
325 document: Document,
326 ) -> Result<String> {
327 let response = self.add_documents(collection_id, vec![document]).await?;
328 Ok(response.ids.into_iter().next().unwrap_or_default())
329 }
330
331 pub async fn list_documents(
333 &self,
334 collection_id: impl AsRef<str>,
335 ) -> Result<DocumentListResponse> {
336 self.list_documents_with_options(collection_id, None, None)
337 .await
338 }
339
340 pub async fn list_documents_with_options(
342 &self,
343 collection_id: impl AsRef<str>,
344 limit: Option<u32>,
345 next_token: Option<&str>,
346 ) -> Result<DocumentListResponse> {
347 let id = XaiClient::encode_path(collection_id.as_ref());
348 let mut url = url::Url::parse(&format!(
349 "{}/collections/{}/documents",
350 self.client.base_url(),
351 id
352 ))?;
353
354 if let Some(l) = limit {
355 url.query_pairs_mut().append_pair("limit", &l.to_string());
356 }
357 if let Some(token) = next_token {
358 url.query_pairs_mut().append_pair("next_token", token);
359 }
360
361 let response = self
362 .client
363 .send(self.client.http().get(url.as_str()))
364 .await?;
365
366 if !response.status().is_success() {
367 return Err(Error::from_response(response).await);
368 }
369
370 Ok(response.json().await?)
371 }
372
373 pub async fn get_document(
375 &self,
376 collection_id: impl AsRef<str>,
377 document_id: impl AsRef<str>,
378 ) -> Result<Document> {
379 let cid = XaiClient::encode_path(collection_id.as_ref());
380 let did = XaiClient::encode_path(document_id.as_ref());
381 let url = format!(
382 "{}/collections/{}/documents/{}",
383 self.client.base_url(),
384 cid,
385 did
386 );
387
388 let response = self.client.send(self.client.http().get(&url)).await?;
389
390 if !response.status().is_success() {
391 return Err(Error::from_response(response).await);
392 }
393
394 Ok(response.json().await?)
395 }
396
397 pub async fn delete_document(
399 &self,
400 collection_id: impl AsRef<str>,
401 document_id: impl AsRef<str>,
402 ) -> Result<()> {
403 let cid = XaiClient::encode_path(collection_id.as_ref());
404 let did = XaiClient::encode_path(document_id.as_ref());
405 let url = format!(
406 "{}/collections/{}/documents/{}",
407 self.client.base_url(),
408 cid,
409 did
410 );
411
412 let response = self.client.send(self.client.http().delete(&url)).await?;
413
414 if !response.status().is_success() {
415 return Err(Error::from_response(response).await);
416 }
417
418 Ok(())
419 }
420
421 pub async fn search(
446 &self,
447 collection_id: impl AsRef<str>,
448 request: SearchRequest,
449 ) -> Result<SearchResponse> {
450 let id = XaiClient::encode_path(collection_id.as_ref());
451 let url = format!("{}/collections/{}/search", self.client.base_url(), id);
452
453 let response = self
454 .client
455 .send(self.client.http().post(&url).json(&request))
456 .await?;
457
458 if !response.status().is_success() {
459 return Err(Error::from_response(response).await);
460 }
461
462 Ok(response.json().await?)
463 }
464
465 pub async fn search_query(
467 &self,
468 collection_id: impl AsRef<str>,
469 query: impl Into<String>,
470 ) -> Result<SearchResponse> {
471 self.search(collection_id, SearchRequest::new(query)).await
472 }
473}
474
475#[cfg(test)]
476mod tests {
477 use super::*;
478 use serde_json::json;
479 use wiremock::matchers::{method, path};
480 use wiremock::{Mock, MockServer, ResponseTemplate};
481
482 #[tokio::test]
483 async fn list_with_options_forwards_query_params() {
484 let server = MockServer::start().await;
485
486 Mock::given(method("GET"))
487 .and(path("/collections"))
488 .respond_with(move |req: &wiremock::Request| {
489 assert_eq!(req.url.query(), Some("limit=4&next_token=tok_col"));
490 ResponseTemplate::new(200).set_body_json(json!({
491 "data": [{"id": "col_1", "name": "a", "document_count": 1}],
492 "next_token": "tok_col_2"
493 }))
494 })
495 .mount(&server)
496 .await;
497
498 let client = XaiClient::builder()
499 .api_key("test-key")
500 .base_url(server.uri())
501 .build()
502 .unwrap();
503
504 let listed = client
505 .collections()
506 .list_with_options(Some(4), Some("tok_col"))
507 .await
508 .unwrap();
509
510 assert_eq!(listed.data.len(), 1);
511 assert_eq!(listed.next_token.as_deref(), Some("tok_col_2"));
512 }
513
514 #[tokio::test]
515 async fn delete_document_encodes_collection_and_document_ids() {
516 let server = MockServer::start().await;
517
518 Mock::given(method("DELETE"))
519 .and(path("/collections/col%2Fsync/documents/doc%201"))
520 .respond_with(ResponseTemplate::new(204))
521 .mount(&server)
522 .await;
523
524 let client = XaiClient::builder()
525 .api_key("test-key")
526 .base_url(server.uri())
527 .build()
528 .unwrap();
529
530 client
531 .collections()
532 .delete_document("col/sync", "doc 1")
533 .await
534 .unwrap();
535 }
536
537 #[tokio::test]
538 async fn search_query_forwards_query_in_request_body() {
539 let server = MockServer::start().await;
540
541 Mock::given(method("POST"))
542 .and(path("/collections/col%2Fsync/search"))
543 .respond_with(move |req: &wiremock::Request| {
544 let body = serde_json::from_slice::<serde_json::Value>(&req.body).unwrap();
545 assert_eq!(body["query"], "needle");
546 ResponseTemplate::new(200).set_body_json(json!({
547 "results": [{
548 "document": {"id": "doc_1", "content": "needle"},
549 "score": 0.9
550 }]
551 }))
552 })
553 .mount(&server)
554 .await;
555
556 let client = XaiClient::builder()
557 .api_key("test-key")
558 .base_url(server.uri())
559 .build()
560 .unwrap();
561
562 let response = client
563 .collections()
564 .search_query("col/sync", "needle")
565 .await
566 .unwrap();
567
568 assert_eq!(response.results.len(), 1);
569 assert_eq!(response.results[0].document.content, "needle");
570 }
571
572 #[tokio::test]
573 async fn create_forwards_collection_request() {
574 let server = MockServer::start().await;
575
576 Mock::given(method("POST"))
577 .and(path("/collections"))
578 .respond_with(move |req: &wiremock::Request| {
579 let body = serde_json::from_slice::<serde_json::Value>(&req.body).unwrap();
580 assert_eq!(body["name"], "research");
581 assert_eq!(body["description"], "private");
582 ResponseTemplate::new(200).set_body_json(json!({
583 "id": "col_1",
584 "name": "research",
585 "description": "private",
586 "document_count": 0
587 }))
588 })
589 .mount(&server)
590 .await;
591
592 let client = XaiClient::builder()
593 .api_key("test-key")
594 .base_url(server.uri())
595 .build()
596 .unwrap();
597
598 let created = client
599 .collections()
600 .create(CreateCollectionRequest::new("research").description("private"))
601 .await
602 .unwrap();
603
604 assert_eq!(created.id, "col_1");
605 assert_eq!(created.name, "research");
606 assert_eq!(created.description.as_deref(), Some("private"));
607 }
608
609 #[tokio::test]
610 async fn create_named_and_get_encodes_ids() {
611 let server = MockServer::start().await;
612
613 Mock::given(method("POST"))
614 .and(path("/collections"))
615 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
616 "id": "col%2Fnew",
617 "name": "new",
618 "document_count": 0
619 })))
620 .mount(&server)
621 .await;
622
623 Mock::given(method("GET"))
624 .and(path("/collections/col%2Fnew"))
625 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
626 "id": "col/new",
627 "name": "new",
628 "document_count": 2
629 })))
630 .mount(&server)
631 .await;
632
633 let client = XaiClient::builder()
634 .api_key("test-key")
635 .base_url(server.uri())
636 .build()
637 .unwrap();
638
639 let created = client.collections().create_named("col/new").await.unwrap();
640 let got = client.collections().get("col/new").await.unwrap();
641 assert_eq!(created.id, "col%2Fnew");
642 assert_eq!(got.id, "col/new");
643 }
644
645 #[tokio::test]
646 async fn update_collection_forwards_payload_and_uses_put() {
647 let server = MockServer::start().await;
648
649 Mock::given(method("PUT"))
650 .and(path("/collections/col%2Fsync"))
651 .respond_with(move |req: &wiremock::Request| {
652 let body = serde_json::from_slice::<serde_json::Value>(&req.body).unwrap();
653 assert_eq!(body["name"], "updated");
654 assert_eq!(body["description"], "sync docs");
655 ResponseTemplate::new(200).set_body_json(json!({
656 "id": "col/sync",
657 "name": "updated",
658 "description": "sync docs",
659 "document_count": 4
660 }))
661 })
662 .mount(&server)
663 .await;
664
665 let client = XaiClient::builder()
666 .api_key("test-key")
667 .base_url(server.uri())
668 .build()
669 .unwrap();
670
671 let updated = client
672 .collections()
673 .update(
674 "col/sync",
675 UpdateCollectionRequest::new()
676 .name("updated")
677 .description("sync docs"),
678 )
679 .await
680 .unwrap();
681
682 assert_eq!(updated.name, "updated");
683 assert_eq!(updated.document_count, 4);
684 }
685
686 #[tokio::test]
687 async fn upsert_document_encodes_ids_and_forwards_payload() {
688 let server = MockServer::start().await;
689
690 Mock::given(method("PATCH"))
691 .and(path("/collections/col%2Fsync/documents/doc%201"))
692 .respond_with(move |req: &wiremock::Request| {
693 let body = serde_json::from_slice::<serde_json::Value>(&req.body).unwrap();
694 assert_eq!(body["id"], "doc 1");
695 assert_eq!(body["content"], "new");
696 ResponseTemplate::new(200).set_body_json(json!({
697 "id": "doc 1",
698 "content": "new",
699 "metadata": null
700 }))
701 })
702 .mount(&server)
703 .await;
704
705 let client = XaiClient::builder()
706 .api_key("test-key")
707 .base_url(server.uri())
708 .build()
709 .unwrap();
710
711 let doc = client
712 .collections()
713 .upsert_document(
714 "col/sync",
715 Document {
716 id: Some("doc 1".to_string()),
717 content: "new".to_string(),
718 metadata: None,
719 },
720 )
721 .await
722 .unwrap();
723
724 assert_eq!(doc.id.as_deref(), Some("doc 1"));
725 assert_eq!(doc.content, "new");
726 }
727
728 #[tokio::test]
729 async fn batch_get_documents_encodes_ids_as_repeated_query_params() {
730 let server = MockServer::start().await;
731
732 Mock::given(method("GET"))
733 .and(path("/collections/col%2Fsync/documents:batchGet"))
734 .respond_with(move |req: &wiremock::Request| {
735 assert_eq!(req.url.query(), Some("ids=d1&ids=d2"));
736 ResponseTemplate::new(200).set_body_json(json!({
737 "data": [
738 {"id": "d1", "content": "one"},
739 {"id": "d2", "content": "two"}
740 ],
741 "next_token": null
742 }))
743 })
744 .mount(&server)
745 .await;
746
747 let client = XaiClient::builder()
748 .api_key("test-key")
749 .base_url(server.uri())
750 .build()
751 .unwrap();
752
753 let docs = client
754 .collections()
755 .batch_get_documents(
756 "col/sync",
757 BatchGetDocumentsRequest::new(vec!["d1".into(), "d2".into()]),
758 )
759 .await
760 .unwrap();
761
762 assert_eq!(docs.data.len(), 2);
763 assert_eq!(docs.data[0].id.as_deref(), Some("d1"));
764 }
765
766 #[tokio::test]
767 async fn list_documents_and_get_document_paths_encode_ids() {
768 let server = MockServer::start().await;
769
770 Mock::given(method("GET"))
771 .and(path("/collections/col%2Fsync/documents"))
772 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
773 "data": [{"id": "doc 1", "content": "one"}],
774 "next_token": "tok"
775 })))
776 .mount(&server)
777 .await;
778
779 Mock::given(method("POST"))
780 .and(path("/collections/col%2Fsync/documents"))
781 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
782 "ids": ["doc 2"]
783 })))
784 .mount(&server)
785 .await;
786
787 Mock::given(method("GET"))
788 .and(path("/collections/col%2Fsync/documents/doc%201"))
789 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
790 "id": "doc 1",
791 "content": "one"
792 })))
793 .mount(&server)
794 .await;
795
796 let client = XaiClient::builder()
797 .api_key("test-key")
798 .base_url(server.uri())
799 .build()
800 .unwrap();
801
802 let docs = client
803 .collections()
804 .list_documents_with_options("col/sync", Some(4), Some("tok"))
805 .await
806 .unwrap();
807 assert_eq!(docs.data.len(), 1);
808
809 let add_response = client
810 .collections()
811 .add_document("col/sync", Document::with_id("doc 2", "second"))
812 .await
813 .unwrap();
814 assert_eq!(add_response, "doc 2");
815
816 let document = client
817 .collections()
818 .get_document("col/sync", "doc 1")
819 .await
820 .unwrap();
821 assert_eq!(document.content, "one");
822 }
823
824 #[tokio::test]
825 async fn list_and_delete_with_request_paths() {
826 let server = MockServer::start().await;
827
828 Mock::given(method("GET"))
829 .and(path("/collections"))
830 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
831 "data": [{"id": "col_1", "name": "alpha", "document_count": 1}]
832 })))
833 .mount(&server)
834 .await;
835
836 Mock::given(method("DELETE"))
837 .and(path("/collections/col%2Fsync"))
838 .respond_with(ResponseTemplate::new(204))
839 .mount(&server)
840 .await;
841
842 let client = XaiClient::builder()
843 .api_key("test-key")
844 .base_url(server.uri())
845 .build()
846 .unwrap();
847
848 let listed = client.collections().list().await.unwrap();
849 assert_eq!(listed.data.len(), 1);
850
851 client.collections().delete("col/sync").await.unwrap();
852 }
853
854 #[tokio::test]
855 async fn search_documents_posts_to_global_documents_search() {
856 let server = MockServer::start().await;
857
858 Mock::given(method("POST"))
859 .and(path("/documents/search"))
860 .respond_with(move |req: &wiremock::Request| {
861 let body = serde_json::from_slice::<serde_json::Value>(&req.body).unwrap();
862 assert_eq!(body["query"], "search-term");
863 ResponseTemplate::new(200).set_body_json(json!({
864 "results": [{
865 "document": {
866 "id": "doc-search-1",
867 "content": "result text"
868 },
869 "score": 0.92
870 }]
871 }))
872 })
873 .mount(&server)
874 .await;
875
876 let client = XaiClient::builder()
877 .api_key("test-key")
878 .base_url(server.uri())
879 .build()
880 .unwrap();
881
882 let response = client
883 .collections()
884 .search_documents(SearchRequest::new("search-term"))
885 .await
886 .unwrap();
887
888 assert_eq!(response.results.len(), 1);
889 assert_eq!(
890 response.results[0].document.id.as_deref(),
891 Some("doc-search-1")
892 );
893 }
894
895 #[tokio::test]
896 async fn add_document_by_id_posts_to_id_scoped_route() {
897 let server = MockServer::start().await;
898
899 Mock::given(method("POST"))
900 .and(path("/collections/col%2Fsync/documents/doc%201"))
901 .respond_with(move |req: &wiremock::Request| {
902 let body = serde_json::from_slice::<serde_json::Value>(&req.body).unwrap();
903 assert_eq!(body["id"], "doc 1");
904 assert_eq!(body["content"], "content");
905 ResponseTemplate::new(200).set_body_json(json!({
906 "id": "doc 1",
907 "content": "content",
908 "metadata": null
909 }))
910 })
911 .mount(&server)
912 .await;
913
914 let client = XaiClient::builder()
915 .api_key("test-key")
916 .base_url(server.uri())
917 .build()
918 .unwrap();
919
920 let doc = client
921 .collections()
922 .add_document_by_id("col/sync", "doc 1", Document::with_id("doc 1", "content"))
923 .await
924 .unwrap();
925
926 assert_eq!(doc.id.as_deref(), Some("doc 1"));
927 assert_eq!(doc.content, "content");
928 }
929}