weaviate_community/collections/classification.rs
1/// All classification associated type components
2use serde::{Deserialize, Serialize};
3use uuid::Uuid;
4
5/// A new ClassificationRequest used to make classification requests
6#[derive(Serialize, Deserialize, Debug)]
7#[serde(rename_all = "camelCase")]
8pub struct ClassificationRequest {
9 #[serde(rename = "type")]
10 pub classification_type: ClassificationType,
11 pub class: String,
12 pub classify_properties: Vec<String>,
13 #[serde(skip_serializing_if = "Option::is_none")]
14 pub based_on_properties: Option<Vec<String>>,
15 pub filters: serde_json::Value,
16 #[serde(skip_serializing_if = "Option::is_none")]
17 #[serde(default)]
18 pub settings: Option<serde_json::Value>,
19}
20
21impl ClassificationRequest {
22 /// Create a new builder for the ClassificationRequest.
23 ///
24 /// This is the same as `ClassificationRequestBuilder::new()`.
25 ///
26 /// # Example
27 /// ```rust
28 /// use weaviate_community::collections::classification::ClassificationRequest;
29 ///
30 /// let builder = ClassificationRequest::builder();
31 /// ```
32 pub fn builder() -> ClassificationRequestBuilder {
33 ClassificationRequestBuilder::default()
34 }
35}
36
37/// Builder for the ClassificationRequest
38#[derive(Debug, Default)]
39pub struct ClassificationRequestBuilder {
40 pub classification_type: ClassificationType,
41 pub class: String,
42 pub classify_properties: Vec<String>,
43 pub based_on_properties: Option<Vec<String>>,
44 pub filters: serde_json::Value,
45 pub settings: Option<serde_json::Value>,
46}
47
48impl ClassificationRequestBuilder {
49 /// Create a new builder for the ClassificationRequest.
50 ///
51 /// This is the same as `ClassificationRequestBuilder::new()`.
52 ///
53 /// The resulting object will have no populated fields. These need filling out in accordance
54 /// with the required classification type using the builder methods.
55 ///
56 /// # Example
57 /// ```rust
58 /// use weaviate_community::collections::classification::ClassificationRequestBuilder;
59 ///
60 /// let builder = ClassificationRequestBuilder::new();
61 /// ```
62 pub fn new() -> ClassificationRequestBuilder {
63 ClassificationRequestBuilder::default()
64 }
65
66 /// Add a value to the `classification_type` property of the ClassificationRequest.
67 ///
68 /// # Parameters
69 /// - classification_type: the classification_type to use for the property
70 ///
71 /// # Example
72 /// ```rust
73 /// use weaviate_community::collections::classification::{
74 /// ClassificationRequestBuilder,
75 /// ClassificationType
76 /// };
77 ///
78 /// let builder = ClassificationRequestBuilder::new()
79 /// .with_type(ClassificationType::KNN);
80 /// ```
81 pub fn with_type(
82 mut self,
83 classification_type: ClassificationType,
84 ) -> ClassificationRequestBuilder {
85 self.classification_type = classification_type;
86 self
87 }
88
89 /// Add a value to the `class` property of the ClassificationRequest.
90 ///
91 /// # Parameters
92 /// - class_name: the name of the class to run the classification on
93 ///
94 /// # Example
95 /// ```rust
96 /// use weaviate_community::collections::classification::ClassificationRequestBuilder;
97 ///
98 /// let builder = ClassificationRequestBuilder::new()
99 /// .with_class("Article");
100 /// ```
101 pub fn with_class(mut self, class_name: &str) -> ClassificationRequestBuilder {
102 self.class = class_name.into();
103 self
104 }
105
106 /// Add a value to the `classify_properties` property of the ClassificationRequest.
107 ///
108 /// # Parameters
109 /// - classify_properties: the properties to classify
110 ///
111 /// # Example
112 /// ```rust
113 /// use weaviate_community::collections::classification::ClassificationRequestBuilder;
114 ///
115 /// let builder = ClassificationRequestBuilder::new()
116 /// .with_classify_properties(vec!["hasPopularity"]);
117 /// ```
118 pub fn with_classify_properties(
119 mut self,
120 classify_properties: Vec<&str>,
121 ) -> ClassificationRequestBuilder {
122 let classify_properties = classify_properties
123 .iter()
124 .map(|field| field.to_string())
125 .collect();
126 self.classify_properties = classify_properties;
127 self
128 }
129
130 /// Add a value to the `based_on_properties` property of the ClassificationRequest.
131 ///
132 /// # Parameters
133 /// - based_on_properties: the 'based on' properties to classify against
134 ///
135 /// # Example
136 /// ```rust
137 /// use weaviate_community::collections::classification::ClassificationRequestBuilder;
138 ///
139 /// let builder = ClassificationRequestBuilder::new()
140 /// .with_based_on_properties(vec!["summary"]);
141 /// ```
142 pub fn with_based_on_properties(
143 mut self,
144 based_on_properties: Vec<&str>,
145 ) -> ClassificationRequestBuilder {
146 let based_on_properties = based_on_properties
147 .iter()
148 .map(|field| field.to_string())
149 .collect();
150 self.based_on_properties = Some(based_on_properties);
151 self
152 }
153
154 /// Add a value to the `filters` property of the ClassificationRequest.
155 ///
156 /// # Parameters
157 /// - filters: the filters for the classifier to use when retrieving results
158 ///
159 /// # Example
160 /// ```rust
161 /// use weaviate_community::collections::classification::ClassificationRequestBuilder;
162 ///
163 /// let builder = ClassificationRequestBuilder::new()
164 /// .with_filters(serde_json::json!(
165 /// {"path": ["wordCount"], "operator": "GreaterThan", "valueInt": 100}
166 /// ));
167 /// ```
168 pub fn with_filters(mut self, filters: serde_json::Value) -> ClassificationRequestBuilder {
169 self.filters = filters;
170 self
171 }
172
173 /// Add a value to the `settings` property of the ClassificationRequest.
174 ///
175 /// # Parameters
176 /// - settings: the settings for the classifier
177 ///
178 /// # Example
179 /// ```rust
180 /// use weaviate_community::collections::classification::ClassificationRequestBuilder;
181 ///
182 /// let builder = ClassificationRequestBuilder::new()
183 /// .with_settings(serde_json::json!({"k": 3}));
184 /// ```
185 pub fn with_settings(mut self, settings: serde_json::Value) -> ClassificationRequestBuilder {
186 self.settings = Some(settings);
187 self
188 }
189
190 /// Build the ClassificationRequest from the ClassificationRequestBuilder
191 ///
192 /// # Example
193 /// Using ClassificationRequestBuilder
194 /// ```rust
195 /// use weaviate_community::collections::classification::{
196 /// ClassificationRequestBuilder,
197 /// ClassificationType
198 /// };
199 ///
200 /// let builder = ClassificationRequestBuilder::new()
201 /// .with_type(ClassificationType::KNN)
202 /// .with_class("Article")
203 /// .with_classify_properties(vec!["hasPopularity"])
204 /// .with_based_on_properties(vec!["summary"])
205 /// .with_filters(serde_json::json!(
206 /// {"path": ["wordCount"], "operator": "GreaterThan", "valueInt": 100}
207 /// ))
208 /// .with_settings(serde_json::json!({"k": 3}))
209 /// .build();
210 /// ```
211 ///
212 /// Using ClassificationRequest
213 /// ```rust
214 /// use weaviate_community::collections::classification::{
215 /// ClassificationRequest,
216 /// ClassificationType
217 /// };
218 ///
219 /// let builder = ClassificationRequest::builder()
220 /// .with_type(ClassificationType::KNN)
221 /// .with_class("Article")
222 /// .with_classify_properties(vec!["hasPopularity"])
223 /// .with_based_on_properties(vec!["summary"])
224 /// .with_filters(serde_json::json!(
225 /// {"path": ["wordCount"], "operator": "GreaterThan", "valueInt": 100}
226 /// ))
227 /// .with_settings(serde_json::json!({"k": 3}))
228 /// .build();
229 /// ```
230 pub fn build(self) -> ClassificationRequest {
231 ClassificationRequest {
232 classification_type: self.classification_type,
233 class: self.class,
234 classify_properties: self.classify_properties,
235 based_on_properties: self.based_on_properties,
236 filters: self.filters,
237 settings: self.settings,
238 }
239 }
240}
241
242/// Types of classification available
243#[derive(Serialize, Deserialize, Debug, Default)]
244pub enum ClassificationType {
245 #[default]
246 #[serde(rename = "knn")]
247 KNN,
248 #[serde(rename = "zeroshot")]
249 ZEROSHOT,
250}
251
252/// Response received from the classification
253#[derive(Serialize, Deserialize, Debug)]
254#[serde(rename_all = "camelCase")]
255pub struct ClassificationResponse {
256 pub id: Uuid,
257 pub class: String,
258 pub classify_properties: Vec<String>,
259 #[serde(skip_serializing_if = "Option::is_none")]
260 pub based_on_properties: Option<Vec<String>>,
261 pub status: String,
262 pub meta: ClassificationMetadata,
263 #[serde(rename = "type")]
264 pub classification_type: String,
265 #[serde(skip_serializing_if = "Option::is_none")]
266 #[serde(default)]
267 pub settings: Option<serde_json::Value>,
268 pub filters: serde_json::Value,
269}
270
271/// Metadata for the Classification
272#[derive(Serialize, Deserialize, Debug)]
273#[serde(rename_all = "camelCase")]
274pub struct ClassificationMetadata {
275 pub started: String,
276 pub completed: String,
277 pub count: Option<u64>,
278 pub count_succeeded: Option<u64>,
279 pub count_failed: Option<u64>,
280}