str_distance/
modifiers.rs

1use std::cmp;
2
3use crate::utils::{count_eq, order_by_len_asc};
4use crate::{DistanceMetric, Jaro};
5
6#[derive(Debug, Clone)]
7pub struct WinklerConfig {
8    /// Scaling factor. Default to 0.1
9    scaling: f64,
10    /// Boost threshold. Default to 0.7
11    threshold: f64,
12    /// max length of common prefix. Default to 4
13    max_length: usize,
14}
15
16impl WinklerConfig {
17    /// # Panics
18    ///
19    /// Panics if the scaling factor times maxlength of common prefix is higher
20    /// than one.
21    pub fn new(scaling: f64, threshold: f64, max_length: usize) -> Self {
22        assert!(scaling * max_length as f64 <= 1.);
23        Self {
24            scaling,
25            threshold,
26            max_length,
27        }
28    }
29}
30
31impl Default for WinklerConfig {
32    fn default() -> Self {
33        Self {
34            scaling: 0.1,
35            threshold: 0.7,
36            max_length: 4,
37        }
38    }
39}
40
41/// `Winkler` modifies a [`DistanceMetric`]'s distance to decrease the distance
42/// between  two strings, when their original distance is below some
43/// `threshold`. The boost is equal to `min(l,  maxlength) * p * dist` where `l`
44/// denotes the length of their common prefix and `dist` denotes the original
45/// distance. The Winkler adjustment was originally defined for the [`Jaro`]
46/// similarity score but is here defined it for any distance.
47#[derive(Debug, Clone)]
48pub struct Winkler<D: DistanceMetric> {
49    /// The base distance to modify.
50    inner: D,
51    /// Coefficients for winkler modification.
52    config: WinklerConfig,
53}
54
55impl<D: DistanceMetric> Winkler<D> {
56    pub fn new(inner: D) -> Self {
57        Self {
58            inner,
59            config: Default::default(),
60        }
61    }
62
63    pub fn with_config(inner: D, config: WinklerConfig) -> Self {
64        Self { inner, config }
65    }
66}
67
68impl<D> DistanceMetric for Winkler<D>
69where
70    D: DistanceMetric,
71    <D as DistanceMetric>::Dist: Into<f64>,
72{
73    type Dist = f64;
74
75    fn distance<S, T>(&self, a: S, b: T) -> Self::Dist
76    where
77        S: IntoIterator,
78        T: IntoIterator,
79        <S as IntoIterator>::IntoIter: Clone,
80        <T as IntoIterator>::IntoIter: Clone,
81        <S as IntoIterator>::Item: PartialEq + PartialEq<<T as IntoIterator>::Item>,
82        <T as IntoIterator>::Item: PartialEq,
83    {
84        let a = a.into_iter();
85        let b = b.into_iter();
86
87        let mut score = self.inner.distance(a.clone(), b.clone()).into();
88
89        if score <= 1. - self.config.threshold {
90            let eq_prefix = count_eq(a, b);
91            score -=
92                cmp::min(eq_prefix, self.config.max_length) as f64 * self.config.scaling * score;
93        }
94
95        score
96    }
97
98    fn str_distance<S, T>(&self, s1: S, s2: T) -> Self::Dist
99    where
100        S: AsRef<str>,
101        T: AsRef<str>,
102    {
103        let (s1, s2) = order_by_len_asc(s1.as_ref(), s2.as_ref());
104        self.distance(s1.chars(), s2.chars())
105    }
106
107    fn normalized<S, T>(&self, a: S, b: T) -> f64
108    where
109        S: IntoIterator,
110        T: IntoIterator,
111        <S as IntoIterator>::IntoIter: Clone,
112        <T as IntoIterator>::IntoIter: Clone,
113        <S as IntoIterator>::Item: PartialEq + PartialEq<<T as IntoIterator>::Item>,
114        <T as IntoIterator>::Item: PartialEq,
115    {
116        self.distance(a, b)
117    }
118}
119
120impl Default for Winkler<Jaro> {
121    fn default() -> Self {
122        Self {
123            inner: Jaro,
124            config: Default::default(),
125        }
126    }
127}