weaviate_community/
classification.rs

1use reqwest::Url;
2use std::error::Error;
3use std::sync::Arc;
4use uuid::Uuid;
5
6use crate::collections::{
7    classification::{ClassificationRequest, ClassificationResponse},
8    error::ClassificationError,
9};
10
11/// All classification related endpoints and functionality described in
12/// [Weaviate meta API documentation](https://weaviate.io/developers/weaviate/api/rest/classification)
13#[derive(Debug)]
14pub struct Classification {
15    endpoint: Url,
16    client: Arc<reqwest::Client>,
17}
18
19impl Classification {
20    /// Create a new instance of the Classification endpoint struct. Should only be done by the 
21    /// parent client.
22    pub(super) fn new(url: &Url, client: Arc<reqwest::Client>) -> Result<Self, Box<dyn Error>> {
23        let endpoint = url.join("/v1/classifications/")?;
24        Ok(Classification { endpoint, client })
25    }
26
27    /// Schedule a new classification
28    ///
29    /// # Example
30    /// ```no_run
31    /// use weaviate_community::WeaviateClient;
32    /// use weaviate_community::collections::classification::{
33    ///     ClassificationRequest,
34    ///     ClassificationType
35    /// };
36    ///
37    /// #[tokio::main]
38    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
39    ///     let client = WeaviateClient::builder("http://localhost:8080").build()?;
40    ///
41    ///     let req = ClassificationRequest::builder()
42    ///         .with_type(ClassificationType::KNN)
43    ///         .with_class("Article")
44    ///         .with_based_on_properties(vec!["summary"])
45    ///         .with_classify_properties(vec!["hasPopularity"])
46    ///         .with_filters(serde_json::json!({
47    ///             "trainingSetWhere": {
48    ///                 "path": ["wordCount"],
49    ///                 "operator": "GreaterThan",
50    ///                 "valueInt": 100
51    ///             }
52    ///         }))
53    ///         .with_settings(serde_json::json!({
54    ///             "k": 3
55    ///         }))
56    ///         .build();
57    ///
58    ///     let res = client.classification.schedule(req).await?;
59    ///     Ok(())
60    /// }
61    /// ```
62    pub async fn schedule(
63        &self,
64        request: ClassificationRequest,
65    ) -> Result<ClassificationResponse, Box<dyn Error>> {
66        let res = self
67            .client
68            .post(self.endpoint.clone())
69            .json(&request)
70            .send()
71            .await?;
72        match res.status() {
73            reqwest::StatusCode::CREATED => {
74                let res: ClassificationResponse = res.json().await?;
75                Ok(res)
76            }
77            _ => Err(self.get_err_msg("schedule classification", res).await)
78        }
79    }
80
81    /// Get the status of a classification
82    ///
83    /// # Example
84    /// ```no_run
85    /// use uuid::Uuid;
86    /// use weaviate_community::WeaviateClient;
87    ///
88    /// #[tokio::main]
89    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
90    ///     let uuid = Uuid::parse_str("00037775-1432-35e5-bc59-443baaef7d80")?;
91    ///     let client = WeaviateClient::builder("http://localhost:8080").build()?;
92    ///
93    ///     let res = client.classification.get(uuid).await?;
94    ///     Ok(())
95    /// }
96    /// ```
97    pub async fn get(&self, id: Uuid) -> Result<ClassificationResponse, Box<dyn Error>> {
98        let endpoint = self.endpoint.join(&id.to_string())?;
99        let res = self.client.get(endpoint).send().await?;
100        match res.status() {
101            reqwest::StatusCode::OK => {
102                let res: ClassificationResponse = res.json().await?;
103                Ok(res)
104            }
105            _ => Err(self.get_err_msg("get classification", res).await)
106        }
107    }
108
109    /// Get the error message for the endpoint
110    ///
111    /// Made to reduce the boilerplate error message building
112    async fn get_err_msg(
113        &self,
114        endpoint: &str,
115        res: reqwest::Response
116    ) -> Box<ClassificationError> {
117        let status_code = res.status();
118        let msg: Result<serde_json::Value, reqwest::Error> = res.json().await;
119        let r_str: String;
120        if let Ok(json) = msg {
121            r_str = format!(
122                "Status code `{}` received when calling {} endpoint. Response: {}",
123                status_code,
124                endpoint,
125                json,
126            );
127        } else {
128            r_str = format!(
129                "Status code `{}` received when calling {} endpoint.",
130                status_code,
131                endpoint
132            );
133        }
134        Box::new(ClassificationError(r_str))
135    }
136}
137
138#[cfg(test)]
139mod tests {
140    use uuid::Uuid;
141    use crate::{
142        WeaviateClient,
143        collections::classification::{ClassificationRequest, ClassificationType}
144    };
145
146    fn get_test_harness() -> (mockito::ServerGuard, WeaviateClient) {
147        let mock_server = mockito::Server::new();
148        let mut host = "http://".to_string();
149        host.push_str(&mock_server.host_with_port());
150        let client = WeaviateClient::builder(&host).build().unwrap();
151        (mock_server, client)
152    }
153
154    fn test_classification_req() -> ClassificationRequest {
155        ClassificationRequest::builder()
156            .with_class("Test")
157            .with_type(ClassificationType::KNN)
158            .with_based_on_properties(vec!["testProp"])
159            .with_classify_properties(vec!["hasPopularity"])
160            .with_filters(serde_json::json!({
161                "path": ["testPropTwo"],
162                "operator": "GreaterThan",
163                "valueInt": 100
164            }))
165            .with_settings(serde_json::json!({"k": 3}))
166            .build()
167    }
168
169    fn mock_post(
170        server: &mut mockito::ServerGuard,
171        endpoint: &str,
172        status_code: usize,
173        body: &str,
174    ) -> mockito::Mock {
175        server
176            .mock("POST", endpoint)
177            .with_status(status_code)
178            .with_header("content-type", "application/json")
179            .with_body(body)
180            .create()
181    }
182
183    fn mock_get(
184        server: &mut mockito::ServerGuard,
185        endpoint: &str,
186        status_code: usize,
187        body: &str,
188    ) -> mockito::Mock {
189        server
190            .mock("GET", endpoint)
191            .with_status(status_code)
192            .with_header("content-type", "application/json")
193            .with_body(body)
194            .create()
195    }
196
197    #[tokio::test]
198    async fn test_classification_schedule_ok() {}
199
200    #[tokio::test]
201    async fn test_classification_schedule_err() {
202        let req = test_classification_req();
203        let (mut mock_server, client) = get_test_harness();
204        let mock = mock_post(&mut mock_server, "/v1/classifications/", 401, "");
205        let res = client.classification.schedule(req).await;
206        mock.assert();
207        assert!(res.is_err());
208    }
209
210    #[tokio::test]
211    async fn test_classification_get_ok() {}
212
213    #[tokio::test]
214    async fn test_classification_get_err() {
215        let uuid = Uuid::new_v4();
216        let mut url = String::from("/v1/classifications/");
217        url.push_str(&uuid.to_string());
218        let (mut mock_server, client) = get_test_harness();
219        let mock = mock_get(&mut mock_server, &url, 401, "");
220        let res = client.classification.get(uuid).await;
221        mock.assert();
222        assert!(res.is_err());
223    }
224}