use crate::collector::Collector;
use crate::collector::SegmentCollector;
use crate::docset::SkipResult;
use crate::fastfield::FacetReader;
use crate::schema::Facet;
use crate::schema::Field;
use crate::DocId;
use crate::Result;
use crate::Score;
use crate::SegmentLocalId;
use crate::SegmentReader;
use crate::TantivyError;
use std::cmp::Ordering;
use std::collections::btree_map;
use std::collections::BTreeMap;
use std::collections::BTreeSet;
use std::collections::BinaryHeap;
use std::collections::Bound;
use std::iter::Peekable;
use std::{u64, usize};
struct Hit<'a> {
count: u64,
facet: &'a Facet,
}
impl<'a> Eq for Hit<'a> {}
impl<'a> PartialEq<Hit<'a>> for Hit<'a> {
fn eq(&self, other: &Hit<'_>) -> bool {
self.count == other.count
}
}
impl<'a> PartialOrd<Hit<'a>> for Hit<'a> {
fn partial_cmp(&self, other: &Hit<'_>) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<'a> Ord for Hit<'a> {
fn cmp(&self, other: &Self) -> Ordering {
other.count.cmp(&self.count)
}
}
fn facet_depth(facet_bytes: &[u8]) -> usize {
if facet_bytes.is_empty() {
0
} else {
facet_bytes.iter().cloned().filter(|b| *b == 0u8).count() + 1
}
}
pub struct FacetCollector {
field: Field,
facets: BTreeSet<Facet>,
}
pub struct FacetSegmentCollector {
reader: FacetReader,
facet_ords_buf: Vec<u64>,
collapse_mapping: Vec<usize>,
counts: Vec<u64>,
collapse_facet_ords: Vec<u64>,
}
fn skip<'a, I: Iterator<Item = &'a Facet>>(
target: &[u8],
collapse_it: &mut Peekable<I>,
) -> SkipResult {
loop {
match collapse_it.peek() {
Some(facet_bytes) => match facet_bytes.encoded_str().as_bytes().cmp(target) {
Ordering::Less => {}
Ordering::Greater => {
return SkipResult::OverStep;
}
Ordering::Equal => {
return SkipResult::Reached;
}
},
None => {
return SkipResult::End;
}
}
collapse_it.next();
}
}
impl FacetCollector {
pub fn for_field(field: Field) -> FacetCollector {
FacetCollector {
field,
facets: BTreeSet::default(),
}
}
pub fn add_facet<T>(&mut self, facet_from: T)
where
Facet: From<T>,
{
let facet = Facet::from(facet_from);
for old_facet in &self.facets {
assert!(
!old_facet.is_prefix_of(&facet),
"Tried to add a facet which is a descendant of an already added facet."
);
assert!(
!facet.is_prefix_of(old_facet),
"Tried to add a facet which is an ancestor of an already added facet."
);
}
self.facets.insert(facet);
}
}
impl Collector for FacetCollector {
type Fruit = FacetCounts;
type Child = FacetSegmentCollector;
fn for_segment(
&self,
_: SegmentLocalId,
reader: &SegmentReader,
) -> Result<FacetSegmentCollector> {
let field_name = reader.schema().get_field_name(self.field);
let facet_reader = reader.facet_reader(self.field).ok_or_else(|| {
TantivyError::SchemaError(format!("Field {:?} is not a facet field.", field_name))
})?;
let mut collapse_mapping = Vec::new();
let mut counts = Vec::new();
let mut collapse_facet_ords = Vec::new();
let mut collapse_facet_it = self.facets.iter().peekable();
collapse_facet_ords.push(0);
{
let mut facet_streamer = facet_reader.facet_dict().range().into_stream();
if facet_streamer.advance() {
'outer: loop {
let skip_result = skip(facet_streamer.key(), &mut collapse_facet_it);
match skip_result {
SkipResult::Reached => {
let collapse_depth = facet_depth(facet_streamer.key());
let mut collapsed_id = 0;
collapse_mapping.push(0);
while facet_streamer.advance() {
let depth = facet_depth(facet_streamer.key());
if depth <= collapse_depth {
continue 'outer;
}
if depth == collapse_depth + 1 {
collapsed_id = collapse_facet_ords.len();
collapse_facet_ords.push(facet_streamer.term_ord());
collapse_mapping.push(collapsed_id);
} else {
collapse_mapping.push(collapsed_id);
}
}
break;
}
SkipResult::End | SkipResult::OverStep => {
collapse_mapping.push(0);
if !facet_streamer.advance() {
break;
}
}
}
}
}
}
counts.resize(collapse_facet_ords.len(), 0);
Ok(FacetSegmentCollector {
reader: facet_reader,
facet_ords_buf: Vec::with_capacity(255),
collapse_mapping,
counts,
collapse_facet_ords,
})
}
fn requires_scoring(&self) -> bool {
false
}
fn merge_fruits(&self, segments_facet_counts: Vec<FacetCounts>) -> Result<FacetCounts> {
let mut facet_counts: BTreeMap<Facet, u64> = BTreeMap::new();
for segment_facet_counts in segments_facet_counts {
for (facet, count) in segment_facet_counts.facet_counts {
*(facet_counts.entry(facet).or_insert(0)) += count;
}
}
Ok(FacetCounts { facet_counts })
}
}
impl SegmentCollector for FacetSegmentCollector {
type Fruit = FacetCounts;
fn collect(&mut self, doc: DocId, _: Score) {
self.reader.facet_ords(doc, &mut self.facet_ords_buf);
let mut previous_collapsed_ord: usize = usize::MAX;
for &facet_ord in &self.facet_ords_buf {
let collapsed_ord = self.collapse_mapping[facet_ord as usize];
self.counts[collapsed_ord] += if collapsed_ord == previous_collapsed_ord {
0
} else {
1
};
previous_collapsed_ord = collapsed_ord;
}
}
fn harvest(self) -> FacetCounts {
let mut facet_counts = BTreeMap::new();
let facet_dict = self.reader.facet_dict();
for (collapsed_facet_ord, count) in self.counts.iter().cloned().enumerate() {
if count == 0 {
continue;
}
let mut facet = vec![];
let facet_ord = self.collapse_facet_ords[collapsed_facet_ord];
facet_dict.ord_to_term(facet_ord as u64, &mut facet);
facet_counts.insert(Facet::from_encoded(facet).unwrap(), count);
}
FacetCounts { facet_counts }
}
}
pub struct FacetCounts {
facet_counts: BTreeMap<Facet, u64>,
}
pub struct FacetChildIterator<'a> {
underlying: btree_map::Range<'a, Facet, u64>,
}
impl<'a> Iterator for FacetChildIterator<'a> {
type Item = (&'a Facet, u64);
fn next(&mut self) -> Option<Self::Item> {
self.underlying.next().map(|(facet, count)| (facet, *count))
}
}
impl FacetCounts {
pub fn get<T>(&self, facet_from: T) -> FacetChildIterator<'_>
where
Facet: From<T>,
{
let facet = Facet::from(facet_from);
let left_bound = Bound::Excluded(facet.clone());
let right_bound = if facet.is_root() {
Bound::Unbounded
} else {
let mut facet_after_bytes: String = facet.encoded_str().to_owned();
facet_after_bytes.push('\u{1}');
let facet_after = Facet::from_encoded_string(facet_after_bytes);
Bound::Excluded(facet_after)
};
let underlying: btree_map::Range<'_, _, _> =
self.facet_counts.range((left_bound, right_bound));
FacetChildIterator { underlying }
}
pub fn top_k<T>(&self, facet: T, k: usize) -> Vec<(&Facet, u64)>
where
Facet: From<T>,
{
let mut heap = BinaryHeap::with_capacity(k);
let mut it = self.get(facet);
for (facet, count) in (&mut it).take(k) {
heap.push(Hit { count, facet });
}
let mut lowest_count: u64 = heap.peek().map(|hit| hit.count).unwrap_or(u64::MIN);
for (facet, count) in it {
if count > lowest_count {
if let Some(mut head) = heap.peek_mut() {
*head = Hit { count, facet };
}
if let Some(head) = heap.peek() {
lowest_count = head.count;
}
}
}
heap.into_sorted_vec()
.into_iter()
.map(|hit| (hit.facet, hit.count))
.collect::<Vec<_>>()
}
}
#[cfg(test)]
mod tests {
use super::{FacetCollector, FacetCounts};
use crate::core::Index;
use crate::query::AllQuery;
use crate::schema::{Document, Facet, Field, Schema};
use rand::distributions::Uniform;
use rand::prelude::SliceRandom;
use rand::{thread_rng, Rng};
use std::iter;
#[test]
fn test_facet_collector_drilldown() {
let mut schema_builder = Schema::builder();
let facet_field = schema_builder.add_facet_field("facet");
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap();
let num_facets: usize = 3 * 4 * 5;
let facets: Vec<Facet> = (0..num_facets)
.map(|mut n| {
let top = n % 3;
n /= 3;
let mid = n % 4;
n /= 4;
let leaf = n % 5;
Facet::from(&format!("/top{}/mid{}/leaf{}", top, mid, leaf))
})
.collect();
for i in 0..num_facets * 10 {
let mut doc = Document::new();
doc.add_facet(facet_field, facets[i % num_facets].clone());
index_writer.add_document(doc);
}
index_writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let mut facet_collector = FacetCollector::for_field(facet_field);
facet_collector.add_facet(Facet::from("/top1"));
let counts = searcher.search(&AllQuery, &facet_collector).unwrap();
{
let facets: Vec<(String, u64)> = counts
.get("/top1")
.map(|(facet, count)| (facet.to_string(), count))
.collect();
assert_eq!(
facets,
[
("/top1/mid0", 50),
("/top1/mid1", 50),
("/top1/mid2", 50),
("/top1/mid3", 50),
]
.iter()
.map(|&(facet_str, count)| (String::from(facet_str), count))
.collect::<Vec<_>>()
);
}
}
#[test]
#[should_panic(expected = "Tried to add a facet which is a descendant of \
an already added facet.")]
fn test_misused_facet_collector() {
let mut facet_collector = FacetCollector::for_field(Field(0));
facet_collector.add_facet(Facet::from("/country"));
facet_collector.add_facet(Facet::from("/country/europe"));
}
#[test]
fn test_doc_unsorted_multifacet() {
let mut schema_builder = Schema::builder();
let facet_field = schema_builder.add_facet_field("facets");
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap();
index_writer.add_document(doc!(
facet_field => Facet::from_text(&"/subjects/A/a"),
facet_field => Facet::from_text(&"/subjects/B/a"),
facet_field => Facet::from_text(&"/subjects/A/b"),
facet_field => Facet::from_text(&"/subjects/B/b"),
));
index_writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
assert_eq!(searcher.num_docs(), 1);
let mut facet_collector = FacetCollector::for_field(facet_field);
facet_collector.add_facet("/subjects");
let counts = searcher.search(&AllQuery, &facet_collector).unwrap();
let facets: Vec<(&Facet, u64)> = counts.get("/subjects").collect();
assert_eq!(facets[0].1, 1);
}
#[test]
fn test_non_used_facet_collector() {
let mut facet_collector = FacetCollector::for_field(Field(0));
facet_collector.add_facet(Facet::from("/country"));
facet_collector.add_facet(Facet::from("/countryeurope"));
}
#[test]
fn test_facet_collector_topk() {
let mut schema_builder = Schema::builder();
let facet_field = schema_builder.add_facet_field("facet");
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let uniform = Uniform::new_inclusive(1, 100_000);
let mut docs: Vec<Document> = vec![("a", 10), ("b", 100), ("c", 7), ("d", 12), ("e", 21)]
.into_iter()
.flat_map(|(c, count)| {
let facet = Facet::from(&format!("/facet/{}", c));
let doc = doc!(facet_field => facet);
iter::repeat(doc).take(count)
})
.map(|mut doc| {
doc.add_facet(
facet_field,
&format!("/facet/{}", thread_rng().sample(&uniform)),
);
doc
})
.collect();
docs[..].shuffle(&mut thread_rng());
let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap();
for doc in docs {
index_writer.add_document(doc);
}
index_writer.commit().unwrap();
let searcher = index.reader().unwrap().searcher();
let mut facet_collector = FacetCollector::for_field(facet_field);
facet_collector.add_facet("/facet");
let counts: FacetCounts = searcher.search(&AllQuery, &facet_collector).unwrap();
{
let facets: Vec<(&Facet, u64)> = counts.top_k("/facet", 3);
assert_eq!(
facets,
vec![
(&Facet::from("/facet/b"), 100),
(&Facet::from("/facet/e"), 21),
(&Facet::from("/facet/d"), 12),
]
);
}
}
}
#[cfg(all(test, feature = "unstable"))]
mod bench {
use collector::FacetCollector;
use query::AllQuery;
use rand::{thread_rng, Rng};
use schema::Facet;
use schema::Schema;
use test::Bencher;
use Index;
#[bench]
fn bench_facet_collector(b: &mut Bencher) {
let mut schema_builder = Schema::builder();
let facet_field = schema_builder.add_facet_field("facet");
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut docs = vec![];
for val in 0..50 {
let facet = Facet::from(&format!("/facet_{}", val));
for _ in 0..val * val {
docs.push(doc!(facet_field=>facet.clone()));
}
}
thread_rng().shuffle(&mut docs[..]);
let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap();
for doc in docs {
index_writer.add_document(doc);
}
index_writer.commit().unwrap();
let reader = index.reader().unwrap();
b.iter(|| {
let searcher = index.searcher();
let facet_collector = FacetCollector::for_field(facet_field);
searcher.search(&AllQuery, &facet_collector).unwrap();
});
}
}