1use std::collections::{HashMap, HashSet};
2
3use rand::rngs::SmallRng;
4use rand::{Rng, SeedableRng};
5use summa_proto::proto;
6use tantivy::aggregation::agg_req::Aggregations;
7use tantivy::aggregation::agg_result::AggregationResults;
8use tantivy::aggregation::AggregationLimitsGuard;
9use tantivy::collector::{FacetCounts, FruitHandle, MultiCollector, MultiFruit};
10use tantivy::query::Query;
11use tantivy::schema::Field;
12use tantivy::{Order, Searcher};
13
14use crate::components::snippet_generator::SnippetGeneratorConfig;
15use crate::components::IndexHolder;
16use crate::errors::{BuilderError, SummaResult};
17use crate::scorers::eval_scorer_tweaker::EvalScorerTweaker;
18use crate::scorers::EvalScorer;
19use crate::{collectors, validators};
20
21#[derive(Clone)]
22pub struct ScoredDocAddress {
23 pub doc_address: tantivy::DocAddress,
24 pub score: Option<proto::Score>,
25}
26
27#[derive(Clone)]
28pub struct ExtractionTooling {
29 pub searcher: Searcher,
30 pub query_fields: Option<HashSet<Field>>,
31 pub multi_fields: HashSet<Field>,
32}
33
34impl ExtractionTooling {
35 pub fn new(searcher: Searcher, query_fields: Option<HashSet<Field>>, multi_fields: HashSet<Field>) -> ExtractionTooling {
36 ExtractionTooling {
37 searcher,
38 query_fields,
39 multi_fields,
40 }
41 }
42}
43
44#[derive(Clone)]
45pub struct PreparedDocumentReferences {
46 pub index_alias: String,
47 pub extraction_tooling: ExtractionTooling,
48 pub snippet_generator_config: Option<SnippetGeneratorConfig>,
49 pub scored_doc_addresses: Vec<ScoredDocAddress>,
50 pub has_next: bool,
51 pub limit: u32,
52 pub offset: u32,
53}
54
55#[derive(Clone)]
56pub enum ReadyCollectorOutput {
57 Aggregation(proto::AggregationCollectorOutput),
58 Count(proto::CountCollectorOutput),
59 Facet(proto::FacetCollectorOutput),
60}
61
62#[derive(Clone)]
63pub enum IntermediateExtractionResult {
64 Ready(ReadyCollectorOutput),
65 PreparedDocumentReferences(PreparedDocumentReferences),
66}
67
68impl IntermediateExtractionResult {
69 pub fn as_document_references(self) -> Option<PreparedDocumentReferences> {
70 if let IntermediateExtractionResult::PreparedDocumentReferences(document_references) = self {
71 Some(document_references)
72 } else {
73 None
74 }
75 }
76
77 pub fn as_count(&self) -> Option<&proto::CountCollectorOutput> {
78 if let IntermediateExtractionResult::Ready(ReadyCollectorOutput::Count(count)) = &self {
79 Some(count)
80 } else {
81 None
82 }
83 }
84}
85
86#[async_trait]
88pub trait FruitExtractor: Sync + Send {
89 fn extract(self: Box<Self>, multi_fruit: &mut MultiFruit) -> SummaResult<IntermediateExtractionResult>;
90}
91
92pub fn build_fruit_extractor(
93 index_holder: &IndexHolder,
94 index_alias: &str,
95 searcher: Searcher,
96 collector_proto: proto::Collector,
97 query: &dyn Query,
98 multi_collector: &mut MultiCollector,
99) -> SummaResult<Box<dyn FruitExtractor>> {
100 match collector_proto.collector {
101 Some(proto::collector::Collector::TopDocs(top_docs_collector_proto)) => {
102 let query_fields = validators::parse_fields(searcher.schema(), &top_docs_collector_proto.fields, &top_docs_collector_proto.excluded_fields)?;
103 let query_fields = (!query_fields.is_empty()).then(|| HashSet::from_iter(query_fields.into_iter().map(|x| x.0)));
104 Ok(match top_docs_collector_proto.scorer {
105 None | Some(proto::Scorer { scorer: None }) => Box::new(
106 TopDocsBuilder::default()
107 .handle(
108 multi_collector.add_collector(
109 tantivy::collector::TopDocs::with_limit((top_docs_collector_proto.limit + 1) as usize)
110 .and_offset(top_docs_collector_proto.offset as usize),
111 ),
112 )
113 .index_alias(index_alias.to_string())
114 .searcher(searcher)
115 .query(query.box_clone())
116 .limit(top_docs_collector_proto.limit)
117 .offset(top_docs_collector_proto.offset)
118 .snippet_configs(top_docs_collector_proto.snippet_configs)
119 .multi_fields(index_holder.multi_fields().clone())
120 .query_fields(query_fields)
121 .build()?,
122 ) as Box<dyn FruitExtractor>,
123 Some(proto::Scorer {
124 scorer: Some(proto::scorer::Scorer::EvalExpr(ref eval_expr)),
125 }) => {
126 let eval_scorer_seed = EvalScorer::new(eval_expr, searcher.schema())?;
127 let top_docs_collector = tantivy::collector::TopDocs::with_limit((top_docs_collector_proto.limit + 1) as usize)
128 .and_offset(top_docs_collector_proto.offset as usize)
129 .tweak_score(EvalScorerTweaker::new(eval_scorer_seed));
130 Box::new(
131 TopDocsBuilder::default()
132 .handle(multi_collector.add_collector(top_docs_collector))
133 .index_alias(index_alias.to_string())
134 .searcher(searcher)
135 .query(query.box_clone())
136 .limit(top_docs_collector_proto.limit)
137 .offset(top_docs_collector_proto.offset)
138 .snippet_configs(top_docs_collector_proto.snippet_configs)
139 .multi_fields(index_holder.multi_fields().clone())
140 .query_fields(query_fields)
141 .build()?,
142 ) as Box<dyn FruitExtractor>
143 }
144 Some(proto::Scorer {
145 scorer: Some(proto::scorer::Scorer::OrderBy(field_name)),
146 }) => {
147 let top_docs_collector = tantivy::collector::TopDocs::with_limit((top_docs_collector_proto.limit + 1) as usize)
148 .and_offset(top_docs_collector_proto.offset as usize)
149 .order_by_fast_field(field_name, Order::Desc);
150 Box::<TopDocs<u64>>::new(
151 TopDocsBuilder::default()
152 .handle(multi_collector.add_collector(top_docs_collector))
153 .index_alias(index_alias.to_string())
154 .searcher(searcher)
155 .query(query.box_clone())
156 .limit(top_docs_collector_proto.limit)
157 .offset(top_docs_collector_proto.offset)
158 .snippet_configs(top_docs_collector_proto.snippet_configs)
159 .multi_fields(index_holder.multi_fields().clone())
160 .query_fields(query_fields)
161 .build()?,
162 ) as Box<dyn FruitExtractor>
163 }
164 })
165 }
166 Some(proto::collector::Collector::ReservoirSampling(reservoir_sampling_collector_proto)) => {
167 let query_fields = validators::parse_fields(
168 searcher.schema(),
169 &reservoir_sampling_collector_proto.fields,
170 &reservoir_sampling_collector_proto.excluded_fields,
171 )?;
172 let query_fields = (!query_fields.is_empty()).then(|| HashSet::from_iter(query_fields.into_iter().map(|x| x.0)));
173 let reservoir_sampling_collector = collectors::ReservoirSampling::with_limit(reservoir_sampling_collector_proto.limit as usize);
174 Ok(Box::new(
175 ReservoirSamplingBuilder::default()
176 .handle(multi_collector.add_collector(reservoir_sampling_collector))
177 .index_alias(index_alias.to_string())
178 .searcher(searcher)
179 .multi_fields(index_holder.multi_fields().clone())
180 .query_fields(query_fields)
181 .limit(reservoir_sampling_collector_proto.limit)
182 .build()?,
183 ) as Box<dyn FruitExtractor>)
184 }
185 Some(proto::collector::Collector::Count(_)) => Ok(Box::new(Count(multi_collector.add_collector(tantivy::collector::Count))) as Box<dyn FruitExtractor>),
186 Some(proto::collector::Collector::Facet(facet_collector_proto)) => {
187 let mut facet_collector = tantivy::collector::FacetCollector::for_field(facet_collector_proto.field);
188 for facet in &facet_collector_proto.facets {
189 facet_collector.add_facet(facet);
190 }
191 Ok(Box::new(Facet(multi_collector.add_collector(facet_collector))) as Box<dyn FruitExtractor>)
192 }
193 Some(proto::collector::Collector::Aggregation(aggregation_collector_proto)) => {
194 let agg_req: Aggregations = serde_json::from_str(&aggregation_collector_proto.aggregations)?;
195 let aggregation_collector =
196 tantivy::aggregation::AggregationCollector::from_aggs(agg_req, AggregationLimitsGuard::new(Some(16_000_000_000), Some(100_000_000)));
197 Ok(Box::new(Aggregation(multi_collector.add_collector(aggregation_collector))) as Box<dyn FruitExtractor>)
198 }
199 None => Ok(Box::new(Count(multi_collector.add_collector(tantivy::collector::Count))) as Box<dyn FruitExtractor>),
200 }
201}
202
203#[derive(Builder)]
204#[builder(pattern = "owned", build_fn(error = "BuilderError"))]
205pub struct TopDocs<T: 'static + Copy + Into<proto::Score> + Sync + Send> {
206 searcher: Searcher,
207 index_alias: String,
208 handle: FruitHandle<Vec<(T, tantivy::DocAddress)>>,
209 limit: u32,
210 offset: u32,
211 snippet_configs: HashMap<String, u32>,
212 query: Box<dyn Query>,
213 #[builder(default = "None")]
214 query_fields: Option<HashSet<Field>>,
215 multi_fields: HashSet<Field>,
216}
217
218#[async_trait]
219impl<T: 'static + Copy + Into<proto::Score> + Sync + Send> FruitExtractor for TopDocs<T> {
220 fn extract(self: Box<Self>, multi_fruit: &mut MultiFruit) -> SummaResult<IntermediateExtractionResult> {
221 let fruit = self.handle.extract(multi_fruit);
222 let length = fruit.len();
223 let doc_addresses = fruit
224 .into_iter()
225 .take(std::cmp::min(self.limit as usize, length))
226 .map(|(score, doc_address)| ScoredDocAddress {
227 doc_address,
228 score: Some(score.into()),
229 })
230 .collect();
231 Ok(IntermediateExtractionResult::PreparedDocumentReferences(PreparedDocumentReferences {
232 index_alias: self.index_alias,
233 extraction_tooling: ExtractionTooling::new(self.searcher.clone(), self.query_fields, self.multi_fields),
234 snippet_generator_config: Some(SnippetGeneratorConfig::new(self.searcher, self.query, self.snippet_configs)),
235 scored_doc_addresses: doc_addresses,
236 has_next: length > self.limit as usize,
237 limit: self.limit,
238 offset: self.offset,
239 }))
240 }
241}
242
243#[derive(Builder)]
244#[builder(pattern = "owned", build_fn(error = "BuilderError"))]
245pub struct ReservoirSampling {
246 searcher: Searcher,
247 index_alias: String,
248 handle: FruitHandle<Vec<tantivy::DocAddress>>,
249 #[builder(default = "None")]
250 query_fields: Option<HashSet<Field>>,
251 multi_fields: HashSet<Field>,
252 limit: u32,
253}
254
255impl FruitExtractor for ReservoirSampling {
256 fn extract(self: Box<Self>, multi_fruit: &mut MultiFruit) -> SummaResult<IntermediateExtractionResult> {
257 let mut rng = SmallRng::from_entropy();
258 Ok(IntermediateExtractionResult::PreparedDocumentReferences(PreparedDocumentReferences {
259 scored_doc_addresses: self
260 .handle
261 .extract(multi_fruit)
262 .into_iter()
263 .map(|doc_address| ScoredDocAddress {
264 doc_address,
265 score: Some(rng.gen::<f64>().into()),
266 })
267 .collect(),
268 index_alias: self.index_alias,
269 has_next: false,
270 limit: self.limit,
271 extraction_tooling: ExtractionTooling::new(self.searcher, self.query_fields, self.multi_fields),
272 snippet_generator_config: None,
273 offset: 0,
274 }))
275 }
276}
277
278pub struct Count(pub FruitHandle<usize>);
279
280impl FruitExtractor for Count {
281 fn extract(self: Box<Self>, multi_fruit: &mut MultiFruit) -> SummaResult<IntermediateExtractionResult> {
282 Ok(IntermediateExtractionResult::Ready(ReadyCollectorOutput::Count(proto::CountCollectorOutput {
283 count: self.0.extract(multi_fruit) as u32,
284 })))
285 }
286}
287
288pub struct Facet(pub FruitHandle<FacetCounts>);
289
290impl FruitExtractor for Facet {
291 fn extract(self: Box<Self>, multi_fruit: &mut MultiFruit) -> SummaResult<IntermediateExtractionResult> {
292 Ok(IntermediateExtractionResult::Ready(ReadyCollectorOutput::Facet(proto::FacetCollectorOutput {
293 facet_counts: self.0.extract(multi_fruit).get("").map(|(facet, count)| (facet.to_string(), count)).collect(),
294 })))
295 }
296}
297
298pub struct Aggregation(pub FruitHandle<AggregationResults>);
299
300impl FruitExtractor for Aggregation {
301 fn extract(self: Box<Self>, multi_fruit: &mut MultiFruit) -> SummaResult<IntermediateExtractionResult> {
302 Ok(IntermediateExtractionResult::Ready(ReadyCollectorOutput::Aggregation(
303 proto::AggregationCollectorOutput {
304 aggregation_results: serde_json::to_string(&self.0.extract(multi_fruit).0)?,
305 },
306 )))
307 }
308}