1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
use crate::collector::top_collector::{TopCollector, TopSegmentCollector};
use crate::collector::{Collector, SegmentCollector};
use crate::Result;
use crate::{DocAddress, DocId, Score, SegmentReader};

pub(crate) struct CustomScoreTopCollector<TCustomScorer, TScore = Score> {
    custom_scorer: TCustomScorer,
    collector: TopCollector<TScore>,
}

impl<TCustomScorer, TScore> CustomScoreTopCollector<TCustomScorer, TScore>
where
    TScore: Clone + PartialOrd,
{
    pub fn new(
        custom_scorer: TCustomScorer,
        limit: usize,
    ) -> CustomScoreTopCollector<TCustomScorer, TScore> {
        CustomScoreTopCollector {
            custom_scorer,
            collector: TopCollector::with_limit(limit),
        }
    }
}

/// A custom segment scorer makes it possible to define any kind of score
/// for a given document belonging to a specific segment.
///
/// It is the segment local version of the [`CustomScorer`](./trait.CustomScorer.html).
pub trait CustomSegmentScorer<TScore>: 'static {
    /// Computes the score of a specific `doc`.
    fn score(&self, doc: DocId) -> TScore;
}

/// `CustomScorer` makes it possible to define any kind of score.
///
/// The `CustomerScorer` itself does not make much of the computation itself.
/// Instead, it helps constructing `Self::Child` instances that will compute
/// the score at a segment scale.
pub trait CustomScorer<TScore>: Sync {
    /// Type of the associated [`CustomSegmentScorer`](./trait.CustomSegmentScorer.html).
    type Child: CustomSegmentScorer<TScore>;
    /// Builds a child scorer for a specific segment. The child scorer is associated to
    /// a specific segment.
    fn segment_scorer(&self, segment_reader: &SegmentReader) -> Result<Self::Child>;
}

impl<TCustomScorer, TScore> Collector for CustomScoreTopCollector<TCustomScorer, TScore>
where
    TCustomScorer: CustomScorer<TScore>,
    TScore: 'static + PartialOrd + Clone + Send + Sync,
{
    type Fruit = Vec<(TScore, DocAddress)>;

    type Child = CustomScoreTopSegmentCollector<TCustomScorer::Child, TScore>;

    fn for_segment(
        &self,
        segment_local_id: u32,
        segment_reader: &SegmentReader,
    ) -> Result<Self::Child> {
        let segment_scorer = self.custom_scorer.segment_scorer(segment_reader)?;
        let segment_collector = self
            .collector
            .for_segment(segment_local_id, segment_reader)?;
        Ok(CustomScoreTopSegmentCollector {
            segment_collector,
            segment_scorer,
        })
    }

    fn requires_scoring(&self) -> bool {
        false
    }

    fn merge_fruits(&self, segment_fruits: Vec<Self::Fruit>) -> Result<Self::Fruit> {
        self.collector.merge_fruits(segment_fruits)
    }
}

pub struct CustomScoreTopSegmentCollector<T, TScore>
where
    TScore: 'static + PartialOrd + Clone + Send + Sync + Sized,
    T: CustomSegmentScorer<TScore>,
{
    segment_collector: TopSegmentCollector<TScore>,
    segment_scorer: T,
}

impl<T, TScore> SegmentCollector for CustomScoreTopSegmentCollector<T, TScore>
where
    TScore: 'static + PartialOrd + Clone + Send + Sync,
    T: 'static + CustomSegmentScorer<TScore>,
{
    type Fruit = Vec<(TScore, DocAddress)>;

    fn collect(&mut self, doc: DocId, _score: Score) {
        let score = self.segment_scorer.score(doc);
        self.segment_collector.collect(doc, score);
    }

    fn harvest(self) -> Vec<(TScore, DocAddress)> {
        self.segment_collector.harvest()
    }
}

impl<F, TScore, T> CustomScorer<TScore> for F
where
    F: 'static + Send + Sync + Fn(&SegmentReader) -> T,
    T: CustomSegmentScorer<TScore>,
{
    type Child = T;

    fn segment_scorer(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
        Ok((self)(segment_reader))
    }
}

impl<F, TScore> CustomSegmentScorer<TScore> for F
where
    F: 'static + Sync + Send + Fn(DocId) -> TScore,
{
    fn score(&self, doc: DocId) -> TScore {
        (self)(doc)
    }
}