swiftide_core/
query.rs

1//! A query is the main object going through a query pipeline
2//!
3//! It acts as a statemachine, with the following transitions:
4//!
5//! `states::Pending`: No documents have been retrieved
6//! `states::Retrieved`: Documents have been retrieved
7//! `states::Answered`: The query has been answered
8use derive_builder::Builder;
9
10use crate::{Embedding, SparseEmbedding, document::Document, util::debug_long_utf8};
11
12/// A query is the main object going through a query pipeline
13///
14/// It acts as a statemachine, with the following transitions:
15///
16/// `states::Pending`: No documents have been retrieved
17/// `states::Retrieved`: Documents have been retrieved
18/// `states::Answered`: The query has been answered
19#[derive(Clone, Default, Builder, PartialEq)]
20#[builder(setter(into))]
21pub struct Query<STATE: QueryState> {
22    original: String,
23    #[builder(default = "self.original.clone().unwrap_or_default()")]
24    current: String,
25    #[builder(default = STATE::default())]
26    state: STATE,
27    #[builder(default)]
28    transformation_history: Vec<TransformationEvent>,
29
30    // TODO: How would this work when doing a rollup query?
31    #[builder(default)]
32    pub embedding: Option<Embedding>,
33
34    #[builder(default)]
35    pub sparse_embedding: Option<SparseEmbedding>,
36
37    /// Documents the query will operate on
38    ///
39    /// A query can retrieve multiple times, accumulating documents
40    #[builder(default)]
41    pub documents: Vec<Document>,
42}
43
44impl<STATE: std::fmt::Debug + QueryState> std::fmt::Debug for Query<STATE> {
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        f.debug_struct("Query")
47            .field("original", &debug_long_utf8(&self.original, 100))
48            .field("current", &debug_long_utf8(&self.current, 100))
49            .field("state", &self.state)
50            .field("transformation_history", &self.transformation_history)
51            .field("embedding", &self.embedding.is_some())
52            .finish()
53    }
54}
55
56impl<STATE: Clone + QueryState> Query<STATE> {
57    pub fn builder() -> QueryBuilder<STATE> {
58        QueryBuilder::default().clone()
59    }
60
61    /// Return the query it started with
62    pub fn original(&self) -> &str {
63        &self.original
64    }
65
66    /// Return the current query (or after retrieval!)
67    pub fn current(&self) -> &str {
68        &self.current
69    }
70
71    fn transition_to<NEWSTATE: QueryState>(self, new_state: NEWSTATE) -> Query<NEWSTATE> {
72        Query {
73            state: new_state,
74            original: self.original,
75            current: self.current,
76            transformation_history: self.transformation_history,
77            embedding: self.embedding,
78            sparse_embedding: self.sparse_embedding,
79            documents: self.documents,
80        }
81    }
82
83    #[allow(dead_code)]
84    pub fn history(&self) -> &Vec<TransformationEvent> {
85        &self.transformation_history
86    }
87
88    /// Returns the current documents that will be used as context for answer generation
89    pub fn documents(&self) -> &[Document] {
90        &self.documents
91    }
92
93    /// Returns the current documents as mutable
94    pub fn documents_mut(&mut self) -> &mut Vec<Document> {
95        &mut self.documents
96    }
97}
98
99impl<STATE: Clone + CanRetrieve> Query<STATE> {
100    /// Add retrieved documents and transition to `states::Retrieved`
101    pub fn retrieved_documents(mut self, documents: Vec<Document>) -> Query<states::Retrieved> {
102        self.documents.extend(documents.clone());
103        self.transformation_history
104            .push(TransformationEvent::Retrieved {
105                before: self.current.clone(),
106                after: String::new(),
107                documents,
108            });
109
110        let state = states::Retrieved;
111
112        self.current.clear();
113        self.transition_to(state)
114    }
115}
116
117impl Query<states::Pending> {
118    pub fn new(query: impl Into<String>) -> Self {
119        Self {
120            original: query.into(),
121            ..Default::default()
122        }
123    }
124
125    /// Transforms the current query
126    pub fn transformed_query(&mut self, new_query: impl Into<String>) {
127        let new_query = new_query.into();
128
129        self.transformation_history
130            .push(TransformationEvent::Transformed {
131                before: self.current.clone(),
132                after: new_query.clone(),
133            });
134
135        self.current = new_query;
136    }
137}
138
139impl Query<states::Retrieved> {
140    pub fn new() -> Self {
141        Self::default()
142    }
143
144    /// Transforms the current response
145    pub fn transformed_response(&mut self, new_response: impl Into<String>) {
146        let new_response = new_response.into();
147
148        self.transformation_history
149            .push(TransformationEvent::Transformed {
150                before: self.current.clone(),
151                after: new_response.clone(),
152            });
153
154        self.current = new_response;
155    }
156
157    /// Transition the query to `states::Answered`
158    #[must_use]
159    pub fn answered(mut self, answer: impl Into<String>) -> Query<states::Answered> {
160        self.current = answer.into();
161        let state = states::Answered;
162        self.transition_to(state)
163    }
164}
165
166impl Query<states::Answered> {
167    pub fn new() -> Self {
168        Self::default()
169    }
170
171    /// Returns the answer of the query
172    pub fn answer(&self) -> &str {
173        &self.current
174    }
175}
176
177/// Marker trait for query states
178pub trait QueryState: Send + Sync + Default {}
179/// Marker trait for query states that can still retrieve
180pub trait CanRetrieve: QueryState {}
181
182/// States of a query
183pub mod states {
184    use super::{CanRetrieve, QueryState};
185
186    #[derive(Debug, Default, Clone, PartialEq)]
187    /// The query is pending and has not been used
188    pub struct Pending;
189
190    #[derive(Debug, Default, Clone, PartialEq)]
191    /// Documents have been retrieved
192    pub struct Retrieved;
193
194    #[derive(Debug, Default, Clone, PartialEq)]
195    /// The query has been answered
196    pub struct Answered;
197
198    impl QueryState for Pending {}
199    impl QueryState for Retrieved {}
200    impl QueryState for Answered {}
201
202    impl CanRetrieve for Pending {}
203    impl CanRetrieve for Retrieved {}
204}
205
206impl<T: AsRef<str>> From<T> for Query<states::Pending> {
207    fn from(original: T) -> Self {
208        Self {
209            original: original.as_ref().to_string(),
210            current: original.as_ref().to_string(),
211            state: states::Pending,
212            ..Default::default()
213        }
214    }
215}
216
217#[derive(Clone, PartialEq)]
218/// Records changes to a query
219pub enum TransformationEvent {
220    Transformed {
221        before: String,
222        after: String,
223    },
224    Retrieved {
225        before: String,
226        after: String,
227        documents: Vec<Document>,
228    },
229}
230
231impl std::fmt::Debug for TransformationEvent {
232    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
233        match self {
234            TransformationEvent::Transformed { before, after } => {
235                write!(
236                    f,
237                    "Transformed: {} -> {}",
238                    &debug_long_utf8(before, 100),
239                    &debug_long_utf8(after, 100)
240                )
241            }
242            TransformationEvent::Retrieved {
243                before,
244                after,
245                documents,
246            } => {
247                write!(
248                    f,
249                    "Retrieved: {} -> {}\nDocuments: {:?}",
250                    &debug_long_utf8(before, 100),
251                    &debug_long_utf8(after, 100),
252                    documents.len()
253                )
254            }
255        }
256    }
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262
263    #[test]
264    fn test_query_initial_state() {
265        let query = Query::<states::Pending>::from("test query");
266        assert_eq!(query.original(), "test query");
267        assert_eq!(query.current(), "test query");
268        assert_eq!(query.history().len(), 0);
269    }
270
271    #[test]
272    fn test_query_transformed_query() {
273        let mut query = Query::<states::Pending>::from("test query");
274        query.transformed_query("new query");
275        assert_eq!(query.current(), "new query");
276        assert_eq!(query.history().len(), 1);
277        if let TransformationEvent::Transformed { before, after } = &query.history()[0] {
278            assert_eq!(before, "test query");
279            assert_eq!(after, "new query");
280        } else {
281            panic!("Unexpected event in history");
282        }
283    }
284
285    #[test]
286    fn test_query_retrieved_documents() {
287        let query = Query::<states::Pending>::from("test query");
288        let documents: Vec<Document> = vec!["doc1".into(), "doc2".into()];
289        let query = query.retrieved_documents(documents.clone());
290        assert_eq!(query.documents(), &documents);
291        assert_eq!(query.history().len(), 1);
292        assert!(query.current().is_empty());
293        if let TransformationEvent::Retrieved {
294            before,
295            after,
296            documents: retrieved_docs,
297        } = &query.history()[0]
298        {
299            assert_eq!(before, "test query");
300            assert_eq!(after, "");
301            assert_eq!(retrieved_docs, &documents);
302        } else {
303            panic!("Unexpected event in history");
304        }
305    }
306
307    #[test]
308    fn test_query_transformed_response() {
309        let query = Query::<states::Pending>::from("test query");
310        let documents = vec!["doc1".into(), "doc2".into()];
311        let mut query = query.retrieved_documents(documents.clone());
312        query.transformed_response("new response");
313
314        assert_eq!(query.current(), "new response");
315        assert_eq!(query.history().len(), 2);
316        assert_eq!(query.documents(), &documents);
317        assert_eq!(query.original, "test query");
318        if let TransformationEvent::Transformed { before, after } = &query.history()[1] {
319            assert_eq!(before, "");
320            assert_eq!(after, "new response");
321        } else {
322            panic!("Unexpected event in history");
323        }
324    }
325
326    #[test]
327    fn test_query_answered() {
328        let query = Query::<states::Pending>::from("test query");
329        let documents = vec!["doc1".into(), "doc2".into()];
330        let query = query.retrieved_documents(documents);
331        let query = query.answered("the answer");
332
333        assert_eq!(query.answer(), "the answer");
334    }
335}