use crate::docset::DocSet;
use crate::query::score_combiner::ScoreCombiner;
use crate::query::Scorer;
use crate::DocId;
use crate::Score;
use std::marker::PhantomData;
pub struct RequiredOptionalScorer<TReqScorer, TOptScorer, TScoreCombiner: ScoreCombiner> {
req_scorer: TReqScorer,
opt_scorer: TOptScorer,
score_cache: Option<Score>,
_phantom: PhantomData<TScoreCombiner>,
}
impl<TReqScorer, TOptScorer, TScoreCombiner>
RequiredOptionalScorer<TReqScorer, TOptScorer, TScoreCombiner>
where
TOptScorer: DocSet,
TScoreCombiner: ScoreCombiner,
{
pub fn new(
req_scorer: TReqScorer,
opt_scorer: TOptScorer,
) -> RequiredOptionalScorer<TReqScorer, TOptScorer, TScoreCombiner> {
RequiredOptionalScorer {
req_scorer,
opt_scorer,
score_cache: None,
_phantom: PhantomData,
}
}
}
impl<TReqScorer, TOptScorer, TScoreCombiner> DocSet
for RequiredOptionalScorer<TReqScorer, TOptScorer, TScoreCombiner>
where
TReqScorer: DocSet,
TOptScorer: DocSet,
TScoreCombiner: ScoreCombiner,
{
fn advance(&mut self) -> DocId {
self.score_cache = None;
self.req_scorer.advance()
}
fn doc(&self) -> DocId {
self.req_scorer.doc()
}
fn size_hint(&self) -> u32 {
self.req_scorer.size_hint()
}
}
impl<TReqScorer, TOptScorer, TScoreCombiner> Scorer
for RequiredOptionalScorer<TReqScorer, TOptScorer, TScoreCombiner>
where
TReqScorer: Scorer,
TOptScorer: Scorer,
TScoreCombiner: ScoreCombiner,
{
fn score(&mut self) -> Score {
if let Some(score) = self.score_cache {
return score;
}
let doc = self.doc();
let mut score_combiner = TScoreCombiner::default();
score_combiner.update(&mut self.req_scorer);
if self.opt_scorer.doc() <= doc && self.opt_scorer.seek(doc) == doc {
score_combiner.update(&mut self.opt_scorer);
}
let score = score_combiner.score();
self.score_cache = Some(score);
score
}
}
#[cfg(test)]
mod tests {
use super::RequiredOptionalScorer;
use crate::docset::{DocSet, TERMINATED};
use crate::postings::tests::test_skip_against_unoptimized;
use crate::query::score_combiner::{DoNothingCombiner, SumCombiner};
use crate::query::ConstScorer;
use crate::query::Scorer;
use crate::query::VecDocSet;
use crate::tests::sample_with_seed;
#[test]
fn test_reqopt_scorer_empty() {
let req = vec![1, 3, 7];
let mut reqoptscorer: RequiredOptionalScorer<_, _, SumCombiner> =
RequiredOptionalScorer::new(
ConstScorer::from(VecDocSet::from(req.clone())),
ConstScorer::from(VecDocSet::from(vec![])),
);
let mut docs = vec![];
while reqoptscorer.doc() != TERMINATED {
docs.push(reqoptscorer.doc());
reqoptscorer.advance();
}
assert_eq!(docs, req);
}
#[test]
fn test_reqopt_scorer() {
let mut reqoptscorer: RequiredOptionalScorer<_, _, SumCombiner> =
RequiredOptionalScorer::new(
ConstScorer::new(VecDocSet::from(vec![1, 3, 7, 8, 9, 10, 13, 15]), 1.0),
ConstScorer::new(VecDocSet::from(vec![1, 2, 7, 11, 12, 15]), 1.0),
);
{
assert_eq!(reqoptscorer.doc(), 1);
assert_eq!(reqoptscorer.score(), 2.0);
}
{
assert_eq!(reqoptscorer.advance(), 3);
assert_eq!(reqoptscorer.doc(), 3);
assert_eq!(reqoptscorer.score(), 1.0);
}
{
assert_eq!(reqoptscorer.advance(), 7);
assert_eq!(reqoptscorer.doc(), 7);
assert_eq!(reqoptscorer.score(), 2.0);
}
{
assert_eq!(reqoptscorer.advance(), 8);
assert_eq!(reqoptscorer.doc(), 8);
assert_eq!(reqoptscorer.score(), 1.0);
}
{
assert_eq!(reqoptscorer.advance(), 9);
assert_eq!(reqoptscorer.doc(), 9);
assert_eq!(reqoptscorer.score(), 1.0);
}
{
assert_eq!(reqoptscorer.advance(), 10);
assert_eq!(reqoptscorer.doc(), 10);
assert_eq!(reqoptscorer.score(), 1.0);
}
{
assert_eq!(reqoptscorer.advance(), 13);
assert_eq!(reqoptscorer.doc(), 13);
assert_eq!(reqoptscorer.score(), 1.0);
}
{
assert_eq!(reqoptscorer.advance(), 15);
assert_eq!(reqoptscorer.doc(), 15);
assert_eq!(reqoptscorer.score(), 2.0);
}
assert_eq!(reqoptscorer.advance(), TERMINATED);
}
#[test]
fn test_reqopt_scorer_skip() {
let req_docs = sample_with_seed(10_000, 0.02, 1);
let opt_docs = sample_with_seed(10_000, 0.02, 2);
let skip_docs = sample_with_seed(10_000, 0.001, 3);
test_skip_against_unoptimized(
|| {
Box::new(RequiredOptionalScorer::<_, _, DoNothingCombiner>::new(
ConstScorer::from(VecDocSet::from(req_docs.clone())),
ConstScorer::from(VecDocSet::from(opt_docs.clone())),
))
},
skip_docs,
);
}
}