1use derive_builder::Builder;
9
10use crate::{Embedding, SparseEmbedding, document::Document, util::debug_long_utf8};
11
12#[derive(Clone, Default, Builder, PartialEq)]
20#[builder(setter(into))]
21pub struct Query<STATE: QueryState> {
22 original: String,
23 #[builder(default = "self.original.clone().unwrap_or_default()")]
24 current: String,
25 #[builder(default = STATE::default())]
26 state: STATE,
27 #[builder(default)]
28 transformation_history: Vec<TransformationEvent>,
29
30 #[builder(default)]
32 pub embedding: Option<Embedding>,
33
34 #[builder(default)]
35 pub sparse_embedding: Option<SparseEmbedding>,
36
37 #[builder(default)]
41 pub documents: Vec<Document>,
42}
43
44impl<STATE: std::fmt::Debug + QueryState> std::fmt::Debug for Query<STATE> {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 f.debug_struct("Query")
47 .field(
48 "original",
49 &debug_long_utf8(&self.original, 100).lines().take(1),
50 )
51 .field(
52 "current",
53 &debug_long_utf8(&self.current, 100).lines().take(1),
54 )
55 .field("state", &self.state)
56 .field("num_transformations", &self.transformation_history.len())
57 .field("embedding", &self.embedding.is_some())
58 .field("num_documents", &self.documents.len())
59 .finish()
60 }
61}
62
63impl<STATE: Clone + QueryState> Query<STATE> {
64 pub fn builder() -> QueryBuilder<STATE> {
65 QueryBuilder::default().clone()
66 }
67
68 pub fn original(&self) -> &str {
70 &self.original
71 }
72
73 pub fn current(&self) -> &str {
75 &self.current
76 }
77
78 fn transition_to<NEWSTATE: QueryState>(self, new_state: NEWSTATE) -> Query<NEWSTATE> {
79 Query {
80 state: new_state,
81 original: self.original,
82 current: self.current,
83 transformation_history: self.transformation_history,
84 embedding: self.embedding,
85 sparse_embedding: self.sparse_embedding,
86 documents: self.documents,
87 }
88 }
89
90 #[allow(dead_code)]
91 pub fn history(&self) -> &Vec<TransformationEvent> {
92 &self.transformation_history
93 }
94
95 pub fn documents(&self) -> &[Document] {
97 &self.documents
98 }
99
100 pub fn documents_mut(&mut self) -> &mut Vec<Document> {
102 &mut self.documents
103 }
104}
105
106impl<STATE: Clone + CanRetrieve> Query<STATE> {
107 pub fn retrieved_documents(mut self, documents: Vec<Document>) -> Query<states::Retrieved> {
109 self.documents.extend(documents.clone());
110 self.transformation_history
111 .push(TransformationEvent::Retrieved {
112 before: self.current.clone(),
113 after: String::new(),
114 documents,
115 });
116
117 let state = states::Retrieved;
118
119 self.current.clear();
120 self.transition_to(state)
121 }
122}
123
124impl Query<states::Pending> {
125 pub fn new(query: impl Into<String>) -> Self {
126 Self {
127 original: query.into(),
128 ..Default::default()
129 }
130 }
131
132 pub fn transformed_query(&mut self, new_query: impl Into<String>) {
134 let new_query = new_query.into();
135
136 self.transformation_history
137 .push(TransformationEvent::Transformed {
138 before: self.current.clone(),
139 after: new_query.clone(),
140 });
141
142 self.current = new_query;
143 }
144}
145
146impl Query<states::Retrieved> {
147 pub fn new() -> Self {
148 Self::default()
149 }
150
151 pub fn transformed_response(&mut self, new_response: impl Into<String>) {
153 let new_response = new_response.into();
154
155 self.transformation_history
156 .push(TransformationEvent::Transformed {
157 before: self.current.clone(),
158 after: new_response.clone(),
159 });
160
161 self.current = new_response;
162 }
163
164 #[must_use]
166 pub fn answered(mut self, answer: impl Into<String>) -> Query<states::Answered> {
167 self.current = answer.into();
168 let state = states::Answered;
169 self.transition_to(state)
170 }
171}
172
173impl Query<states::Answered> {
174 pub fn new() -> Self {
175 Self::default()
176 }
177
178 pub fn answer(&self) -> &str {
180 &self.current
181 }
182}
183
184pub trait QueryState: Send + Sync + Default {}
186pub trait CanRetrieve: QueryState {}
188
189pub mod states {
191 use super::{CanRetrieve, QueryState};
192
193 #[derive(Debug, Default, Clone, PartialEq)]
194 pub struct Pending;
196
197 #[derive(Debug, Default, Clone, PartialEq)]
198 pub struct Retrieved;
200
201 #[derive(Debug, Default, Clone, PartialEq)]
202 pub struct Answered;
204
205 impl QueryState for Pending {}
206 impl QueryState for Retrieved {}
207 impl QueryState for Answered {}
208
209 impl CanRetrieve for Pending {}
210 impl CanRetrieve for Retrieved {}
211}
212
213impl<T: AsRef<str>> From<T> for Query<states::Pending> {
214 fn from(original: T) -> Self {
215 Self {
216 original: original.as_ref().to_string(),
217 current: original.as_ref().to_string(),
218 state: states::Pending,
219 ..Default::default()
220 }
221 }
222}
223
224#[derive(Clone, PartialEq)]
225pub enum TransformationEvent {
227 Transformed {
228 before: String,
229 after: String,
230 },
231 Retrieved {
232 before: String,
233 after: String,
234 documents: Vec<Document>,
235 },
236}
237
238impl std::fmt::Debug for TransformationEvent {
239 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
240 match self {
241 TransformationEvent::Transformed { before, after } => {
242 write!(
243 f,
244 "Transformed: {} -> {}",
245 &debug_long_utf8(before, 100),
246 &debug_long_utf8(after, 100)
247 )
248 }
249 TransformationEvent::Retrieved {
250 before,
251 after,
252 documents,
253 } => {
254 write!(
255 f,
256 "Retrieved: {} -> {}\nDocuments: {:?}",
257 &debug_long_utf8(before, 100),
258 &debug_long_utf8(after, 100),
259 documents.len()
260 )
261 }
262 }
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269
270 #[test]
271 fn test_query_initial_state() {
272 let query = Query::<states::Pending>::from("test query");
273 assert_eq!(query.original(), "test query");
274 assert_eq!(query.current(), "test query");
275 assert_eq!(query.history().len(), 0);
276 }
277
278 #[test]
279 fn test_query_transformed_query() {
280 let mut query = Query::<states::Pending>::from("test query");
281 query.transformed_query("new query");
282 assert_eq!(query.current(), "new query");
283 assert_eq!(query.history().len(), 1);
284 if let TransformationEvent::Transformed { before, after } = &query.history()[0] {
285 assert_eq!(before, "test query");
286 assert_eq!(after, "new query");
287 } else {
288 panic!("Unexpected event in history");
289 }
290 }
291
292 #[test]
293 fn test_query_retrieved_documents() {
294 let query = Query::<states::Pending>::from("test query");
295 let documents: Vec<Document> = vec!["doc1".into(), "doc2".into()];
296 let query = query.retrieved_documents(documents.clone());
297 assert_eq!(query.documents(), &documents);
298 assert_eq!(query.history().len(), 1);
299 assert!(query.current().is_empty());
300 if let TransformationEvent::Retrieved {
301 before,
302 after,
303 documents: retrieved_docs,
304 } = &query.history()[0]
305 {
306 assert_eq!(before, "test query");
307 assert_eq!(after, "");
308 assert_eq!(retrieved_docs, &documents);
309 } else {
310 panic!("Unexpected event in history");
311 }
312 }
313
314 #[test]
315 fn test_query_transformed_response() {
316 let query = Query::<states::Pending>::from("test query");
317 let documents = vec!["doc1".into(), "doc2".into()];
318 let mut query = query.retrieved_documents(documents.clone());
319 query.transformed_response("new response");
320
321 assert_eq!(query.current(), "new response");
322 assert_eq!(query.history().len(), 2);
323 assert_eq!(query.documents(), &documents);
324 assert_eq!(query.original, "test query");
325 if let TransformationEvent::Transformed { before, after } = &query.history()[1] {
326 assert_eq!(before, "");
327 assert_eq!(after, "new response");
328 } else {
329 panic!("Unexpected event in history");
330 }
331 }
332
333 #[test]
334 fn test_query_answered() {
335 let query = Query::<states::Pending>::from("test query");
336 let documents = vec!["doc1".into(), "doc2".into()];
337 let query = query.retrieved_documents(documents);
338 let query = query.answered("the answer");
339
340 assert_eq!(query.answer(), "the answer");
341 }
342}