vivaldi_nc/
network_coordinate.rs

1//! Main interface module for Vivaldi network coordinates.
2//!
3//! For usage explanation and examples, please see the main [`crate`] documentation.
4
5use core::time::Duration;
6
7use serde::{Deserialize, Serialize};
8
9use crate::height_vector::HeightVector;
10
11//
12// **** Features ****
13//
14#[cfg(feature = "f32")]
15type FloatType = f32;
16
17#[cfg(not(feature = "f32"))]
18type FloatType = f64;
19
20//
21// **** Constants ****
22//
23
24// Vivaldi tuning parameters
25const C_ERROR: FloatType = 0.25;
26const C_DELTA: FloatType = 0.25;
27
28// initial error value
29const DEFAULT_ERROR: FloatType = 200.0;
30
31// error should always be greater than zero
32const MIN_ERROR: FloatType = FloatType::EPSILON;
33
34//
35// **** Structs ****
36//
37
38/// A `NetworkCoordinate<N>` is the main interface to a Vivaldi network coordinate.
39///
40/// # Generic Parameters
41///
42/// - `N`: Const generic for number of dimensions. For example, `NetworkCoordinate<3>` is a
43/// 3-Dimentionsal Euclidean coordinate plus a height. Should be a positive number greater than
44/// zero.
45///
46/// **Note:** Dimensions other than 2D or 3D are usually not useful. If you want to use one of
47/// those dimensions, you can use type aliases ([`NetworkCoordinate2D`] or [`NetworkCoordinate3D`])
48/// which are a little more ergonomic than using the generic here.
49///
50/// # Examples
51///
52/// For an explanation and examples of usage, please see the main [`crate`] documentation.
53#[derive(Clone, Debug, Serialize, Deserialize)]
54pub struct NetworkCoordinate<const N: usize> {
55    #[serde(flatten)]
56    heightvec: HeightVector<N>,
57    error: FloatType,
58}
59
60// type aliases for convenience
61
62/// A 2D [`NetworkCoordinate`]. Includes a 2D Euclidean position and a height.
63///
64/// This type alias is just for convenience. It's functionally equivalent to
65/// `NetworkCoordinate<2>`. For more information, see [`NetworkCoordinate`].
66#[allow(clippy::module_name_repetitions)]
67pub type NetworkCoordinate2D = NetworkCoordinate<2>;
68
69/// A 3D [`NetworkCoordinate`]. Includes a 3D Euclidean position and a height.
70///
71/// This type alias is just for convenience. It's functionally equivalent to
72/// `NetworkCoordinate<3>`. For more information, see [`NetworkCoordinate`].
73#[allow(clippy::module_name_repetitions)]
74pub type NetworkCoordinate3D = NetworkCoordinate<3>;
75
76//
77// **** Implementations ****
78//
79
80impl<const N: usize> NetworkCoordinate<N> {
81    /// Creates a new random [`NetworkCoordinate`]
82    ///
83    /// # Example
84    ///
85    /// ```
86    /// use vivaldi_nc::NetworkCoordinate;
87    ///
88    /// // create a new 3-dimensional random NC
89    /// let a: NetworkCoordinate<3> = NetworkCoordinate::new();
90    ///
91    /// // print the NC
92    /// println!("Our new NC is: {:#?}", a);
93    /// ```
94    #[must_use]
95    pub fn new() -> Self {
96        Self::default()
97    }
98
99    /// Given another Vivaldi [`NetworkCoordinate`], estimate the round trip time (ie ping) between them.
100    ///
101    /// This is done by computing the height vector distance between between the two coordinates.
102    /// Vivaldi uses this distance as a representation of estimated round trip time.
103    ///
104    /// # Parameters
105    ///
106    /// - `rhs`: the other coordinate
107    ///
108    /// # Returns
109    ///
110    /// - the estimated round trip time as a `Duration`
111    ///
112    /// # Example
113    ///
114    /// ```
115    /// use vivaldi_nc::NetworkCoordinate;
116    ///
117    /// // create some 2-dimensional NCs for the sake of this example. These will just be random
118    /// // NCs. In a real usecase these would have meaningful values.
119    /// let a: NetworkCoordinate<2> = NetworkCoordinate::new();
120    /// let b: NetworkCoordinate<2> = NetworkCoordinate::new();
121    ///
122    /// // get the estimated RTT, convert to milliseconds, and print
123    /// println!("Estimated RTT: {}", a.estimated_rtt(&b).as_millis());
124    /// ```
125    ///
126    #[must_use]
127    pub fn estimated_rtt(&self, rhs: &Self) -> Duration {
128        // estimated rss is euclidean distance between the two plus the sum of the heights
129        cfg_if::cfg_if! {
130            if #[cfg(feature = "f32")] {
131                Duration::from_secs_f32((self.heightvec - rhs.heightvec).len() / 1000.0)
132            } else {
133                Duration::from_secs_f64((self.heightvec - rhs.heightvec).len() / 1000.0)
134            }
135        }
136    }
137
138    /// Given another Vivaldi [`NetworkCoordinate`], adjust our coordinateto better represent the actual round
139    /// trip time (aka distance) between us.
140    ///
141    /// # Parameters
142    ///
143    /// - `rhs`: the other coordinate
144    /// - `rtt`: the measured round trip time between `self` and `rhs`
145    ///
146    /// # Returns
147    ///
148    /// - a reference to `self`
149    ///
150    /// # Example
151    ///
152    /// ```
153    /// use core::time::Duration;
154    /// use vivaldi_nc::NetworkCoordinate;
155    ///
156    /// // We always have our own NC:
157    /// let mut local: NetworkCoordinate<2> = NetworkCoordinate::new();
158    ///
159    /// // Assume we received a NC from a remote node:
160    /// let remote: NetworkCoordinate<2> = NetworkCoordinate::new();
161    ///
162    /// // And we measured the RTT between us and the remote node:
163    /// let rtt = Duration::from_millis(100);
164    ///
165    /// // Now we can update our NC to adjust our position relative to the remote node:
166    /// local.update(&remote, rtt);
167    /// ```
168    ///
169    /// # Algorithm
170    ///
171    /// This is an implementation of Vivaldi NCs per the original paper. It implements the following
172    /// alogirthm (quoting from paper):
173    ///
174    /// ```text
175    /// // Incorporate new information: node j has been
176    /// // measured to be rtt ms away, has coordinates xj,
177    /// // and an error estimate of ej .
178    /// //
179    /// // Our own coordinates and error estimate are xi and ei.
180    /// //
181    /// // The constants ce and cc are tuning parameters.
182    ///
183    /// vivaldi(rtt, xj, ej)
184    ///     // Sample weight balances local and remote error. (1)
185    ///     w = ei /(ei + ej)
186    ///     // Compute relative error of this sample. (2)
187    ///     es = ∣∣∣‖xi − xj‖ − rtt∣∣∣/rtt
188    ///     // Update weighted moving average of local error. (3)
189    ///     ei = es × ce × w + ei × (1 − ce × w)
190    ///     // Update local coordinates. (4)
191    ///     δ = cc × w
192    ///     xi = xi + δ × (rtt − ‖xi − xj ‖) × u(xi − xj)
193    /// ```
194    ///
195    pub fn update(&mut self, rhs: &Self, rtt: Duration) -> &Self {
196        // convert Durations into FloatType as fractional milliseconds for convenience
197        cfg_if::cfg_if! {
198            if #[cfg(feature = "f32")] {
199                let rtt_ms = rtt.as_secs_f32() * 1000.0;
200                let rtt_estimated_ms = self.estimated_rtt(rhs).as_secs_f32() * 1000.0;
201            } else {
202                let rtt_ms = rtt.as_secs_f64() * 1000.0;
203                let rtt_estimated_ms = self.estimated_rtt(rhs).as_secs_f64() * 1000.0;
204            }
205        }
206
207        // rtt needs to be positive
208        if rtt_ms < 0.0 {
209            // Note: `rtt` is guaranteed to be positive because `Duration` enforces it.
210            //       If this panics, something changed where `Duration` now allows for negative
211            //       values.
212            unreachable!();
213        }
214
215        // Sample weight balances local and remote error. (1)
216        // w = ei /(ei + ej )
217        let w = self.error / (self.error + rhs.error);
218
219        // Compute relative error of this sample. (2)
220        // es = ∣∣∣‖xi − xj‖ − rtt∣∣∣/rtt
221        let error = rtt_ms - rtt_estimated_ms;
222        let es = error.abs() / rtt_ms;
223
224        // Update weighted moving average of local error. (3)
225        // ei = es × ce × w + ei × (1 − ce × w)
226        // self.error = (es * C_ERROR * w + self.error * (1.0 - C_ERROR * w)).max(MIN_ERROR);
227        // NOTE: using `mul_add()` which is a little safer (avoid overflows)
228        self.error = (es * C_ERROR)
229            .mul_add(w, self.error * C_ERROR.mul_add(-w, 1.0))
230            .max(MIN_ERROR);
231
232        // Update local coordinates. (4)
233        // δ = cc × w
234        let delta = C_DELTA * w;
235        // xi = xi + δ × (rtt − ‖xi − xj ‖) × u(xi − xj)
236        self.heightvec =
237            self.heightvec + (self.heightvec - rhs.heightvec).normalized() * delta * error;
238
239        // if we ended up with an invalid coordinate, return a new random coordinate with default
240        // error
241        if self.heightvec.is_invalid() {
242            *self = Self::new();
243
244            // We should never get here because the call to `normalized()` above should catch any
245            // invalid `heightvec`
246            unreachable!();
247        }
248
249        // return reference to updated self
250        self
251    }
252
253    /// getter for error value - useful for consumers to understand the estimated accuracty of this
254    /// `NetworkCoordinate`
255    #[must_use]
256    pub const fn error(&self) -> FloatType {
257        self.error
258    }
259}
260
261//
262// **** Trait Implementations ****
263//
264
265impl<const N: usize> Default for NetworkCoordinate<N> {
266    /// A default `NetworkCoordinate` has a random position and `DEFAULT_ERROR`
267    fn default() -> Self {
268        Self {
269            heightvec: HeightVector::<N>::random(),
270            error: DEFAULT_ERROR,
271        }
272    }
273}
274
275//
276// **** Tests ****
277//
278#[cfg(test)]
279mod tests {
280    use assert_approx_eq::assert_approx_eq;
281
282    use super::*;
283
284    #[test]
285    fn test_convergence() {
286        let mut a = NetworkCoordinate::<3>::new();
287        let mut b = NetworkCoordinate::<3>::new();
288        let t = Duration::from_millis(250);
289        (0..20).for_each(|_| {
290            a.update(&b, t);
291            b.update(&a, t);
292        });
293        let rtt = a.estimated_rtt(&b);
294        assert_approx_eq!(rtt.as_secs_f32() * 1000.0, 250.0, 1.0);
295    }
296
297    #[test]
298    fn test_mini_network() {
299        // define a little network with these nodes:
300        //
301        // slc has 80ms stem time to core entry Seattle
302        // nyc has 30ms stem time to core entry Virginia
303        // lax has 15ms stem time to core entry Los Angeles
304        // mad has 60ms stem time to core entry London
305        //
306        // we'll assume that traffic in the core moves at 50% the speed of light
307        //
308        // that gives us this grid of RTTs (in ms) for the core:
309        //
310        // |             | Seattle | Virgina | Los Angeles | London |
311        // |-------------|---------|---------|-------------|--------|
312        // | Seattle     |       - |      52 |          20 |    102 |
313        // | Virginia    |         |       - |          50 |     78 |
314        // | Los Angeles |         |         |           - |    116 |
315        // | London      |         |         |             |      - |
316        //
317        // Which gives us these routes (plus their reverse) and times (ms):
318        //
319        // SLC -> Seattle -> Virginia -> NYC = 80 + 52 + 30 = 162
320        // SLC -> Seattle -> Los Angeles -> LAX = 80 + 20 + 15 = 115
321        // SLC -> Seattle -> Londong -> MAD = 80 + 102 + 60 = 242
322        // NYC -> Virginia -> Los Angeles -> LAX = 30 + 50 + 15 = 95
323        // NYC -> Virginia -> London -> MAD = 30 + 78 + 60 = 168
324        // LAX -> Los Angeles -> London -> MAD = 15 + 116 + 60 = 192
325
326        // create the NCs for each endpoint
327        let mut slc = NetworkCoordinate::<2>::new();
328        let mut nyc = NetworkCoordinate::<2>::new();
329        let mut lax = NetworkCoordinate::<2>::new();
330        let mut mad = NetworkCoordinate::<2>::new();
331
332        // verify the initial error
333        let error = slc.error.hypot(nyc.error.hypot(lax.error.hypot(mad.error)));
334        assert_approx_eq!(error, 400.0);
335
336        // iterate plenty of times to converge and minimize error
337        (0..20).for_each(|_| {
338            slc.update(&nyc, Duration::from_millis(162));
339            nyc.update(&slc, Duration::from_millis(162));
340
341            slc.update(&lax, Duration::from_millis(115));
342            lax.update(&slc, Duration::from_millis(115));
343
344            slc.update(&mad, Duration::from_millis(242));
345            mad.update(&slc, Duration::from_millis(242));
346
347            nyc.update(&lax, Duration::from_millis(95));
348            lax.update(&nyc, Duration::from_millis(95));
349
350            nyc.update(&mad, Duration::from_millis(168));
351            mad.update(&nyc, Duration::from_millis(168));
352
353            lax.update(&mad, Duration::from_millis(192));
354            mad.update(&lax, Duration::from_millis(192));
355        });
356
357        // compute and test the root mean squared error
358        let error = slc.error + nyc.error + lax.error + mad.error;
359        assert!(error < 5.0);
360    }
361
362    #[test]
363    fn test_serde() {
364        // start with JSON, deserialize it
365        let s = "{\"position\":[1.5,0.5,2.0],\"height\":0.1,\"error\":1.0}";
366        let a: NetworkCoordinate<3> =
367            serde_json::from_str(s).expect("deserialization failed during test");
368
369        // make sure it's the right length and works like we expect a normal NC
370        assert_approx_eq!(a.heightvec.len(), 2.649_509, 0.001);
371        assert_approx_eq!(a.error, 1.0);
372        assert_eq!(a.estimated_rtt(&a).as_millis(), 0);
373
374        // serialize it into a new JSON string and make sure it matches the original
375        let t = serde_json::to_string(&a);
376        assert_eq!(t.as_ref().expect("serialization failed during test"), s);
377    }
378
379    #[test]
380    fn test_estimated_rtt() {
381        // start with JSON, deserialize it
382        let s = "{\"position\":[1.5,0.5,2.0],\"height\":25.0,\"error\":1.0}";
383        let a: NetworkCoordinate<3> =
384            serde_json::from_str(s).expect("deserialization failed during test");
385        let s = "{\"position\":[-1.5,-0.5,-2.0],\"height\":50.0,\"error\":1.0}";
386        let b: NetworkCoordinate<3> =
387            serde_json::from_str(s).expect("deserialization failed during test");
388
389        let estimate = a.estimated_rtt(&b);
390        assert_approx_eq!(estimate.as_secs_f32(), 0.080_099);
391    }
392
393    #[test]
394    fn test_error_getter() {
395        let s = "{\"position\":[1.5,0.5,2.0],\"height\":25.0,\"error\":1.0}";
396        let a: NetworkCoordinate<3> =
397            serde_json::from_str(s).expect("deserialization failed during test");
398        assert_approx_eq!(a.error(), 1.0);
399    }
400}