1use crate::client::Client;
2use crate::error::{Result, ShilpError};
3use crate::models::{
4 FileReaderOptions, IngestRequest, IngestResponse, IngestSourceType,
5 ListEmbeddingModelsResponse, ListIngestionSourcesResponse, ListNLIVerticalsResponse,
6 ListStorageResponse, ReadDocumentResponse, SearchRequest, SearchResponse,
7};
8use std::collections::HashMap;
9
10impl Client {
11 pub async fn ingest_data(&self, req: &IngestRequest) -> Result<IngestResponse> {
13 self.do_request(
14 reqwest::Method::POST,
15 "/api/data/v1/ingest",
16 Some(req),
17 None,
18 )
19 .await
20 }
21
22 pub async fn search_data(&self, req: &SearchRequest) -> Result<SearchResponse> {
25 if req.collection.is_empty() {
26 return Err(ShilpError::ValidationError(
27 "collection name cannot be empty".to_string(),
28 ));
29 }
30 if req.vector_query.is_none() && req.query.as_ref().map_or(true, |q| q.is_empty()) {
31 return Err(ShilpError::ValidationError(
32 "both vector_query and query cannot be empty".to_string(),
33 ));
34 }
35 self.do_request(
36 reqwest::Method::POST,
37 "/api/data/v1/search",
38 Some(req),
39 None,
40 )
41 .await
42 }
43
44 pub async fn list_storage(
47 &self,
48 path: Option<&str>,
49 _source: IngestSourceType,
50 ) -> Result<ListStorageResponse> {
51 let mut params = HashMap::new();
52 if let Some(p) = path {
53 params.insert("path".to_string(), p.to_string());
54 }
55
56 self.do_request::<ListStorageResponse, ()>(
57 reqwest::Method::GET,
58 "/api/data/v1/storage/list",
59 None,
60 Some(¶ms),
61 )
62 .await
63 }
64
65 pub async fn list_ingest_sources(&self) -> Result<ListIngestionSourcesResponse> {
67 self.do_request::<ListIngestionSourcesResponse, ()>(
68 reqwest::Method::GET,
69 "/api/data/v1/ingest/sources",
70 None,
71 None,
72 )
73 .await
74 }
75
76 pub async fn read_document(
80 &self,
81 path: &str,
82 options: &FileReaderOptions,
83 ) -> Result<ReadDocumentResponse> {
84 if path.is_empty() {
85 return Err(ShilpError::ValidationError(
86 "path cannot be empty".to_string(),
87 ));
88 }
89
90 if let Some(rows) = options.limit {
91 if rows < 0 {
92 return Err(ShilpError::ValidationError(
93 "rows cannot be negative".to_string(),
94 ));
95 }
96 }
97
98 if let Some(skip) = options.skip {
99 if skip < 0 {
100 return Err(ShilpError::ValidationError(
101 "skip cannot be negative".to_string(),
102 ));
103 }
104 }
105
106 if let Some(ref source) = options.source {
107 if *source == IngestSourceType::MongoDB && path.split('/').count() != 2 {
108 return Err(ShilpError::ValidationError(
109 "for mongodb source, path must be in the format 'database/collection'"
110 .to_string(),
111 ));
112 }
113
114 if !source.is_valid() {
115 return Err(ShilpError::ValidationError(format!(
116 "invalid source type - {:?}",
117 source
118 )));
119 }
120 }
121
122 let mut params = HashMap::new();
123 params.insert("path".to_string(), path.to_string());
124
125 if let Some(ref source) = options.source {
126 params.insert("source".to_string(), format!("{:?}", source).to_lowercase());
127 }
128
129 if let Some(rows) = options.limit {
130 params.insert("rows".to_string(), rows.to_string());
131 }
132
133 if let Some(skip) = options.skip {
134 params.insert("skip".to_string(), skip.to_string());
135 }
136
137 if let Some(ref mongo_filter) = options.mongo_filter {
138 let filter_str = serde_json::to_string(mongo_filter)?;
139 params.insert("mongo_filter".to_string(), filter_str);
140 }
141
142 self.do_request::<ReadDocumentResponse, ()>(
143 reqwest::Method::GET,
144 "/api/data/v1/storage/read",
145 None,
146 Some(¶ms),
147 )
148 .await
149 }
150
151 pub async fn upload_data_file(&self, file_path: &std::path::Path) -> Result<()> {
153 self.do_file_request(
154 reqwest::Method::POST,
155 "/api/data/v1/storage/upload",
156 file_path,
157 )
158 .await
159 }
160
161 pub async fn list_embedding_models(&self) -> Result<ListEmbeddingModelsResponse> {
163 self.do_request::<ListEmbeddingModelsResponse, ()>(
164 reqwest::Method::GET,
165 "/api/data/v1/embedding/models",
166 None,
167 None,
168 )
169 .await
170 }
171
172 pub async fn list_nli_verticals(&self) -> Result<ListNLIVerticalsResponse> {
174 self.do_request::<ListNLIVerticalsResponse, ()>(
175 reqwest::Method::GET,
176 "/api/data/v1/nli/verticals",
177 None,
178 None,
179 )
180 .await
181 }
182}