swiftide_core/
query.rs

1//! A query is the main object going through a query pipeline
2//!
3//! It acts as a statemachine, with the following transitions:
4//!
5//! `states::Pending`: No documents have been retrieved
6//! `states::Retrieved`: Documents have been retrieved
7//! `states::Answered`: The query has been answered
8use derive_builder::Builder;
9
10use crate::{Embedding, SparseEmbedding, document::Document, util::debug_long_utf8};
11
12/// A query is the main object going through a query pipeline
13///
14/// It acts as a statemachine, with the following transitions:
15///
16/// `states::Pending`: No documents have been retrieved
17/// `states::Retrieved`: Documents have been retrieved
18/// `states::Answered`: The query has been answered
19#[derive(Clone, Default, Builder, PartialEq)]
20#[builder(setter(into))]
21pub struct Query<STATE: QueryState> {
22    original: String,
23    #[builder(default = "self.original.clone().unwrap_or_default()")]
24    current: String,
25    #[builder(default = STATE::default())]
26    state: STATE,
27    #[builder(default)]
28    transformation_history: Vec<TransformationEvent>,
29
30    // TODO: How would this work when doing a rollup query?
31    #[builder(default)]
32    pub embedding: Option<Embedding>,
33
34    #[builder(default)]
35    pub sparse_embedding: Option<SparseEmbedding>,
36
37    /// Documents the query will operate on
38    ///
39    /// A query can retrieve multiple times, accumulating documents
40    #[builder(default)]
41    pub documents: Vec<Document>,
42}
43
44impl<STATE: std::fmt::Debug + QueryState> std::fmt::Debug for Query<STATE> {
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        f.debug_struct("Query")
47            .field(
48                "original",
49                &debug_long_utf8(&self.original, 100).lines().take(1),
50            )
51            .field(
52                "current",
53                &debug_long_utf8(&self.current, 100).lines().take(1),
54            )
55            .field("state", &self.state)
56            .field("num_transformations", &self.transformation_history.len())
57            .field("embedding", &self.embedding.is_some())
58            .field("num_documents", &self.documents.len())
59            .finish()
60    }
61}
62
63impl<STATE: Clone + QueryState> Query<STATE> {
64    pub fn builder() -> QueryBuilder<STATE> {
65        QueryBuilder::default().clone()
66    }
67
68    /// Return the query it started with
69    pub fn original(&self) -> &str {
70        &self.original
71    }
72
73    /// Return the current query (or after retrieval!)
74    pub fn current(&self) -> &str {
75        &self.current
76    }
77
78    fn transition_to<NEWSTATE: QueryState>(self, new_state: NEWSTATE) -> Query<NEWSTATE> {
79        Query {
80            state: new_state,
81            original: self.original,
82            current: self.current,
83            transformation_history: self.transformation_history,
84            embedding: self.embedding,
85            sparse_embedding: self.sparse_embedding,
86            documents: self.documents,
87        }
88    }
89
90    #[allow(dead_code)]
91    pub fn history(&self) -> &Vec<TransformationEvent> {
92        &self.transformation_history
93    }
94
95    /// Returns the current documents that will be used as context for answer generation
96    pub fn documents(&self) -> &[Document] {
97        &self.documents
98    }
99
100    /// Returns the current documents as mutable
101    pub fn documents_mut(&mut self) -> &mut Vec<Document> {
102        &mut self.documents
103    }
104}
105
106impl<STATE: Clone + CanRetrieve> Query<STATE> {
107    /// Add retrieved documents and transition to `states::Retrieved`
108    pub fn retrieved_documents(mut self, documents: Vec<Document>) -> Query<states::Retrieved> {
109        self.documents.extend(documents.clone());
110        self.transformation_history
111            .push(TransformationEvent::Retrieved {
112                before: self.current.clone(),
113                after: String::new(),
114                documents,
115            });
116
117        let state = states::Retrieved;
118
119        self.current.clear();
120        self.transition_to(state)
121    }
122}
123
124impl Query<states::Pending> {
125    pub fn new(query: impl Into<String>) -> Self {
126        Self {
127            original: query.into(),
128            ..Default::default()
129        }
130    }
131
132    /// Transforms the current query
133    pub fn transformed_query(&mut self, new_query: impl Into<String>) {
134        let new_query = new_query.into();
135
136        self.transformation_history
137            .push(TransformationEvent::Transformed {
138                before: self.current.clone(),
139                after: new_query.clone(),
140            });
141
142        self.current = new_query;
143    }
144}
145
146impl Query<states::Retrieved> {
147    pub fn new() -> Self {
148        Self::default()
149    }
150
151    /// Transforms the current response
152    pub fn transformed_response(&mut self, new_response: impl Into<String>) {
153        let new_response = new_response.into();
154
155        self.transformation_history
156            .push(TransformationEvent::Transformed {
157                before: self.current.clone(),
158                after: new_response.clone(),
159            });
160
161        self.current = new_response;
162    }
163
164    /// Transition the query to `states::Answered`
165    #[must_use]
166    pub fn answered(mut self, answer: impl Into<String>) -> Query<states::Answered> {
167        self.current = answer.into();
168        let state = states::Answered;
169        self.transition_to(state)
170    }
171}
172
173impl Query<states::Answered> {
174    pub fn new() -> Self {
175        Self::default()
176    }
177
178    /// Returns the answer of the query
179    pub fn answer(&self) -> &str {
180        &self.current
181    }
182}
183
184/// Marker trait for query states
185pub trait QueryState: Send + Sync + Default {}
186/// Marker trait for query states that can still retrieve
187pub trait CanRetrieve: QueryState {}
188
189/// States of a query
190pub mod states {
191    use super::{CanRetrieve, QueryState};
192
193    #[derive(Debug, Default, Clone, PartialEq)]
194    /// The query is pending and has not been used
195    pub struct Pending;
196
197    #[derive(Debug, Default, Clone, PartialEq)]
198    /// Documents have been retrieved
199    pub struct Retrieved;
200
201    #[derive(Debug, Default, Clone, PartialEq)]
202    /// The query has been answered
203    pub struct Answered;
204
205    impl QueryState for Pending {}
206    impl QueryState for Retrieved {}
207    impl QueryState for Answered {}
208
209    impl CanRetrieve for Pending {}
210    impl CanRetrieve for Retrieved {}
211}
212
213impl<T: AsRef<str>> From<T> for Query<states::Pending> {
214    fn from(original: T) -> Self {
215        Self {
216            original: original.as_ref().to_string(),
217            current: original.as_ref().to_string(),
218            state: states::Pending,
219            ..Default::default()
220        }
221    }
222}
223
224#[derive(Clone, PartialEq)]
225/// Records changes to a query
226pub enum TransformationEvent {
227    Transformed {
228        before: String,
229        after: String,
230    },
231    Retrieved {
232        before: String,
233        after: String,
234        documents: Vec<Document>,
235    },
236}
237
238impl std::fmt::Debug for TransformationEvent {
239    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
240        match self {
241            TransformationEvent::Transformed { before, after } => {
242                write!(
243                    f,
244                    "Transformed: {} -> {}",
245                    &debug_long_utf8(before, 100),
246                    &debug_long_utf8(after, 100)
247                )
248            }
249            TransformationEvent::Retrieved {
250                before,
251                after,
252                documents,
253            } => {
254                write!(
255                    f,
256                    "Retrieved: {} -> {}\nDocuments: {:?}",
257                    &debug_long_utf8(before, 100),
258                    &debug_long_utf8(after, 100),
259                    documents.len()
260                )
261            }
262        }
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269
270    #[test]
271    fn test_query_initial_state() {
272        let query = Query::<states::Pending>::from("test query");
273        assert_eq!(query.original(), "test query");
274        assert_eq!(query.current(), "test query");
275        assert_eq!(query.history().len(), 0);
276    }
277
278    #[test]
279    fn test_query_transformed_query() {
280        let mut query = Query::<states::Pending>::from("test query");
281        query.transformed_query("new query");
282        assert_eq!(query.current(), "new query");
283        assert_eq!(query.history().len(), 1);
284        if let TransformationEvent::Transformed { before, after } = &query.history()[0] {
285            assert_eq!(before, "test query");
286            assert_eq!(after, "new query");
287        } else {
288            panic!("Unexpected event in history");
289        }
290    }
291
292    #[test]
293    fn test_query_retrieved_documents() {
294        let query = Query::<states::Pending>::from("test query");
295        let documents: Vec<Document> = vec!["doc1".into(), "doc2".into()];
296        let query = query.retrieved_documents(documents.clone());
297        assert_eq!(query.documents(), &documents);
298        assert_eq!(query.history().len(), 1);
299        assert!(query.current().is_empty());
300        if let TransformationEvent::Retrieved {
301            before,
302            after,
303            documents: retrieved_docs,
304        } = &query.history()[0]
305        {
306            assert_eq!(before, "test query");
307            assert_eq!(after, "");
308            assert_eq!(retrieved_docs, &documents);
309        } else {
310            panic!("Unexpected event in history");
311        }
312    }
313
314    #[test]
315    fn test_query_transformed_response() {
316        let query = Query::<states::Pending>::from("test query");
317        let documents = vec!["doc1".into(), "doc2".into()];
318        let mut query = query.retrieved_documents(documents.clone());
319        query.transformed_response("new response");
320
321        assert_eq!(query.current(), "new response");
322        assert_eq!(query.history().len(), 2);
323        assert_eq!(query.documents(), &documents);
324        assert_eq!(query.original, "test query");
325        if let TransformationEvent::Transformed { before, after } = &query.history()[1] {
326            assert_eq!(before, "");
327            assert_eq!(after, "new response");
328        } else {
329            panic!("Unexpected event in history");
330        }
331    }
332
333    #[test]
334    fn test_query_answered() {
335        let query = Query::<states::Pending>::from("test query");
336        let documents = vec!["doc1".into(), "doc2".into()];
337        let query = query.retrieved_documents(documents);
338        let query = query.answered("the answer");
339
340        assert_eq!(query.answer(), "the answer");
341    }
342}