weaviate_community/
modules.rs1use reqwest::Url;
2use std::error::Error;
3use std::sync::Arc;
4use crate::collections::error::ModuleError;
5use crate::collections::modules::{ContextionaryConcept, ContextionaryExtension};
6
7#[derive(Debug)]
10pub struct Modules {
11 endpoint: Url,
12 client: Arc<reqwest::Client>,
13}
14
15impl Modules {
16 pub(super) fn new(url: &Url, client: Arc<reqwest::Client>) -> Result<Self, Box<dyn Error>> {
19 let endpoint = url.join("/v1/modules/")?;
20 Ok(Modules { endpoint, client })
21 }
22
23 pub async fn contextionary_get_concept(
41 &self,
42 concept: &str
43 ) -> Result<ContextionaryConcept, Box<dyn Error>> {
44 let mut endpoint = String::from("text2vec-contextionary/concepts/");
45 endpoint.push_str(concept);
46 let endpoint = self.endpoint.join(&endpoint)?;
47 let res = self.client.get(endpoint).send().await?;
48
49 match res.status() {
50 reqwest::StatusCode::OK => {
51 let res: ContextionaryConcept = res.json().await?;
52 Ok(res)
53 },
54 _ => Err(self.get_err_msg("text2vec-contextionary concepts", res).await),
55 }
56 }
57
58 pub async fn contextionary_extend(
78 &self,
79 concept: ContextionaryExtension
80 ) -> Result<ContextionaryExtension, Box<dyn Error>> {
81 let endpoint = self.endpoint.join("text2vec-contextionary/extensions")?;
82 let res = self
83 .client
84 .post(endpoint)
85 .json(&concept)
86 .send()
87 .await?;
88 match res.status() {
89 reqwest::StatusCode::OK => {
90 let res: ContextionaryExtension = res.json().await?;
91 Ok(res)
92 },
93 _ => Err(self.get_err_msg("text2vec-contextionary extend", res).await),
94 }
95 }
96
97 async fn get_err_msg(
101 &self,
102 endpoint: &str,
103 res: reqwest::Response
104 ) -> Box<ModuleError> {
105 let status_code = res.status();
106 let msg: Result<serde_json::Value, reqwest::Error> = res.json().await;
107 let r_str: String;
108 if let Ok(json) = msg {
109 r_str = format!(
110 "Status code `{}` received when calling {} endpoint. Response: {}",
111 status_code,
112 endpoint,
113 json,
114 );
115 } else {
116 r_str = format!(
117 "Status code `{}` received when calling {} endpoint.",
118 status_code,
119 endpoint
120 );
121 }
122 Box::new(ModuleError(r_str))
123 }
124}
125
126#[cfg(test)]
127mod tests {
128 use crate::{
129 WeaviateClient,
130 collections::modules::{
131 ContextionaryExtension,
132 ContextionaryConcept, IndividualWords
133 }
134 };
135
136 fn get_test_harness() -> (mockito::ServerGuard, WeaviateClient) {
137 let mock_server = mockito::Server::new();
138 let mut host = "http://".to_string();
139 host.push_str(&mock_server.host_with_port());
140 let client = WeaviateClient::builder(&host).build().unwrap();
141 (mock_server, client)
142 }
143
144 fn get_mock_concept_response() -> String {
145 serde_json::to_string(&ContextionaryConcept {
146 individual_words: vec![
147 IndividualWords {
148 info: None,
149 word: "test".into(),
150 present: None,
151 concatenated_word: None,
152 }
153 ]
154 }).unwrap()
155 }
156
157 fn mock_post(
158 server: &mut mockito::ServerGuard,
159 endpoint: &str,
160 status_code: usize,
161 body: &str,
162 ) -> mockito::Mock {
163 server
164 .mock("POST", endpoint)
165 .with_status(status_code)
166 .with_header("content-type", "application/json")
167 .with_body(body)
168 .create()
169 }
170
171 fn mock_get(
172 server: &mut mockito::ServerGuard,
173 endpoint: &str,
174 status_code: usize,
175 body: &str,
176 ) -> mockito::Mock {
177 server
178 .mock("GET", endpoint)
179 .with_status(status_code)
180 .with_header("content-type", "application/json")
181 .with_body(body)
182 .create()
183 }
184
185 #[tokio::test]
186 async fn test_get_concept_ok() {
187 let (mut mock_server, client) = get_test_harness();
188 let mock = mock_get(
189 &mut mock_server,
190 "/v1/modules/text2vec-contextionary/concepts/test",
191 200,
192 &get_mock_concept_response(),
193 );
194 let res = client.modules.contextionary_get_concept("test").await;
195 mock.assert();
196 assert!(res.is_ok());
197 }
198
199 #[tokio::test]
200 async fn test_get_concept_err() {
201 let (mut mock_server, client) = get_test_harness();
202 let mock = mock_get(
203 &mut mock_server,
204 "/v1/modules/text2vec-contextionary/concepts/test",
205 401,
206 "",
207 );
208 let res = client.modules.contextionary_get_concept("test").await;
209 mock.assert();
210 assert!(res.is_err());
211 }
212
213 #[tokio::test]
214 async fn test_extend_ok() {
215 let ext = ContextionaryExtension::new("test", "test", 1.0);
216 let ext_str = serde_json::to_string(&ext).unwrap();
217 let (mut mock_server, client) = get_test_harness();
218 let mock = mock_post(
219 &mut mock_server,
220 "/v1/modules/text2vec-contextionary/extensions",
221 200,
222 &ext_str,
223 );
224 let res = client.modules.contextionary_extend(ext).await;
225 mock.assert();
226 assert!(res.is_ok());
227 }
228
229 #[tokio::test]
230 async fn test_extend_err() {
231 let (mut mock_server, client) = get_test_harness();
232 let mock = mock_post(
233 &mut mock_server,
234 "/v1/modules/text2vec-contextionary/extensions",
235 401,
236 "",
237 );
238 let res = client.modules.contextionary_extend(
239 ContextionaryExtension::new("test", "test", 1.0)
240 ).await;
241 mock.assert();
242 assert!(res.is_err());
243 }
244}