scion_stack/path/
manager.rs

1// Copyright 2025 Anapaya Systems
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//   http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//! # Path manager
15//!
16//! A [PathManager] provides applications with SCION paths. The method
17//! [PathManager::path_wait] is an async implementation that possibly awaits
18//! asynchronous, external path requests before returning. The sync-equivalent
19//! is [SyncPathManager::try_cached_path] which returns immediately in all
20//! cases.
21//!
22//! The main implementation provided in this module is the [CachingPathManager]
23//! which is an _active_ component: [CachingPathManager::start] will start an
24//! asynchronous background task (via `tokio::spawn`) that fetches requested
25//! paths using the provided [PathFetcher].
26//!
27//! Typically, applications require paths to fulfill specific constraints and
28//! paths are ranked; i.e. some paths are preferred over others. In the case of
29//! the [CachingPathManager] such constraints are expressed via a [PathPolicy]:
30//! only paths that fulfill [PathPolicy::predicate] are returned and paths with
31//! a lower rank according to [PathPolicy::rank] are preferred.
32
33use std::{cmp::Ordering, future::Future, io, sync::Arc};
34
35use bytes::Bytes;
36use chrono::{DateTime, Utc};
37use derive_more::Deref;
38use endhost_api_client::client::EndhostApiClient;
39use futures::{
40    FutureExt,
41    future::{self, BoxFuture},
42};
43use scc::{Guard, HashIndex, hash_index::Entry};
44use scion_proto::{
45    address::IsdAsn,
46    path::{self, Path},
47};
48use thiserror::Error;
49use tokio::sync::mpsc;
50use tokio_util::sync::CancellationToken;
51use tracing::{debug, error, trace, warn};
52
53use super::{Shortest, policy::PathPolicy};
54use crate::types::ResFut;
55
56/// Path fetch errors.
57#[derive(Debug, Error)]
58pub enum PathToError {
59    /// Path fetch failed.
60    #[error("fetching paths: {0}")]
61    FetchPaths(String),
62    /// No path found.
63    #[error("no path found")]
64    NoPathFound,
65}
66
67/// Path wait errors.
68#[derive(Debug, Clone, Error)]
69pub enum PathWaitError {
70    /// Path fetch failed.
71    #[error("path fetch failed: {0}")]
72    FetchFailed(String),
73    /// No path found.
74    #[error("no path found")]
75    NoPathFound,
76}
77
78impl From<PathToError> for PathWaitError {
79    fn from(error: PathToError) -> Self {
80        match error {
81            PathToError::FetchPaths(msg) => PathWaitError::FetchFailed(msg),
82            PathToError::NoPathFound => PathWaitError::NoPathFound,
83        }
84    }
85}
86
87/// Trait for active path management with async interface.
88pub trait PathManager: SyncPathManager {
89    /// Returns a path to the destination from the path cache or requests a new path from the SCION
90    /// Control Plane.
91    fn path_wait(
92        &self,
93        src: IsdAsn,
94        dst: IsdAsn,
95        now: DateTime<Utc>,
96    ) -> impl ResFut<'_, Path<Bytes>, PathWaitError>;
97}
98
99/// Trait for active path management with sync interface. Implementors of this trait should be
100/// able to be used in sync and async context. The functions must not block.
101pub trait SyncPathManager {
102    /// Add a path to the path cache. This can be used to register reverse paths.
103    fn register_path(&self, src: IsdAsn, dst: IsdAsn, now: DateTime<Utc>, path: Path<Bytes>);
104
105    /// Returns a path to the destination from the path cache.
106    /// If the path is not in the cache, it returns Ok(None)
107    /// If the cache is locked an io error WouldBlock is returned.
108    fn try_cached_path(
109        &self,
110        src: IsdAsn,
111        dst: IsdAsn,
112        now: DateTime<Utc>,
113    ) -> io::Result<Option<Path<Bytes>>>;
114}
115
116/// Request for prefetching a path
117#[derive(Debug, Clone)]
118struct PrefetchRequest {
119    pub src: IsdAsn,
120    pub dst: IsdAsn,
121    pub now: DateTime<Utc>,
122}
123
124/// Registration of a new path
125#[derive(Debug, Clone)]
126struct PathRegistration {
127    pub src: IsdAsn,
128    pub dst: IsdAsn,
129    pub now: DateTime<Utc>,
130    pub path: Path<Bytes>,
131}
132
133/// Cached path entry with metadata
134#[derive(Debug, Clone)]
135struct CachedPath {
136    path: scion_proto::path::Path,
137    #[expect(unused)]
138    cached_at: DateTime<Utc>,
139    from_registration: bool,
140}
141
142impl CachedPath {
143    fn new(path: scion_proto::path::Path, now: DateTime<Utc>, from_registration: bool) -> Self {
144        Self {
145            path,
146            cached_at: now,
147            from_registration,
148        }
149    }
150
151    fn is_expired(&self, now: DateTime<Utc>) -> bool {
152        self.path
153            .expiry_time()
154            .map(|expiry| expiry < now)
155            .unwrap_or(true)
156    }
157}
158
159/// Active path manager that runs as a background task
160pub struct CachingPathManager<P: PathPolicy = Shortest, F: PathFetcher = PathFetcherImpl> {
161    /// Shared state between the manager and the background task
162    state: CachingPathManagerState<P, F>,
163    /// Channels for communicating with the background task
164    prefetch_tx: mpsc::Sender<PrefetchRequest>,
165    registration_tx: mpsc::Sender<PathRegistration>,
166    /// Cancellation token for the background task
167    cancellation_token: CancellationToken,
168}
169
170/// Path fetch errors.
171#[derive(Debug, thiserror::Error)]
172pub enum PathFetchError {
173    /// Segment fetch failed.
174    #[error("failed to fetch segments: {0}")]
175    FetchSegments(#[from] SegmentFetchError),
176}
177
178/// Path fetcher trait.
179pub trait PathFetcher {
180    /// Fetch paths between source and destination ISD-AS.
181    fn fetch_paths(
182        &self,
183        src: IsdAsn,
184        dst: IsdAsn,
185    ) -> impl ResFut<'_, Vec<path::Path>, PathFetchError>;
186}
187
188type BoxedPathLookupResult = BoxFuture<'static, Result<Path<Bytes>, PathWaitError>>;
189
190struct CachingPathManagerStateInner<P: PathPolicy, F: PathFetcher> {
191    /// Policy for path selection
192    policy: P,
193    /// Path fetcher for requesting new paths
194    fetcher: F,
195    /// Cache of paths indexed by (src, dst)
196    path_cache: HashIndex<(IsdAsn, IsdAsn), CachedPath>,
197    /// In-flight path requests indexed by (src, dst)
198    inflight: HashIndex<(IsdAsn, IsdAsn), future::Shared<BoxedPathLookupResult>>,
199}
200
201/// Shared state for the active path manager
202#[derive(Deref)]
203#[deref(forward)]
204struct CachingPathManagerState<P: PathPolicy, F: PathFetcher>(
205    Arc<CachingPathManagerStateInner<P, F>>,
206);
207
208impl<P: PathPolicy, F: PathFetcher> Clone for CachingPathManagerState<P, F> {
209    fn clone(&self) -> Self {
210        Self(Arc::clone(&self.0))
211    }
212}
213
214impl<P: PathPolicy + Send + Sync + 'static, F: PathFetcher + Send + Sync + 'static>
215    CachingPathManager<P, F>
216{
217    /// Create and start an active path manager with automatic task management.
218    /// The background task is spawned internally and will be cancelled when the manager is dropped.
219    /// This is the recommended method for most users.
220    pub fn start(policy: P, fetcher: F) -> Self {
221        let cancellation_token = CancellationToken::new();
222        let (manager, task_future) =
223            Self::start_future(policy, fetcher, cancellation_token.clone());
224
225        // Spawn task internally, it is stopped when the manager is dropped.
226        tokio::spawn(async move {
227            task_future.await;
228        });
229
230        manager
231    }
232
233    /// Create the manager and task future.
234    pub fn start_future(
235        policy: P,
236        fetcher: F,
237        cancellation_token: CancellationToken,
238    ) -> (Self, impl std::future::Future<Output = ()>) {
239        let (prefetch_tx, prefetch_rx) = mpsc::channel(1000);
240        let (registration_tx, registration_rx) = mpsc::channel(1000);
241
242        let state = CachingPathManagerState(Arc::new(CachingPathManagerStateInner {
243            policy,
244            fetcher,
245            path_cache: HashIndex::new(),
246            inflight: HashIndex::new(),
247        }));
248
249        let manager = Self {
250            state: state.clone(),
251            prefetch_tx,
252            registration_tx,
253            cancellation_token: cancellation_token.clone(),
254        };
255
256        let task_future = async move {
257            let task =
258                PathManagerTask::new(state, prefetch_rx, registration_rx, cancellation_token);
259            task.run().await
260        };
261
262        (manager, task_future)
263    }
264
265    /// Returns a cached path if it is not expired.
266    pub fn try_cached_path(
267        &self,
268        src: IsdAsn,
269        dst: IsdAsn,
270        now: DateTime<Utc>,
271    ) -> io::Result<Option<Path<Bytes>>> {
272        self.state.try_cached_path(src, dst, now)
273    }
274
275    fn prefetch_path_internal(&self, src: IsdAsn, dst: IsdAsn, now: DateTime<Utc>) {
276        if let Err(e) = self.prefetch_tx.try_send(PrefetchRequest { src, dst, now }) {
277            trace!(
278                "Failed to send prefetch request - background task may be stopped: {}",
279                e
280            );
281        }
282    }
283
284    fn register_path_internal(
285        &self,
286        src: IsdAsn,
287        dst: IsdAsn,
288        now: DateTime<Utc>,
289        path: Path<Bytes>,
290    ) {
291        if let Err(e) = self.registration_tx.try_send(PathRegistration {
292            src,
293            dst,
294            now,
295            path,
296        }) {
297            warn!(
298                "Failed to send path registration - background task may be stopped: {}",
299                e
300            );
301        }
302    }
303}
304
305impl<P: PathPolicy, F: PathFetcher> Drop for CachingPathManager<P, F> {
306    fn drop(&mut self) {
307        self.cancellation_token.cancel();
308        // Background task will be cleaned up automatically
309        trace!("PathManager dropped, background task cancelled");
310    }
311}
312
313impl<P: PathPolicy + Send + Sync + 'static, F: PathFetcher + Send + Sync + 'static> SyncPathManager
314    for CachingPathManager<P, F>
315{
316    fn register_path(&self, src: IsdAsn, dst: IsdAsn, now: DateTime<Utc>, path: Path<Bytes>) {
317        self.register_path_internal(src, dst, now, path);
318    }
319
320    /// Returns a cached path if it is not expired or prefetches it if it is not in the cache.
321    /// If the path is not in the cache, it returns Ok(None).
322    /// If the cache is locked an io error WouldBlock is returned.
323    fn try_cached_path(
324        &self,
325        src: IsdAsn,
326        dst: IsdAsn,
327        now: DateTime<Utc>,
328    ) -> io::Result<Option<Path<Bytes>>> {
329        match self.state.try_cached_path(src, dst, now)? {
330            Some(path) => Ok(Some(path)),
331            None => {
332                // If the path is not found in the cache, we issue a prefetch request.
333                self.prefetch_path_internal(src, dst, now);
334                Ok(None)
335            }
336        }
337    }
338}
339
340impl<P: PathPolicy + Send + Sync + 'static, F: PathFetcher + Send + Sync + 'static> PathManager
341    for CachingPathManager<P, F>
342{
343    fn path_wait(
344        &self,
345        src: IsdAsn,
346        dst: IsdAsn,
347        now: DateTime<Utc>,
348    ) -> impl ResFut<'_, Path<Bytes>, PathWaitError> {
349        async move {
350            // First check if we have a cached path
351            if let Some(cached) = self.state.cached_path_wait(src, dst, now).await {
352                return Ok(cached);
353            }
354
355            // Fetch new path
356            self.state.fetch_and_cache_path(src, dst, now).await
357        }
358    }
359}
360
361/// Trait for prefetching paths in the path manager.
362pub trait PathPrefetcher {
363    /// Prefetch a paths for the given source and destination.
364    fn prefetch_path(&self, src: IsdAsn, dst: IsdAsn);
365}
366
367impl<P: PathPolicy + Send + Sync + 'static, F: PathFetcher + Send + Sync + 'static> PathPrefetcher
368    for CachingPathManager<P, F>
369{
370    fn prefetch_path(&self, src: IsdAsn, dst: IsdAsn) {
371        self.prefetch_path_internal(src, dst, Utc::now());
372    }
373}
374
375impl<P: PathPolicy + Send + Sync + 'static, F: PathFetcher + Send + Sync + 'static>
376    CachingPathManagerState<P, F>
377{
378    /// Returns a cached path if it is not expired.
379    pub fn try_cached_path(
380        &self,
381        src: IsdAsn,
382        dst: IsdAsn,
383        now: DateTime<Utc>,
384    ) -> io::Result<Option<Path<Bytes>>> {
385        let guard = Guard::new();
386        match self.path_cache.peek(&(src, dst), &guard) {
387            Some(cached) => {
388                if !cached.is_expired(now) {
389                    Ok(Some(cached.path.clone()))
390                } else {
391                    Ok(None)
392                }
393            }
394            None => Ok(None),
395        }
396    }
397
398    /// Returns a cached path if it is not expired. The cache state is locked asynchronously.
399    /// This should be used to get the cached path in an async context.
400    async fn cached_path_wait(
401        &self,
402        src: IsdAsn,
403        dst: IsdAsn,
404        now: DateTime<Utc>,
405    ) -> Option<Path<Bytes>> {
406        let guard = Guard::new();
407        match self.path_cache.peek(&(src, dst), &guard) {
408            Some(cached) => {
409                if !cached.is_expired(now) {
410                    Some(cached.path.clone())
411                } else {
412                    None
413                }
414            }
415            None => None,
416        }
417    }
418
419    /// Fetches a path, coalescing concurrent requests for the same source and destination.
420    async fn fetch_and_cache_path(
421        &self,
422        src: IsdAsn,
423        dst: IsdAsn,
424        now: DateTime<Utc>,
425    ) -> Result<Path<Bytes>, PathWaitError> {
426        let fut = match self.inflight.entry_sync((src, dst)) {
427            Entry::Occupied(entry) => entry.get().clone(),
428            Entry::Vacant(entry) => {
429                let self_c = self.clone();
430                entry
431                    .insert_entry(
432                        async move {
433                            let result = self_c.do_fetch_and_cache(src, dst, now).await;
434                            self_c.inflight.remove_sync(&(src, dst));
435                            result
436                        }
437                        .boxed()
438                        .shared(),
439                    )
440                    .clone()
441            }
442        };
443
444        fut.await
445    }
446
447    /// Helper to do the actual fetching and caching of paths between source and destination.
448    async fn do_fetch_and_cache(
449        &self,
450        src: IsdAsn,
451        dst: IsdAsn,
452        now: DateTime<Utc>,
453    ) -> Result<Path<Bytes>, PathWaitError> {
454        let mut paths = self
455            .fetcher
456            .fetch_paths(src, dst)
457            .await
458            .map_err(|e| PathWaitError::FetchFailed(e.to_string()))?;
459
460        paths.retain(|path| {
461            self.policy
462                .predicate(&super::policy::PolicyPath::new(path, false))
463        });
464        paths.sort_by(|a, b| {
465            self.policy.rank(
466                &super::policy::PolicyPath::new(a, false),
467                &super::policy::PolicyPath::new(b, false),
468            )
469        });
470        let preferred_path = paths.into_iter().next().ok_or(PathWaitError::NoPathFound)?;
471        let cached_path = CachedPath::new(preferred_path.clone(), now, false);
472
473        let entry = self.path_cache.entry_sync((src, dst));
474        match entry {
475            Entry::Occupied(mut entry) => {
476                entry.update(cached_path);
477            }
478            Entry::Vacant(entry) => {
479                entry.insert_entry(cached_path);
480            }
481        }
482
483        debug!(src = %src, dst = %dst, "Cached new path");
484        Ok(preferred_path)
485    }
486
487    /// Check if there is an in-flight request for the given source and destination.
488    fn request_inflight(&self, src: IsdAsn, dst: IsdAsn) -> bool {
489        let guard = Guard::new();
490        self.inflight.peek(&(src, dst), &guard).is_some()
491    }
492}
493
494/// Background task that handles prefetch requests and path registrations
495struct PathManagerTask<P: PathPolicy, F: PathFetcher> {
496    state: CachingPathManagerState<P, F>,
497    prefetch_rx: mpsc::Receiver<PrefetchRequest>,
498    registration_rx: mpsc::Receiver<PathRegistration>,
499    cancellation_token: CancellationToken,
500}
501
502impl<P: PathPolicy + Send + Sync + 'static, F: PathFetcher + Send + Sync + 'static>
503    PathManagerTask<P, F>
504{
505    fn new(
506        state: CachingPathManagerState<P, F>,
507        prefetch_rx: mpsc::Receiver<PrefetchRequest>,
508        registration_rx: mpsc::Receiver<PathRegistration>,
509        cancellation_token: CancellationToken,
510    ) -> Self {
511        Self {
512            state,
513            prefetch_rx,
514            registration_rx,
515            cancellation_token,
516        }
517    }
518
519    async fn run(mut self) {
520        trace!("Starting active path manager task");
521
522        loop {
523            tokio::select! {
524                // Handle cancellation with highest priority
525                _ = self.cancellation_token.cancelled() => {
526                    debug!("Path manager task cancelled");
527                    break;
528                }
529
530                // Handle path registrations (higher priority than prefetch)
531                registration = self.registration_rx.recv() => {
532                    match registration {
533                        Some(reg) => {
534                            self.handle_registration(reg).await;
535                        }
536                        None => {
537                            debug!("Registration channel closed");
538                            break;
539                        }
540                    }
541                }
542
543                // Handle prefetch requests
544                prefetch = self.prefetch_rx.recv() => {
545                    match prefetch {
546                        Some(req) => {
547                            self.handle_prefetch(req).await;
548                        }
549                        None => {
550                            debug!("Prefetch channel closed");
551                            break;
552                        }
553                    }
554                }
555            }
556        }
557
558        trace!("Path manager task finished");
559    }
560
561    async fn handle_registration(&self, registration: PathRegistration) {
562        trace!(
563            src = %registration.src,
564            dst = %registration.dst,
565            "Handling path registration"
566        );
567
568        // Check if the path is accepted by the policy
569        let policy_path = super::policy::PolicyPath::new(&registration.path, true);
570        if !self.state.policy.predicate(&policy_path) {
571            debug!(
572                src = %registration.src,
573                dst = %registration.dst,
574                "Registered path rejected by policy"
575            );
576            return;
577        }
578        let entry = self
579            .state
580            .path_cache
581            .entry_sync((registration.src, registration.dst));
582        match entry {
583            Entry::Occupied(mut entry) => {
584                let cached = entry.get();
585                let should_update = if cached.is_expired(registration.now) {
586                    true
587                } else {
588                    let existing_policy_path =
589                        super::policy::PolicyPath::new(&cached.path, cached.from_registration);
590                    matches!(
591                        self.state.policy.rank(&policy_path, &existing_policy_path),
592                        Ordering::Less
593                    )
594                };
595                if should_update {
596                    entry.update(CachedPath::new(registration.path, registration.now, true));
597                }
598            }
599            Entry::Vacant(entry) => {
600                entry.insert_entry(CachedPath::new(registration.path, registration.now, true));
601            }
602        }
603    }
604
605    /// Handle a prefetch request by checking the cache and fetching the path if necessary.
606    /// If the path is already cached or there is an in-flight request, it skips fetching.
607    /// Otherwise, it fetches the path and caches it.
608    async fn handle_prefetch(&self, request: PrefetchRequest) {
609        debug!(
610            src = %request.src,
611            dst = %request.dst,
612            "Handling prefetch request"
613        );
614
615        // Check if we already have a valid cached path
616        if self
617            .state
618            .cached_path_wait(request.src, request.dst, request.now)
619            .await
620            .is_some()
621        {
622            debug!("Path already cached, skipping prefetch");
623            return;
624        }
625
626        // Check if there is an in-flight request for the same source and destination
627        if self.state.request_inflight(request.src, request.dst) {
628            debug!(
629                src = %request.src,
630                dst = %request.dst,
631                "Path request already in flight, skipping prefetch"
632            );
633            return;
634        }
635
636        // Perform the actual fetching and caching of the path. It might be that in the mean time
637        // another request for the same path has been made, but in that case the path will be cached
638        // by the other request or the prefetch will be coalesced with it.
639        match self
640            .state
641            .fetch_and_cache_path(request.src, request.dst, request.now)
642            .await
643        {
644            Ok(_) => {
645                debug!(
646                    src = %request.src,
647                    dst = %request.dst,
648                    "Successfully prefetched path"
649                );
650            }
651            Err(e) => {
652                error!(
653                    src = %request.src,
654                    dst = %request.dst,
655                    error = %e,
656                    "Failed to prefetch path"
657                );
658            }
659        }
660    }
661}
662
663/// Segment fetch error.
664pub type SegmentFetchError = Box<dyn std::error::Error + Send + Sync>;
665
666/// Path segments.
667pub struct Segments {
668    /// Core segments.
669    pub core_segments: Vec<path::PathSegment>,
670    /// Non-core segments.
671    pub non_core_segments: Vec<path::PathSegment>,
672}
673
674/// Segment fetcher trait.
675pub trait SegmentFetcher {
676    /// Fetch path segments between src and dst.
677    fn fetch_segments<'a>(
678        &'a self,
679        src: IsdAsn,
680        dst: IsdAsn,
681    ) -> impl Future<Output = Result<Segments, SegmentFetchError>> + Send + 'a;
682}
683
684/// Connect RPC segment fetcher.
685pub struct ConnectRpcSegmentFetcher {
686    client: Arc<dyn EndhostApiClient>,
687}
688
689impl ConnectRpcSegmentFetcher {
690    /// Creates a new connect RPC segment fetcher.
691    pub fn new(client: Arc<dyn EndhostApiClient>) -> Self {
692        Self { client }
693    }
694}
695
696impl SegmentFetcher for ConnectRpcSegmentFetcher {
697    async fn fetch_segments(
698        &self,
699        src: IsdAsn,
700        dst: IsdAsn,
701    ) -> Result<Segments, SegmentFetchError> {
702        let resp = self
703            .client
704            .list_segments(src, dst, 128, "".to_string())
705            .await?;
706        tracing::debug!(
707            n_core=resp.core_segments.len(),
708            n_up=resp.up_segments.len(),
709            n_down=resp.down_segments.len(),
710            src = %src,
711            dst = %dst,
712            "Received segments from control plane"
713        );
714        let (core_segments, non_core_segments) = resp.split_parts();
715        Ok(Segments {
716            core_segments,
717            non_core_segments,
718        })
719    }
720}
721
722/// Path fetcher.
723pub struct PathFetcherImpl<F: SegmentFetcher = ConnectRpcSegmentFetcher> {
724    segment_fetcher: F,
725}
726
727impl<F: SegmentFetcher> PathFetcherImpl<F> {
728    /// Creates a new path fetcher.
729    pub fn new(segment_fetcher: F) -> Self {
730        Self { segment_fetcher }
731    }
732}
733
734impl<L: SegmentFetcher + Send + Sync> PathFetcher for PathFetcherImpl<L> {
735    async fn fetch_paths(
736        &self,
737        src: IsdAsn,
738        dst: IsdAsn,
739    ) -> Result<Vec<path::Path>, PathFetchError> {
740        let Segments {
741            core_segments,
742            non_core_segments,
743        } = self.segment_fetcher.fetch_segments(src, dst).await?;
744        trace!(
745            n_core_segments = core_segments.len(),
746            n_non_core_segments = non_core_segments.len(),
747            src = %src,
748            dst = %dst,
749            "Fetched segments"
750        );
751        let paths = path::combinator::combine(src, dst, core_segments, non_core_segments);
752        Ok(paths)
753    }
754}
755
756#[cfg(test)]
757mod tests {
758    use std::{
759        collections::HashMap,
760        sync::{
761            Arc, Mutex,
762            atomic::{AtomicUsize, Ordering},
763        },
764    };
765
766    use bytes::{BufMut, BytesMut};
767    use scion_proto::{
768        address::IsdAsn,
769        packet::ByEndpoint,
770        path::{self, DataPlanePath, EncodedStandardPath, Path},
771        wire_encoding::WireDecode,
772    };
773    use tokio::{sync::Barrier, task::yield_now};
774
775    use super::*;
776    use crate::path::policy;
777
778    type PathMap = HashMap<(IsdAsn, IsdAsn), Result<Vec<Path>, PathFetchError>>;
779    #[derive(Default)]
780    struct MockPathFetcher {
781        paths: Mutex<PathMap>,
782        call_count: AtomicUsize,
783        call_delay: Option<usize>,
784        barrier: Option<Arc<Barrier>>,
785    }
786
787    impl MockPathFetcher {
788        fn with_path(src: IsdAsn, dst: IsdAsn, path: Path) -> Self {
789            let mut paths = HashMap::new();
790            paths.insert((src, dst), Ok(vec![path]));
791            Self {
792                paths: Mutex::new(paths),
793                call_count: AtomicUsize::new(0),
794                call_delay: None,
795                barrier: None,
796            }
797        }
798
799        fn with_error(src: IsdAsn, dst: IsdAsn, error: &'static str) -> Self {
800            let mut paths = HashMap::new();
801            paths.insert((src, dst), Err(PathFetchError::FetchSegments(error.into())));
802            Self {
803                paths: Mutex::new(paths),
804                call_count: AtomicUsize::new(0),
805                call_delay: None,
806                barrier: None,
807            }
808        }
809
810        fn with_barrier(mut self, barrier: Arc<Barrier>) -> Self {
811            self.barrier = Some(barrier);
812            self
813        }
814    }
815
816    impl PathFetcher for MockPathFetcher {
817        fn fetch_paths(
818            &self,
819            src: IsdAsn,
820            dst: IsdAsn,
821        ) -> impl ResFut<'_, Vec<path::Path>, PathFetchError> {
822            async move {
823                self.call_count.fetch_add(1, Ordering::Relaxed);
824                if let Some(delay) = self.call_delay {
825                    while self.call_count.load(Ordering::SeqCst) < delay {
826                        yield_now().await;
827                    }
828                }
829                if let Some(barrier) = &self.barrier {
830                    barrier.wait().await;
831                }
832                match self.paths.lock().unwrap().get(&(src, dst)) {
833                    Some(Ok(paths)) => Ok(paths.clone()),
834                    None => Ok(vec![]),
835                    Some(Err(_)) => Err(PathFetchError::FetchSegments("other error".into())),
836                }
837            }
838        }
839    }
840
841    fn test_path(src: IsdAsn, dst: IsdAsn) -> Path {
842        let mut path_raw = BytesMut::with_capacity(36);
843        path_raw.put_u32(0x0000_2000);
844        path_raw.put_slice(&[0_u8; 32]);
845        let dp_path =
846            DataPlanePath::Standard(EncodedStandardPath::decode(&mut path_raw.freeze()).unwrap());
847
848        Path::new(
849            dp_path,
850            ByEndpoint {
851                source: src,
852                destination: dst,
853            },
854            None,
855        )
856    }
857
858    fn setup_pm(fetcher: MockPathFetcher) -> CachingPathManagerState<Shortest, MockPathFetcher> {
859        CachingPathManagerState(Arc::new(CachingPathManagerStateInner {
860            policy: policy::Shortest {},
861            fetcher,
862            path_cache: HashIndex::new(),
863            inflight: HashIndex::new(),
864        }))
865    }
866
867    #[tokio::test]
868    async fn fetch_and_cache_path_single_request_success() {
869        let src = IsdAsn(0x1_ff00_0000_0110);
870        let dst = IsdAsn(0x1_ff00_0000_0111);
871        let path = test_path(src, dst);
872        let fetcher = MockPathFetcher::with_path(src, dst, path.clone());
873        let state = setup_pm(fetcher);
874
875        let result = state.fetch_and_cache_path(src, dst, Utc::now()).await;
876
877        assert!(result.is_ok());
878        assert_eq!(state.fetcher.call_count.load(Ordering::SeqCst), 1);
879        let guard = Guard::new();
880        assert!(state.path_cache.peek(&(src, dst), &guard).is_some());
881        assert!(state.inflight.peek(&(src, dst), &guard).is_none());
882    }
883
884    #[tokio::test]
885    async fn fetch_and_cache_path_concurrent_requests_coalesced() {
886        let src = IsdAsn(0x1_ff00_0000_0110);
887        let dst = IsdAsn(0x1_ff00_0000_0111);
888        let path = test_path(src, dst);
889        let barrier = Arc::new(Barrier::new(2));
890        let fetcher =
891            MockPathFetcher::with_path(src, dst, path.clone()).with_barrier(barrier.clone());
892        let state = setup_pm(fetcher);
893
894        let state_clone = state.clone();
895        let task1 =
896            tokio::spawn(
897                async move { state_clone.fetch_and_cache_path(src, dst, Utc::now()).await },
898            );
899        // Wait for the first task to start the fetch operation.
900        while state.fetcher.call_count.load(Ordering::SeqCst) < 1 {
901            yield_now().await;
902        }
903
904        let state_clone2 = state.clone();
905        let task2 = tokio::spawn(async move {
906            state_clone2
907                .fetch_and_cache_path(src, dst, Utc::now())
908                .await
909        });
910
911        // Unblock the fetcher
912        barrier.wait().await;
913
914        let (res1, res2) = future::join(task1, task2).await;
915
916        assert_eq!(state.fetcher.call_count.load(Ordering::SeqCst), 1);
917        res1.unwrap().unwrap();
918        res2.unwrap().unwrap();
919        let guard = Guard::new();
920        assert!(state.inflight.peek(&(src, dst), &guard).is_none());
921    }
922
923    #[tokio::test]
924    async fn fetch_and_cache_path_fetch_error() {
925        let src = IsdAsn(0x1_ff00_0000_0110);
926        let dst = IsdAsn(0x1_ff00_0000_0111);
927        let fetcher = MockPathFetcher::with_error(src, dst, "error");
928        let state = setup_pm(fetcher);
929
930        let result = state.fetch_and_cache_path(src, dst, Utc::now()).await;
931
932        assert!(matches!(result, Err(PathWaitError::FetchFailed(_))));
933        assert_eq!(state.fetcher.call_count.load(Ordering::SeqCst), 1);
934        let guard = Guard::new();
935        assert!(state.path_cache.peek(&(src, dst), &guard).is_none());
936        assert!(state.inflight.peek(&(src, dst), &guard).is_none());
937    }
938
939    #[tokio::test]
940    async fn fetch_and_cache_path_no_path_found() {
941        let src = IsdAsn(0x1_ff00_0000_0110);
942        let dst = IsdAsn(0x1_ff00_0000_0111);
943        let fetcher = MockPathFetcher::default();
944        let state = setup_pm(fetcher);
945
946        let result = state.fetch_and_cache_path(src, dst, Utc::now()).await;
947
948        assert!(matches!(result, Err(PathWaitError::NoPathFound)));
949        assert_eq!(state.fetcher.call_count.load(Ordering::SeqCst), 1);
950    }
951
952    #[tokio::test]
953    async fn fetch_and_cache_path_concurrent_requests_different_keys() {
954        let src1 = IsdAsn(0x1_ff00_0000_0110);
955        let dst1 = IsdAsn(0x1_ff00_0000_0111);
956        let src2 = IsdAsn(0x1_ff00_0000_0120);
957        let dst2 = IsdAsn(0x1_ff00_0000_0121);
958        let path1 = test_path(src1, dst1);
959        let path2 = test_path(src2, dst2);
960
961        let mut paths = HashMap::new();
962        paths.insert((src1, dst1), Ok(vec![path1.clone()]));
963        paths.insert((src2, dst2), Ok(vec![path2.clone()]));
964
965        let barrier = Arc::new(Barrier::new(3));
966
967        let fetcher = MockPathFetcher {
968            paths: Mutex::new(paths),
969            ..Default::default()
970        }
971        .with_barrier(barrier.clone());
972        let state = setup_pm(fetcher);
973
974        let state_clone1 = state.clone();
975        let task1 = tokio::spawn(async move {
976            state_clone1
977                .fetch_and_cache_path(src1, dst1, Utc::now())
978                .await
979        });
980
981        let state_clone2 = state.clone();
982        let task2 = tokio::spawn(async move {
983            state_clone2
984                .fetch_and_cache_path(src2, dst2, Utc::now())
985                .await
986        });
987
988        // Unblock the fetcher
989        barrier.wait().await;
990
991        let (res1, res2) = future::join(task1, task2).await;
992
993        assert_eq!(state.fetcher.call_count.load(Ordering::SeqCst), 2);
994        let got1 = res1.unwrap().unwrap();
995        let got2 = res2.unwrap().unwrap();
996        assert_eq!(got1.source(), path1.source());
997        assert_eq!(got1.destination(), path1.destination());
998        assert_eq!(got2.source(), path2.source());
999        assert_eq!(got2.destination(), path2.destination());
1000    }
1001}