1use derive_builder::Builder;
9
10use crate::{Embedding, SparseEmbedding, document::Document, util::debug_long_utf8};
11
12#[derive(Clone, Default, Builder, PartialEq)]
20#[builder(setter(into))]
21pub struct Query<STATE: QueryState> {
22 original: String,
23 #[builder(default = "self.original.clone().unwrap_or_default()")]
24 current: String,
25 #[builder(default = STATE::default())]
26 state: STATE,
27 #[builder(default)]
28 transformation_history: Vec<TransformationEvent>,
29
30 #[builder(default)]
32 pub embedding: Option<Embedding>,
33
34 #[builder(default)]
35 pub sparse_embedding: Option<SparseEmbedding>,
36
37 #[builder(default)]
41 pub documents: Vec<Document>,
42}
43
44impl<STATE: std::fmt::Debug + QueryState> std::fmt::Debug for Query<STATE> {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 f.debug_struct("Query")
47 .field("original", &debug_long_utf8(&self.original, 100))
48 .field("current", &debug_long_utf8(&self.current, 100))
49 .field("state", &self.state)
50 .field("transformation_history", &self.transformation_history)
51 .field("embedding", &self.embedding.is_some())
52 .finish()
53 }
54}
55
56impl<STATE: Clone + QueryState> Query<STATE> {
57 pub fn builder() -> QueryBuilder<STATE> {
58 QueryBuilder::default().clone()
59 }
60
61 pub fn original(&self) -> &str {
63 &self.original
64 }
65
66 pub fn current(&self) -> &str {
68 &self.current
69 }
70
71 fn transition_to<NEWSTATE: QueryState>(self, new_state: NEWSTATE) -> Query<NEWSTATE> {
72 Query {
73 state: new_state,
74 original: self.original,
75 current: self.current,
76 transformation_history: self.transformation_history,
77 embedding: self.embedding,
78 sparse_embedding: self.sparse_embedding,
79 documents: self.documents,
80 }
81 }
82
83 #[allow(dead_code)]
84 pub fn history(&self) -> &Vec<TransformationEvent> {
85 &self.transformation_history
86 }
87
88 pub fn documents(&self) -> &[Document] {
90 &self.documents
91 }
92
93 pub fn documents_mut(&mut self) -> &mut Vec<Document> {
95 &mut self.documents
96 }
97}
98
99impl<STATE: Clone + CanRetrieve> Query<STATE> {
100 pub fn retrieved_documents(mut self, documents: Vec<Document>) -> Query<states::Retrieved> {
102 self.documents.extend(documents.clone());
103 self.transformation_history
104 .push(TransformationEvent::Retrieved {
105 before: self.current.clone(),
106 after: String::new(),
107 documents,
108 });
109
110 let state = states::Retrieved;
111
112 self.current.clear();
113 self.transition_to(state)
114 }
115}
116
117impl Query<states::Pending> {
118 pub fn new(query: impl Into<String>) -> Self {
119 Self {
120 original: query.into(),
121 ..Default::default()
122 }
123 }
124
125 pub fn transformed_query(&mut self, new_query: impl Into<String>) {
127 let new_query = new_query.into();
128
129 self.transformation_history
130 .push(TransformationEvent::Transformed {
131 before: self.current.clone(),
132 after: new_query.clone(),
133 });
134
135 self.current = new_query;
136 }
137}
138
139impl Query<states::Retrieved> {
140 pub fn new() -> Self {
141 Self::default()
142 }
143
144 pub fn transformed_response(&mut self, new_response: impl Into<String>) {
146 let new_response = new_response.into();
147
148 self.transformation_history
149 .push(TransformationEvent::Transformed {
150 before: self.current.clone(),
151 after: new_response.clone(),
152 });
153
154 self.current = new_response;
155 }
156
157 #[must_use]
159 pub fn answered(mut self, answer: impl Into<String>) -> Query<states::Answered> {
160 self.current = answer.into();
161 let state = states::Answered;
162 self.transition_to(state)
163 }
164}
165
166impl Query<states::Answered> {
167 pub fn new() -> Self {
168 Self::default()
169 }
170
171 pub fn answer(&self) -> &str {
173 &self.current
174 }
175}
176
177pub trait QueryState: Send + Sync + Default {}
179pub trait CanRetrieve: QueryState {}
181
182pub mod states {
184 use super::{CanRetrieve, QueryState};
185
186 #[derive(Debug, Default, Clone, PartialEq)]
187 pub struct Pending;
189
190 #[derive(Debug, Default, Clone, PartialEq)]
191 pub struct Retrieved;
193
194 #[derive(Debug, Default, Clone, PartialEq)]
195 pub struct Answered;
197
198 impl QueryState for Pending {}
199 impl QueryState for Retrieved {}
200 impl QueryState for Answered {}
201
202 impl CanRetrieve for Pending {}
203 impl CanRetrieve for Retrieved {}
204}
205
206impl<T: AsRef<str>> From<T> for Query<states::Pending> {
207 fn from(original: T) -> Self {
208 Self {
209 original: original.as_ref().to_string(),
210 current: original.as_ref().to_string(),
211 state: states::Pending,
212 ..Default::default()
213 }
214 }
215}
216
217#[derive(Clone, PartialEq)]
218pub enum TransformationEvent {
220 Transformed {
221 before: String,
222 after: String,
223 },
224 Retrieved {
225 before: String,
226 after: String,
227 documents: Vec<Document>,
228 },
229}
230
231impl std::fmt::Debug for TransformationEvent {
232 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
233 match self {
234 TransformationEvent::Transformed { before, after } => {
235 write!(
236 f,
237 "Transformed: {} -> {}",
238 &debug_long_utf8(before, 100),
239 &debug_long_utf8(after, 100)
240 )
241 }
242 TransformationEvent::Retrieved {
243 before,
244 after,
245 documents,
246 } => {
247 write!(
248 f,
249 "Retrieved: {} -> {}\nDocuments: {:?}",
250 &debug_long_utf8(before, 100),
251 &debug_long_utf8(after, 100),
252 documents.len()
253 )
254 }
255 }
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262
263 #[test]
264 fn test_query_initial_state() {
265 let query = Query::<states::Pending>::from("test query");
266 assert_eq!(query.original(), "test query");
267 assert_eq!(query.current(), "test query");
268 assert_eq!(query.history().len(), 0);
269 }
270
271 #[test]
272 fn test_query_transformed_query() {
273 let mut query = Query::<states::Pending>::from("test query");
274 query.transformed_query("new query");
275 assert_eq!(query.current(), "new query");
276 assert_eq!(query.history().len(), 1);
277 if let TransformationEvent::Transformed { before, after } = &query.history()[0] {
278 assert_eq!(before, "test query");
279 assert_eq!(after, "new query");
280 } else {
281 panic!("Unexpected event in history");
282 }
283 }
284
285 #[test]
286 fn test_query_retrieved_documents() {
287 let query = Query::<states::Pending>::from("test query");
288 let documents: Vec<Document> = vec!["doc1".into(), "doc2".into()];
289 let query = query.retrieved_documents(documents.clone());
290 assert_eq!(query.documents(), &documents);
291 assert_eq!(query.history().len(), 1);
292 assert!(query.current().is_empty());
293 if let TransformationEvent::Retrieved {
294 before,
295 after,
296 documents: retrieved_docs,
297 } = &query.history()[0]
298 {
299 assert_eq!(before, "test query");
300 assert_eq!(after, "");
301 assert_eq!(retrieved_docs, &documents);
302 } else {
303 panic!("Unexpected event in history");
304 }
305 }
306
307 #[test]
308 fn test_query_transformed_response() {
309 let query = Query::<states::Pending>::from("test query");
310 let documents = vec!["doc1".into(), "doc2".into()];
311 let mut query = query.retrieved_documents(documents.clone());
312 query.transformed_response("new response");
313
314 assert_eq!(query.current(), "new response");
315 assert_eq!(query.history().len(), 2);
316 assert_eq!(query.documents(), &documents);
317 assert_eq!(query.original, "test query");
318 if let TransformationEvent::Transformed { before, after } = &query.history()[1] {
319 assert_eq!(before, "");
320 assert_eq!(after, "new response");
321 } else {
322 panic!("Unexpected event in history");
323 }
324 }
325
326 #[test]
327 fn test_query_answered() {
328 let query = Query::<states::Pending>::from("test query");
329 let documents = vec!["doc1".into(), "doc2".into()];
330 let query = query.retrieved_documents(documents);
331 let query = query.answered("the answer");
332
333 assert_eq!(query.answer(), "the answer");
334 }
335}