weaviate_community/
classification.rs1use reqwest::Url;
2use std::error::Error;
3use std::sync::Arc;
4use uuid::Uuid;
5
6use crate::collections::{
7 classification::{ClassificationRequest, ClassificationResponse},
8 error::ClassificationError,
9};
10
11#[derive(Debug)]
14pub struct Classification {
15 endpoint: Url,
16 client: Arc<reqwest::Client>,
17}
18
19impl Classification {
20 pub(super) fn new(url: &Url, client: Arc<reqwest::Client>) -> Result<Self, Box<dyn Error>> {
23 let endpoint = url.join("/v1/classifications/")?;
24 Ok(Classification { endpoint, client })
25 }
26
27 pub async fn schedule(
63 &self,
64 request: ClassificationRequest,
65 ) -> Result<ClassificationResponse, Box<dyn Error>> {
66 let res = self
67 .client
68 .post(self.endpoint.clone())
69 .json(&request)
70 .send()
71 .await?;
72 match res.status() {
73 reqwest::StatusCode::CREATED => {
74 let res: ClassificationResponse = res.json().await?;
75 Ok(res)
76 }
77 _ => Err(self.get_err_msg("schedule classification", res).await)
78 }
79 }
80
81 pub async fn get(&self, id: Uuid) -> Result<ClassificationResponse, Box<dyn Error>> {
98 let endpoint = self.endpoint.join(&id.to_string())?;
99 let res = self.client.get(endpoint).send().await?;
100 match res.status() {
101 reqwest::StatusCode::OK => {
102 let res: ClassificationResponse = res.json().await?;
103 Ok(res)
104 }
105 _ => Err(self.get_err_msg("get classification", res).await)
106 }
107 }
108
109 async fn get_err_msg(
113 &self,
114 endpoint: &str,
115 res: reqwest::Response
116 ) -> Box<ClassificationError> {
117 let status_code = res.status();
118 let msg: Result<serde_json::Value, reqwest::Error> = res.json().await;
119 let r_str: String;
120 if let Ok(json) = msg {
121 r_str = format!(
122 "Status code `{}` received when calling {} endpoint. Response: {}",
123 status_code,
124 endpoint,
125 json,
126 );
127 } else {
128 r_str = format!(
129 "Status code `{}` received when calling {} endpoint.",
130 status_code,
131 endpoint
132 );
133 }
134 Box::new(ClassificationError(r_str))
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use uuid::Uuid;
141 use crate::{
142 WeaviateClient,
143 collections::classification::{ClassificationRequest, ClassificationType}
144 };
145
146 fn get_test_harness() -> (mockito::ServerGuard, WeaviateClient) {
147 let mock_server = mockito::Server::new();
148 let mut host = "http://".to_string();
149 host.push_str(&mock_server.host_with_port());
150 let client = WeaviateClient::builder(&host).build().unwrap();
151 (mock_server, client)
152 }
153
154 fn test_classification_req() -> ClassificationRequest {
155 ClassificationRequest::builder()
156 .with_class("Test")
157 .with_type(ClassificationType::KNN)
158 .with_based_on_properties(vec!["testProp"])
159 .with_classify_properties(vec!["hasPopularity"])
160 .with_filters(serde_json::json!({
161 "path": ["testPropTwo"],
162 "operator": "GreaterThan",
163 "valueInt": 100
164 }))
165 .with_settings(serde_json::json!({"k": 3}))
166 .build()
167 }
168
169 fn mock_post(
170 server: &mut mockito::ServerGuard,
171 endpoint: &str,
172 status_code: usize,
173 body: &str,
174 ) -> mockito::Mock {
175 server
176 .mock("POST", endpoint)
177 .with_status(status_code)
178 .with_header("content-type", "application/json")
179 .with_body(body)
180 .create()
181 }
182
183 fn mock_get(
184 server: &mut mockito::ServerGuard,
185 endpoint: &str,
186 status_code: usize,
187 body: &str,
188 ) -> mockito::Mock {
189 server
190 .mock("GET", endpoint)
191 .with_status(status_code)
192 .with_header("content-type", "application/json")
193 .with_body(body)
194 .create()
195 }
196
197 #[tokio::test]
198 async fn test_classification_schedule_ok() {}
199
200 #[tokio::test]
201 async fn test_classification_schedule_err() {
202 let req = test_classification_req();
203 let (mut mock_server, client) = get_test_harness();
204 let mock = mock_post(&mut mock_server, "/v1/classifications/", 401, "");
205 let res = client.classification.schedule(req).await;
206 mock.assert();
207 assert!(res.is_err());
208 }
209
210 #[tokio::test]
211 async fn test_classification_get_ok() {}
212
213 #[tokio::test]
214 async fn test_classification_get_err() {
215 let uuid = Uuid::new_v4();
216 let mut url = String::from("/v1/classifications/");
217 url.push_str(&uuid.to_string());
218 let (mut mock_server, client) = get_test_harness();
219 let mock = mock_get(&mut mock_server, &url, 401, "");
220 let res = client.classification.get(uuid).await;
221 mock.assert();
222 assert!(res.is_err());
223 }
224}