summa_core/components/
fruit_extractors.rs

1use std::collections::{HashMap, HashSet};
2
3use rand::rngs::SmallRng;
4use rand::{Rng, SeedableRng};
5use summa_proto::proto;
6use tantivy::aggregation::agg_req::Aggregations;
7use tantivy::aggregation::agg_result::AggregationResults;
8use tantivy::aggregation::AggregationLimitsGuard;
9use tantivy::collector::{FacetCounts, FruitHandle, MultiCollector, MultiFruit};
10use tantivy::query::Query;
11use tantivy::schema::Field;
12use tantivy::{Order, Searcher};
13
14use crate::components::snippet_generator::SnippetGeneratorConfig;
15use crate::components::IndexHolder;
16use crate::errors::{BuilderError, SummaResult};
17use crate::scorers::eval_scorer_tweaker::EvalScorerTweaker;
18use crate::scorers::EvalScorer;
19use crate::{collectors, validators};
20
21#[derive(Clone)]
22pub struct ScoredDocAddress {
23    pub doc_address: tantivy::DocAddress,
24    pub score: Option<proto::Score>,
25}
26
27#[derive(Clone)]
28pub struct ExtractionTooling {
29    pub searcher: Searcher,
30    pub query_fields: Option<HashSet<Field>>,
31    pub multi_fields: HashSet<Field>,
32}
33
34impl ExtractionTooling {
35    pub fn new(searcher: Searcher, query_fields: Option<HashSet<Field>>, multi_fields: HashSet<Field>) -> ExtractionTooling {
36        ExtractionTooling {
37            searcher,
38            query_fields,
39            multi_fields,
40        }
41    }
42}
43
44#[derive(Clone)]
45pub struct PreparedDocumentReferences {
46    pub index_alias: String,
47    pub extraction_tooling: ExtractionTooling,
48    pub snippet_generator_config: Option<SnippetGeneratorConfig>,
49    pub scored_doc_addresses: Vec<ScoredDocAddress>,
50    pub has_next: bool,
51    pub limit: u32,
52    pub offset: u32,
53}
54
55#[derive(Clone)]
56pub enum ReadyCollectorOutput {
57    Aggregation(proto::AggregationCollectorOutput),
58    Count(proto::CountCollectorOutput),
59    Facet(proto::FacetCollectorOutput),
60}
61
62#[derive(Clone)]
63pub enum IntermediateExtractionResult {
64    Ready(ReadyCollectorOutput),
65    PreparedDocumentReferences(PreparedDocumentReferences),
66}
67
68impl IntermediateExtractionResult {
69    pub fn as_document_references(self) -> Option<PreparedDocumentReferences> {
70        if let IntermediateExtractionResult::PreparedDocumentReferences(document_references) = self {
71            Some(document_references)
72        } else {
73            None
74        }
75    }
76
77    pub fn as_count(&self) -> Option<&proto::CountCollectorOutput> {
78        if let IntermediateExtractionResult::Ready(ReadyCollectorOutput::Count(count)) = &self {
79            Some(count)
80        } else {
81            None
82        }
83    }
84}
85
86/// Extracts data from `MultiFruit` and moving it to the `proto::CollectorOutput`
87#[async_trait]
88pub trait FruitExtractor: Sync + Send {
89    fn extract(self: Box<Self>, multi_fruit: &mut MultiFruit) -> SummaResult<IntermediateExtractionResult>;
90}
91
92pub fn build_fruit_extractor(
93    index_holder: &IndexHolder,
94    index_alias: &str,
95    searcher: Searcher,
96    collector_proto: proto::Collector,
97    query: &dyn Query,
98    multi_collector: &mut MultiCollector,
99) -> SummaResult<Box<dyn FruitExtractor>> {
100    match collector_proto.collector {
101        Some(proto::collector::Collector::TopDocs(top_docs_collector_proto)) => {
102            let query_fields = validators::parse_fields(searcher.schema(), &top_docs_collector_proto.fields, &top_docs_collector_proto.excluded_fields)?;
103            let query_fields = (!query_fields.is_empty()).then(|| HashSet::from_iter(query_fields.into_iter().map(|x| x.0)));
104            Ok(match top_docs_collector_proto.scorer {
105                None | Some(proto::Scorer { scorer: None }) => Box::new(
106                    TopDocsBuilder::default()
107                        .handle(
108                            multi_collector.add_collector(
109                                tantivy::collector::TopDocs::with_limit((top_docs_collector_proto.limit + 1) as usize)
110                                    .and_offset(top_docs_collector_proto.offset as usize),
111                            ),
112                        )
113                        .index_alias(index_alias.to_string())
114                        .searcher(searcher)
115                        .query(query.box_clone())
116                        .limit(top_docs_collector_proto.limit)
117                        .offset(top_docs_collector_proto.offset)
118                        .snippet_configs(top_docs_collector_proto.snippet_configs)
119                        .multi_fields(index_holder.multi_fields().clone())
120                        .query_fields(query_fields)
121                        .build()?,
122                ) as Box<dyn FruitExtractor>,
123                Some(proto::Scorer {
124                    scorer: Some(proto::scorer::Scorer::EvalExpr(ref eval_expr)),
125                }) => {
126                    let eval_scorer_seed = EvalScorer::new(eval_expr, searcher.schema())?;
127                    let top_docs_collector = tantivy::collector::TopDocs::with_limit((top_docs_collector_proto.limit + 1) as usize)
128                        .and_offset(top_docs_collector_proto.offset as usize)
129                        .tweak_score(EvalScorerTweaker::new(eval_scorer_seed));
130                    Box::new(
131                        TopDocsBuilder::default()
132                            .handle(multi_collector.add_collector(top_docs_collector))
133                            .index_alias(index_alias.to_string())
134                            .searcher(searcher)
135                            .query(query.box_clone())
136                            .limit(top_docs_collector_proto.limit)
137                            .offset(top_docs_collector_proto.offset)
138                            .snippet_configs(top_docs_collector_proto.snippet_configs)
139                            .multi_fields(index_holder.multi_fields().clone())
140                            .query_fields(query_fields)
141                            .build()?,
142                    ) as Box<dyn FruitExtractor>
143                }
144                Some(proto::Scorer {
145                    scorer: Some(proto::scorer::Scorer::OrderBy(field_name)),
146                }) => {
147                    let top_docs_collector = tantivy::collector::TopDocs::with_limit((top_docs_collector_proto.limit + 1) as usize)
148                        .and_offset(top_docs_collector_proto.offset as usize)
149                        .order_by_fast_field(field_name, Order::Desc);
150                    Box::<TopDocs<u64>>::new(
151                        TopDocsBuilder::default()
152                            .handle(multi_collector.add_collector(top_docs_collector))
153                            .index_alias(index_alias.to_string())
154                            .searcher(searcher)
155                            .query(query.box_clone())
156                            .limit(top_docs_collector_proto.limit)
157                            .offset(top_docs_collector_proto.offset)
158                            .snippet_configs(top_docs_collector_proto.snippet_configs)
159                            .multi_fields(index_holder.multi_fields().clone())
160                            .query_fields(query_fields)
161                            .build()?,
162                    ) as Box<dyn FruitExtractor>
163                }
164            })
165        }
166        Some(proto::collector::Collector::ReservoirSampling(reservoir_sampling_collector_proto)) => {
167            let query_fields = validators::parse_fields(
168                searcher.schema(),
169                &reservoir_sampling_collector_proto.fields,
170                &reservoir_sampling_collector_proto.excluded_fields,
171            )?;
172            let query_fields = (!query_fields.is_empty()).then(|| HashSet::from_iter(query_fields.into_iter().map(|x| x.0)));
173            let reservoir_sampling_collector = collectors::ReservoirSampling::with_limit(reservoir_sampling_collector_proto.limit as usize);
174            Ok(Box::new(
175                ReservoirSamplingBuilder::default()
176                    .handle(multi_collector.add_collector(reservoir_sampling_collector))
177                    .index_alias(index_alias.to_string())
178                    .searcher(searcher)
179                    .multi_fields(index_holder.multi_fields().clone())
180                    .query_fields(query_fields)
181                    .limit(reservoir_sampling_collector_proto.limit)
182                    .build()?,
183            ) as Box<dyn FruitExtractor>)
184        }
185        Some(proto::collector::Collector::Count(_)) => Ok(Box::new(Count(multi_collector.add_collector(tantivy::collector::Count))) as Box<dyn FruitExtractor>),
186        Some(proto::collector::Collector::Facet(facet_collector_proto)) => {
187            let mut facet_collector = tantivy::collector::FacetCollector::for_field(facet_collector_proto.field);
188            for facet in &facet_collector_proto.facets {
189                facet_collector.add_facet(facet);
190            }
191            Ok(Box::new(Facet(multi_collector.add_collector(facet_collector))) as Box<dyn FruitExtractor>)
192        }
193        Some(proto::collector::Collector::Aggregation(aggregation_collector_proto)) => {
194            let agg_req: Aggregations = serde_json::from_str(&aggregation_collector_proto.aggregations)?;
195            let aggregation_collector =
196                tantivy::aggregation::AggregationCollector::from_aggs(agg_req, AggregationLimitsGuard::new(Some(16_000_000_000), Some(100_000_000)));
197            Ok(Box::new(Aggregation(multi_collector.add_collector(aggregation_collector))) as Box<dyn FruitExtractor>)
198        }
199        None => Ok(Box::new(Count(multi_collector.add_collector(tantivy::collector::Count))) as Box<dyn FruitExtractor>),
200    }
201}
202
203#[derive(Builder)]
204#[builder(pattern = "owned", build_fn(error = "BuilderError"))]
205pub struct TopDocs<T: 'static + Copy + Into<proto::Score> + Sync + Send> {
206    searcher: Searcher,
207    index_alias: String,
208    handle: FruitHandle<Vec<(T, tantivy::DocAddress)>>,
209    limit: u32,
210    offset: u32,
211    snippet_configs: HashMap<String, u32>,
212    query: Box<dyn Query>,
213    #[builder(default = "None")]
214    query_fields: Option<HashSet<Field>>,
215    multi_fields: HashSet<Field>,
216}
217
218#[async_trait]
219impl<T: 'static + Copy + Into<proto::Score> + Sync + Send> FruitExtractor for TopDocs<T> {
220    fn extract(self: Box<Self>, multi_fruit: &mut MultiFruit) -> SummaResult<IntermediateExtractionResult> {
221        let fruit = self.handle.extract(multi_fruit);
222        let length = fruit.len();
223        let doc_addresses = fruit
224            .into_iter()
225            .take(std::cmp::min(self.limit as usize, length))
226            .map(|(score, doc_address)| ScoredDocAddress {
227                doc_address,
228                score: Some(score.into()),
229            })
230            .collect();
231        Ok(IntermediateExtractionResult::PreparedDocumentReferences(PreparedDocumentReferences {
232            index_alias: self.index_alias,
233            extraction_tooling: ExtractionTooling::new(self.searcher.clone(), self.query_fields, self.multi_fields),
234            snippet_generator_config: Some(SnippetGeneratorConfig::new(self.searcher, self.query, self.snippet_configs)),
235            scored_doc_addresses: doc_addresses,
236            has_next: length > self.limit as usize,
237            limit: self.limit,
238            offset: self.offset,
239        }))
240    }
241}
242
243#[derive(Builder)]
244#[builder(pattern = "owned", build_fn(error = "BuilderError"))]
245pub struct ReservoirSampling {
246    searcher: Searcher,
247    index_alias: String,
248    handle: FruitHandle<Vec<tantivy::DocAddress>>,
249    #[builder(default = "None")]
250    query_fields: Option<HashSet<Field>>,
251    multi_fields: HashSet<Field>,
252    limit: u32,
253}
254
255impl FruitExtractor for ReservoirSampling {
256    fn extract(self: Box<Self>, multi_fruit: &mut MultiFruit) -> SummaResult<IntermediateExtractionResult> {
257        let mut rng = SmallRng::from_entropy();
258        Ok(IntermediateExtractionResult::PreparedDocumentReferences(PreparedDocumentReferences {
259            scored_doc_addresses: self
260                .handle
261                .extract(multi_fruit)
262                .into_iter()
263                .map(|doc_address| ScoredDocAddress {
264                    doc_address,
265                    score: Some(rng.gen::<f64>().into()),
266                })
267                .collect(),
268            index_alias: self.index_alias,
269            has_next: false,
270            limit: self.limit,
271            extraction_tooling: ExtractionTooling::new(self.searcher, self.query_fields, self.multi_fields),
272            snippet_generator_config: None,
273            offset: 0,
274        }))
275    }
276}
277
278pub struct Count(pub FruitHandle<usize>);
279
280impl FruitExtractor for Count {
281    fn extract(self: Box<Self>, multi_fruit: &mut MultiFruit) -> SummaResult<IntermediateExtractionResult> {
282        Ok(IntermediateExtractionResult::Ready(ReadyCollectorOutput::Count(proto::CountCollectorOutput {
283            count: self.0.extract(multi_fruit) as u32,
284        })))
285    }
286}
287
288pub struct Facet(pub FruitHandle<FacetCounts>);
289
290impl FruitExtractor for Facet {
291    fn extract(self: Box<Self>, multi_fruit: &mut MultiFruit) -> SummaResult<IntermediateExtractionResult> {
292        Ok(IntermediateExtractionResult::Ready(ReadyCollectorOutput::Facet(proto::FacetCollectorOutput {
293            facet_counts: self.0.extract(multi_fruit).get("").map(|(facet, count)| (facet.to_string(), count)).collect(),
294        })))
295    }
296}
297
298pub struct Aggregation(pub FruitHandle<AggregationResults>);
299
300impl FruitExtractor for Aggregation {
301    fn extract(self: Box<Self>, multi_fruit: &mut MultiFruit) -> SummaResult<IntermediateExtractionResult> {
302        Ok(IntermediateExtractionResult::Ready(ReadyCollectorOutput::Aggregation(
303            proto::AggregationCollectorOutput {
304                aggregation_results: serde_json::to_string(&self.0.extract(multi_fruit).0)?,
305            },
306        )))
307    }
308}