weaviate_community/
modules.rs

1use reqwest::Url;
2use std::error::Error;
3use std::sync::Arc;
4use crate::collections::error::ModuleError;
5use crate::collections::modules::{ContextionaryConcept, ContextionaryExtension};
6
7/// All contextionary module related endpoints and functionality described in
8/// [Weaviate contextionary API documentation](https://weaviate.io/developers/weaviate/modules/retriever-vectorizer-modules/text2vec-contextionary)
9#[derive(Debug)]
10pub struct Modules {
11    endpoint: Url,
12    client: Arc<reqwest::Client>,
13}
14
15impl Modules {
16    /// Create a new Modules object. The modules object is intended to like inside the 
17    /// WeaviateClient and be called through the WeaviateClient.
18    pub(super) fn new(url: &Url, client: Arc<reqwest::Client>) -> Result<Self, Box<dyn Error>> {
19        let endpoint = url.join("/v1/modules/")?;
20        Ok(Modules { endpoint, client })
21    }
22
23    /// Get a concept from text2vec-contextionary.
24    ///
25    /// # Parameter
26    /// - concept: the concept to search for
27    ///
28    /// # Example
29    /// ```no_run
30    /// use weaviate_community::WeaviateClient;
31    ///
32    /// #[tokio::main]
33    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
34    ///     let client = WeaviateClient::builder("http://localhost:8080").build()?;
35    ///     let res = client.modules.contextionary_get_concept("concept").await;
36    ///
37    ///     Ok(())
38    /// }
39    /// ```
40    pub async fn contextionary_get_concept(
41        &self,
42        concept: &str
43    ) -> Result<ContextionaryConcept, Box<dyn Error>> {
44        let mut endpoint = String::from("text2vec-contextionary/concepts/");
45        endpoint.push_str(concept);
46        let endpoint = self.endpoint.join(&endpoint)?;
47        let res = self.client.get(endpoint).send().await?;
48
49        match res.status() {
50            reqwest::StatusCode::OK => {
51                let res: ContextionaryConcept = res.json().await?;
52                Ok(res)
53            },
54            _ => Err(self.get_err_msg("text2vec-contextionary concepts", res).await),
55        }
56    }
57
58    /// Extend text2vec-contextionary.
59    ///
60    /// # Parameter
61    /// - concept: the concept to extend contextionary with
62    ///
63    /// # Example
64    /// ```no_run
65    /// use weaviate_community::WeaviateClient;
66    /// use weaviate_community::collections::modules::ContextionaryExtension;
67    ///
68    /// #[tokio::main]
69    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
70    ///     let client = WeaviateClient::builder("http://localhost:8080").build()?;
71    ///     let ext = ContextionaryExtension::new("concept", "description", 1.0);
72    ///     let res = client.modules.contextionary_extend(ext).await;
73    ///
74    ///     Ok(())
75    /// }
76    /// ```
77    pub async fn contextionary_extend(
78        &self,
79        concept: ContextionaryExtension
80    ) -> Result<ContextionaryExtension, Box<dyn Error>> {
81        let endpoint = self.endpoint.join("text2vec-contextionary/extensions")?;
82        let res = self
83            .client
84            .post(endpoint)
85            .json(&concept)
86            .send()
87            .await?;
88        match res.status() {
89            reqwest::StatusCode::OK => {
90                let res: ContextionaryExtension = res.json().await?;
91                Ok(res)
92            },
93            _ => Err(self.get_err_msg("text2vec-contextionary extend", res).await),
94        }
95    }
96
97    /// Get the error message for the endpoint
98    ///
99    /// Made to reduce the boilerplate error message building
100    async fn get_err_msg(
101        &self,
102        endpoint: &str,
103        res: reqwest::Response
104    ) -> Box<ModuleError> {
105        let status_code = res.status();
106        let msg: Result<serde_json::Value, reqwest::Error> = res.json().await;
107        let r_str: String;
108        if let Ok(json) = msg {
109            r_str = format!(
110                "Status code `{}` received when calling {} endpoint. Response: {}",
111                status_code,
112                endpoint,
113                json,
114            );
115        } else {
116            r_str = format!(
117                "Status code `{}` received when calling {} endpoint.",
118                status_code,
119                endpoint
120            );
121        }
122        Box::new(ModuleError(r_str))
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use crate::{
129        WeaviateClient,
130        collections::modules::{
131            ContextionaryExtension,
132            ContextionaryConcept, IndividualWords
133        }
134    };
135
136    fn get_test_harness() -> (mockito::ServerGuard, WeaviateClient) {
137        let mock_server = mockito::Server::new();
138        let mut host = "http://".to_string();
139        host.push_str(&mock_server.host_with_port());
140        let client = WeaviateClient::builder(&host).build().unwrap();
141        (mock_server, client)
142    }
143
144    fn get_mock_concept_response() -> String {
145        serde_json::to_string(&ContextionaryConcept { 
146            individual_words: vec![
147                IndividualWords {
148                    info: None,
149                    word: "test".into(),
150                    present: None,
151                    concatenated_word: None,
152                }
153            ]
154        }).unwrap()
155    }
156
157    fn mock_post(
158        server: &mut mockito::ServerGuard,
159        endpoint: &str,
160        status_code: usize,
161        body: &str,
162    ) -> mockito::Mock {
163        server
164            .mock("POST", endpoint)
165            .with_status(status_code)
166            .with_header("content-type", "application/json")
167            .with_body(body)
168            .create()
169    }
170
171    fn mock_get(
172        server: &mut mockito::ServerGuard,
173        endpoint: &str,
174        status_code: usize,
175        body: &str,
176    ) -> mockito::Mock {
177        server
178            .mock("GET", endpoint)
179            .with_status(status_code)
180            .with_header("content-type", "application/json")
181            .with_body(body)
182            .create()
183    }
184
185    #[tokio::test]
186    async fn test_get_concept_ok() {
187        let (mut mock_server, client) = get_test_harness();
188        let mock = mock_get(
189            &mut mock_server,
190            "/v1/modules/text2vec-contextionary/concepts/test",
191            200,
192            &get_mock_concept_response(),
193        );
194        let res = client.modules.contextionary_get_concept("test").await;
195        mock.assert();
196        assert!(res.is_ok());
197    }
198
199    #[tokio::test]
200    async fn test_get_concept_err() {
201        let (mut mock_server, client) = get_test_harness();
202        let mock = mock_get(
203            &mut mock_server,
204            "/v1/modules/text2vec-contextionary/concepts/test",
205            401,
206            "",
207        );
208        let res = client.modules.contextionary_get_concept("test").await;
209        mock.assert();
210        assert!(res.is_err());
211    }
212
213    #[tokio::test]
214    async fn test_extend_ok() {
215        let ext = ContextionaryExtension::new("test", "test", 1.0);
216        let ext_str = serde_json::to_string(&ext).unwrap();
217        let (mut mock_server, client) = get_test_harness();
218        let mock = mock_post(
219            &mut mock_server,
220            "/v1/modules/text2vec-contextionary/extensions",
221            200,
222            &ext_str,
223        );
224        let res = client.modules.contextionary_extend(ext).await;
225        mock.assert();
226        assert!(res.is_ok());
227    }
228
229    #[tokio::test]
230    async fn test_extend_err() {
231        let (mut mock_server, client) = get_test_harness();
232        let mock = mock_post(
233            &mut mock_server,
234            "/v1/modules/text2vec-contextionary/extensions",
235            401,
236            "",
237        );
238        let res = client.modules.contextionary_extend(
239            ContextionaryExtension::new("test", "test", 1.0)
240        ).await;
241        mock.assert();
242        assert!(res.is_err());
243    }
244}