vivaldi_nc/network_coordinate.rs
1//! Main interface module for Vivaldi network coordinates.
2//!
3//! For usage explanation and examples, please see the main [`crate`] documentation.
4
5use core::time::Duration;
6
7use serde::{Deserialize, Serialize};
8
9use crate::height_vector::HeightVector;
10
11//
12// **** Features ****
13//
14#[cfg(feature = "f32")]
15type FloatType = f32;
16
17#[cfg(not(feature = "f32"))]
18type FloatType = f64;
19
20//
21// **** Constants ****
22//
23
24// Vivaldi tuning parameters
25const C_ERROR: FloatType = 0.25;
26const C_DELTA: FloatType = 0.25;
27
28// initial error value
29const DEFAULT_ERROR: FloatType = 200.0;
30
31// error should always be greater than zero
32const MIN_ERROR: FloatType = FloatType::EPSILON;
33
34//
35// **** Structs ****
36//
37
38/// A `NetworkCoordinate<N>` is the main interface to a Vivaldi network coordinate.
39///
40/// # Generic Parameters
41///
42/// - `N`: Const generic for number of dimensions. For example, `NetworkCoordinate<3>` is a
43/// 3-Dimentionsal Euclidean coordinate plus a height. Should be a positive number greater than
44/// zero.
45///
46/// **Note:** Dimensions other than 2D or 3D are usually not useful. If you want to use one of
47/// those dimensions, you can use type aliases ([`NetworkCoordinate2D`] or [`NetworkCoordinate3D`])
48/// which are a little more ergonomic than using the generic here.
49///
50/// # Examples
51///
52/// For an explanation and examples of usage, please see the main [`crate`] documentation.
53#[derive(Clone, Debug, Serialize, Deserialize)]
54pub struct NetworkCoordinate<const N: usize> {
55 #[serde(flatten)]
56 heightvec: HeightVector<N>,
57 error: FloatType,
58}
59
60// type aliases for convenience
61
62/// A 2D [`NetworkCoordinate`]. Includes a 2D Euclidean position and a height.
63///
64/// This type alias is just for convenience. It's functionally equivalent to
65/// `NetworkCoordinate<2>`. For more information, see [`NetworkCoordinate`].
66#[allow(clippy::module_name_repetitions)]
67pub type NetworkCoordinate2D = NetworkCoordinate<2>;
68
69/// A 3D [`NetworkCoordinate`]. Includes a 3D Euclidean position and a height.
70///
71/// This type alias is just for convenience. It's functionally equivalent to
72/// `NetworkCoordinate<3>`. For more information, see [`NetworkCoordinate`].
73#[allow(clippy::module_name_repetitions)]
74pub type NetworkCoordinate3D = NetworkCoordinate<3>;
75
76//
77// **** Implementations ****
78//
79
80impl<const N: usize> NetworkCoordinate<N> {
81 /// Creates a new random [`NetworkCoordinate`]
82 ///
83 /// # Example
84 ///
85 /// ```
86 /// use vivaldi_nc::NetworkCoordinate;
87 ///
88 /// // create a new 3-dimensional random NC
89 /// let a: NetworkCoordinate<3> = NetworkCoordinate::new();
90 ///
91 /// // print the NC
92 /// println!("Our new NC is: {:#?}", a);
93 /// ```
94 #[must_use]
95 pub fn new() -> Self {
96 Self::default()
97 }
98
99 /// Given another Vivaldi [`NetworkCoordinate`], estimate the round trip time (ie ping) between them.
100 ///
101 /// This is done by computing the height vector distance between between the two coordinates.
102 /// Vivaldi uses this distance as a representation of estimated round trip time.
103 ///
104 /// # Parameters
105 ///
106 /// - `rhs`: the other coordinate
107 ///
108 /// # Returns
109 ///
110 /// - the estimated round trip time as a `Duration`
111 ///
112 /// # Example
113 ///
114 /// ```
115 /// use vivaldi_nc::NetworkCoordinate;
116 ///
117 /// // create some 2-dimensional NCs for the sake of this example. These will just be random
118 /// // NCs. In a real usecase these would have meaningful values.
119 /// let a: NetworkCoordinate<2> = NetworkCoordinate::new();
120 /// let b: NetworkCoordinate<2> = NetworkCoordinate::new();
121 ///
122 /// // get the estimated RTT, convert to milliseconds, and print
123 /// println!("Estimated RTT: {}", a.estimated_rtt(&b).as_millis());
124 /// ```
125 ///
126 #[must_use]
127 pub fn estimated_rtt(&self, rhs: &Self) -> Duration {
128 // estimated rss is euclidean distance between the two plus the sum of the heights
129 cfg_if::cfg_if! {
130 if #[cfg(feature = "f32")] {
131 Duration::from_secs_f32((self.heightvec - rhs.heightvec).len() / 1000.0)
132 } else {
133 Duration::from_secs_f64((self.heightvec - rhs.heightvec).len() / 1000.0)
134 }
135 }
136 }
137
138 /// Given another Vivaldi [`NetworkCoordinate`], adjust our coordinateto better represent the actual round
139 /// trip time (aka distance) between us.
140 ///
141 /// # Parameters
142 ///
143 /// - `rhs`: the other coordinate
144 /// - `rtt`: the measured round trip time between `self` and `rhs`
145 ///
146 /// # Returns
147 ///
148 /// - a reference to `self`
149 ///
150 /// # Example
151 ///
152 /// ```
153 /// use core::time::Duration;
154 /// use vivaldi_nc::NetworkCoordinate;
155 ///
156 /// // We always have our own NC:
157 /// let mut local: NetworkCoordinate<2> = NetworkCoordinate::new();
158 ///
159 /// // Assume we received a NC from a remote node:
160 /// let remote: NetworkCoordinate<2> = NetworkCoordinate::new();
161 ///
162 /// // And we measured the RTT between us and the remote node:
163 /// let rtt = Duration::from_millis(100);
164 ///
165 /// // Now we can update our NC to adjust our position relative to the remote node:
166 /// local.update(&remote, rtt);
167 /// ```
168 ///
169 /// # Algorithm
170 ///
171 /// This is an implementation of Vivaldi NCs per the original paper. It implements the following
172 /// alogirthm (quoting from paper):
173 ///
174 /// ```text
175 /// // Incorporate new information: node j has been
176 /// // measured to be rtt ms away, has coordinates xj,
177 /// // and an error estimate of ej .
178 /// //
179 /// // Our own coordinates and error estimate are xi and ei.
180 /// //
181 /// // The constants ce and cc are tuning parameters.
182 ///
183 /// vivaldi(rtt, xj, ej)
184 /// // Sample weight balances local and remote error. (1)
185 /// w = ei /(ei + ej)
186 /// // Compute relative error of this sample. (2)
187 /// es = ∣∣∣‖xi − xj‖ − rtt∣∣∣/rtt
188 /// // Update weighted moving average of local error. (3)
189 /// ei = es × ce × w + ei × (1 − ce × w)
190 /// // Update local coordinates. (4)
191 /// δ = cc × w
192 /// xi = xi + δ × (rtt − ‖xi − xj ‖) × u(xi − xj)
193 /// ```
194 ///
195 pub fn update(&mut self, rhs: &Self, rtt: Duration) -> &Self {
196 // convert Durations into FloatType as fractional milliseconds for convenience
197 cfg_if::cfg_if! {
198 if #[cfg(feature = "f32")] {
199 let rtt_ms = rtt.as_secs_f32() * 1000.0;
200 let rtt_estimated_ms = self.estimated_rtt(rhs).as_secs_f32() * 1000.0;
201 } else {
202 let rtt_ms = rtt.as_secs_f64() * 1000.0;
203 let rtt_estimated_ms = self.estimated_rtt(rhs).as_secs_f64() * 1000.0;
204 }
205 }
206
207 // rtt needs to be positive
208 if rtt_ms < 0.0 {
209 // Note: `rtt` is guaranteed to be positive because `Duration` enforces it.
210 // If this panics, something changed where `Duration` now allows for negative
211 // values.
212 unreachable!();
213 }
214
215 // Sample weight balances local and remote error. (1)
216 // w = ei /(ei + ej )
217 let w = self.error / (self.error + rhs.error);
218
219 // Compute relative error of this sample. (2)
220 // es = ∣∣∣‖xi − xj‖ − rtt∣∣∣/rtt
221 let error = rtt_ms - rtt_estimated_ms;
222 let es = error.abs() / rtt_ms;
223
224 // Update weighted moving average of local error. (3)
225 // ei = es × ce × w + ei × (1 − ce × w)
226 // self.error = (es * C_ERROR * w + self.error * (1.0 - C_ERROR * w)).max(MIN_ERROR);
227 // NOTE: using `mul_add()` which is a little safer (avoid overflows)
228 self.error = (es * C_ERROR)
229 .mul_add(w, self.error * C_ERROR.mul_add(-w, 1.0))
230 .max(MIN_ERROR);
231
232 // Update local coordinates. (4)
233 // δ = cc × w
234 let delta = C_DELTA * w;
235 // xi = xi + δ × (rtt − ‖xi − xj ‖) × u(xi − xj)
236 self.heightvec =
237 self.heightvec + (self.heightvec - rhs.heightvec).normalized() * delta * error;
238
239 // if we ended up with an invalid coordinate, return a new random coordinate with default
240 // error
241 if self.heightvec.is_invalid() {
242 *self = Self::new();
243
244 // We should never get here because the call to `normalized()` above should catch any
245 // invalid `heightvec`
246 unreachable!();
247 }
248
249 // return reference to updated self
250 self
251 }
252
253 /// getter for error value - useful for consumers to understand the estimated accuracty of this
254 /// `NetworkCoordinate`
255 #[must_use]
256 pub const fn error(&self) -> FloatType {
257 self.error
258 }
259}
260
261//
262// **** Trait Implementations ****
263//
264
265impl<const N: usize> Default for NetworkCoordinate<N> {
266 /// A default `NetworkCoordinate` has a random position and `DEFAULT_ERROR`
267 fn default() -> Self {
268 Self {
269 heightvec: HeightVector::<N>::random(),
270 error: DEFAULT_ERROR,
271 }
272 }
273}
274
275//
276// **** Tests ****
277//
278#[cfg(test)]
279mod tests {
280 use assert_approx_eq::assert_approx_eq;
281
282 use super::*;
283
284 #[test]
285 fn test_convergence() {
286 let mut a = NetworkCoordinate::<3>::new();
287 let mut b = NetworkCoordinate::<3>::new();
288 let t = Duration::from_millis(250);
289 (0..20).for_each(|_| {
290 a.update(&b, t);
291 b.update(&a, t);
292 });
293 let rtt = a.estimated_rtt(&b);
294 assert_approx_eq!(rtt.as_secs_f32() * 1000.0, 250.0, 1.0);
295 }
296
297 #[test]
298 fn test_mini_network() {
299 // define a little network with these nodes:
300 //
301 // slc has 80ms stem time to core entry Seattle
302 // nyc has 30ms stem time to core entry Virginia
303 // lax has 15ms stem time to core entry Los Angeles
304 // mad has 60ms stem time to core entry London
305 //
306 // we'll assume that traffic in the core moves at 50% the speed of light
307 //
308 // that gives us this grid of RTTs (in ms) for the core:
309 //
310 // | | Seattle | Virgina | Los Angeles | London |
311 // |-------------|---------|---------|-------------|--------|
312 // | Seattle | - | 52 | 20 | 102 |
313 // | Virginia | | - | 50 | 78 |
314 // | Los Angeles | | | - | 116 |
315 // | London | | | | - |
316 //
317 // Which gives us these routes (plus their reverse) and times (ms):
318 //
319 // SLC -> Seattle -> Virginia -> NYC = 80 + 52 + 30 = 162
320 // SLC -> Seattle -> Los Angeles -> LAX = 80 + 20 + 15 = 115
321 // SLC -> Seattle -> Londong -> MAD = 80 + 102 + 60 = 242
322 // NYC -> Virginia -> Los Angeles -> LAX = 30 + 50 + 15 = 95
323 // NYC -> Virginia -> London -> MAD = 30 + 78 + 60 = 168
324 // LAX -> Los Angeles -> London -> MAD = 15 + 116 + 60 = 192
325
326 // create the NCs for each endpoint
327 let mut slc = NetworkCoordinate::<2>::new();
328 let mut nyc = NetworkCoordinate::<2>::new();
329 let mut lax = NetworkCoordinate::<2>::new();
330 let mut mad = NetworkCoordinate::<2>::new();
331
332 // verify the initial error
333 let error = slc.error.hypot(nyc.error.hypot(lax.error.hypot(mad.error)));
334 assert_approx_eq!(error, 400.0);
335
336 // iterate plenty of times to converge and minimize error
337 (0..20).for_each(|_| {
338 slc.update(&nyc, Duration::from_millis(162));
339 nyc.update(&slc, Duration::from_millis(162));
340
341 slc.update(&lax, Duration::from_millis(115));
342 lax.update(&slc, Duration::from_millis(115));
343
344 slc.update(&mad, Duration::from_millis(242));
345 mad.update(&slc, Duration::from_millis(242));
346
347 nyc.update(&lax, Duration::from_millis(95));
348 lax.update(&nyc, Duration::from_millis(95));
349
350 nyc.update(&mad, Duration::from_millis(168));
351 mad.update(&nyc, Duration::from_millis(168));
352
353 lax.update(&mad, Duration::from_millis(192));
354 mad.update(&lax, Duration::from_millis(192));
355 });
356
357 // compute and test the root mean squared error
358 let error = slc.error + nyc.error + lax.error + mad.error;
359 assert!(error < 5.0);
360 }
361
362 #[test]
363 fn test_serde() {
364 // start with JSON, deserialize it
365 let s = "{\"position\":[1.5,0.5,2.0],\"height\":0.1,\"error\":1.0}";
366 let a: NetworkCoordinate<3> =
367 serde_json::from_str(s).expect("deserialization failed during test");
368
369 // make sure it's the right length and works like we expect a normal NC
370 assert_approx_eq!(a.heightvec.len(), 2.649_509, 0.001);
371 assert_approx_eq!(a.error, 1.0);
372 assert_eq!(a.estimated_rtt(&a).as_millis(), 0);
373
374 // serialize it into a new JSON string and make sure it matches the original
375 let t = serde_json::to_string(&a);
376 assert_eq!(t.as_ref().expect("serialization failed during test"), s);
377 }
378
379 #[test]
380 fn test_estimated_rtt() {
381 // start with JSON, deserialize it
382 let s = "{\"position\":[1.5,0.5,2.0],\"height\":25.0,\"error\":1.0}";
383 let a: NetworkCoordinate<3> =
384 serde_json::from_str(s).expect("deserialization failed during test");
385 let s = "{\"position\":[-1.5,-0.5,-2.0],\"height\":50.0,\"error\":1.0}";
386 let b: NetworkCoordinate<3> =
387 serde_json::from_str(s).expect("deserialization failed during test");
388
389 let estimate = a.estimated_rtt(&b);
390 assert_approx_eq!(estimate.as_secs_f32(), 0.080_099);
391 }
392
393 #[test]
394 fn test_error_getter() {
395 let s = "{\"position\":[1.5,0.5,2.0],\"height\":25.0,\"error\":1.0}";
396 let a: NetworkCoordinate<3> =
397 serde_json::from_str(s).expect("deserialization failed during test");
398 assert_approx_eq!(a.error(), 1.0);
399 }
400}