1use std::collections::HashMap;
4use std::future::Future;
5use std::path::Path;
6use std::sync::Arc;
7use std::time::Duration;
8
9use tokio::sync::{mpsc::Sender, Mutex, OwnedMutexGuard};
10use tokio::time::sleep;
11use tracing::{debug, info};
12
13use super::bbox::BoundingBox;
14use super::cache::{
15 cache, in_flight_loads, record_disk_hit, record_in_flight_wait, record_load_request,
16 record_memory_hit, record_network_fetch, CachedEdge, CachedNetwork, CachedNode, NetworkRef,
17 CACHE_VERSION,
18};
19use super::config::{ConnectivityPolicy, NetworkConfig};
20use super::coord::Coord;
21use super::error::RoutingError;
22use super::network::{EdgeData, RoadNetwork};
23use super::osm::OverpassResponse;
24use super::progress::RoutingProgress;
25
26impl RoadNetwork {
27 pub async fn load_or_fetch(
28 bbox: &BoundingBox,
29 config: &NetworkConfig,
30 progress: Option<&Sender<RoutingProgress>>,
31 ) -> Result<NetworkRef, RoutingError> {
32 let cache_key = bbox.cache_key();
33 record_load_request();
34
35 if let Some(tx) = progress {
36 let _ = tx.send(RoutingProgress::CheckingCache { percent: 0 }).await;
37 }
38
39 {
40 let cache_guard = cache().read().await;
41 if cache_guard.contains_key(&cache_key) {
42 record_memory_hit();
43 info!("Using in-memory cached road network for {}", cache_key);
44 if let Some(tx) = progress {
45 let _ = tx
46 .send(RoutingProgress::CheckingCache { percent: 10 })
47 .await;
48 }
49 return Ok(NetworkRef::new(cache_guard, cache_key));
50 }
51 }
52 if let Some(tx) = progress {
53 let _ = tx.send(RoutingProgress::CheckingCache { percent: 5 }).await;
54 }
55
56 Self::load_or_insert(cache_key, async {
57 tokio::fs::create_dir_all(&config.cache_dir).await?;
58 let cache_path = config.cache_dir.join(format!("{}.json", bbox.cache_key()));
59
60 if tokio::fs::try_exists(&cache_path).await.unwrap_or(false) {
61 info!("Loading road network from file cache: {:?}", cache_path);
62 if let Some(tx) = progress {
63 let _ = tx.send(RoutingProgress::CheckingCache { percent: 8 }).await;
64 }
65 match Self::load_from_file(&cache_path, config).await {
66 Ok(network) => {
67 record_disk_hit();
68 if let Some(tx) = progress {
69 let _ = tx
70 .send(RoutingProgress::BuildingGraph { percent: 50 })
71 .await;
72 }
73 return Ok(network);
74 }
75 Err(e) => info!("File cache invalid ({}), downloading fresh", e),
76 }
77 } else {
78 info!("Downloading road network from Overpass API");
79 }
80
81 record_network_fetch();
82 let network = Self::fetch_from_api(bbox, config, progress).await?;
83 network.save_to_file(&cache_path).await?;
84 info!("Saved road network to file cache: {:?}", cache_path);
85 Ok(network)
86 })
87 .await
88 }
89
90 pub async fn fetch(
91 bbox: &BoundingBox,
92 config: &NetworkConfig,
93 progress: Option<&Sender<RoutingProgress>>,
94 ) -> Result<Self, RoutingError> {
95 Self::fetch_from_api(bbox, config, progress).await
96 }
97
98 async fn fetch_from_api(
99 bbox: &BoundingBox,
100 config: &NetworkConfig,
101 progress: Option<&Sender<RoutingProgress>>,
102 ) -> Result<Self, RoutingError> {
103 let highway_regex = config.highway_regex();
104 let query = format!(
105 r#"[out:json][timeout:120];
106(
107 way["highway"~"{}"]
108 ({},{},{},{});
109);
110(._;>;);
111out body;"#,
112 highway_regex, bbox.min_lat, bbox.min_lng, bbox.max_lat, bbox.max_lng
113 );
114
115 debug!("Overpass query:\n{}", query);
116 info!(
117 "Preparing Overpass query for bbox: {:.4},{:.4} to {:.4},{:.4}",
118 bbox.min_lat, bbox.min_lng, bbox.max_lat, bbox.max_lng
119 );
120
121 if let Some(tx) = progress {
122 let _ = tx
123 .send(RoutingProgress::DownloadingNetwork {
124 percent: 10,
125 bytes: 0,
126 })
127 .await;
128 }
129
130 let client = reqwest::Client::builder()
131 .connect_timeout(config.connect_timeout)
132 .read_timeout(config.read_timeout)
133 .timeout(config.read_timeout)
134 .user_agent("SolverForge/0.5.0")
135 .build()
136 .map_err(|e| RoutingError::Network(e.to_string()))?;
137
138 if let Some(tx) = progress {
139 let _ = tx
140 .send(RoutingProgress::DownloadingNetwork {
141 percent: 15,
142 bytes: 0,
143 })
144 .await;
145 }
146
147 let bytes = fetch_overpass_bytes(&client, &query, config, progress).await?;
148
149 let bytes_len = bytes.len();
150 if let Some(tx) = progress {
151 let _ = tx
152 .send(RoutingProgress::DownloadingNetwork {
153 percent: 30,
154 bytes: bytes_len,
155 })
156 .await;
157 }
158
159 if let Some(tx) = progress {
160 let _ = tx
161 .send(RoutingProgress::ParsingOsm {
162 percent: 32,
163 nodes: 0,
164 edges: 0,
165 })
166 .await;
167 }
168
169 let osm_data: OverpassResponse =
170 serde_json::from_slice(&bytes).map_err(|e| RoutingError::Parse(e.to_string()))?;
171
172 info!("Downloaded {} OSM elements", osm_data.elements.len());
173
174 if let Some(tx) = progress {
175 let _ = tx
176 .send(RoutingProgress::ParsingOsm {
177 percent: 35,
178 nodes: osm_data.elements.len(),
179 edges: 0,
180 })
181 .await;
182 }
183
184 if let Some(tx) = progress {
185 let _ = tx
186 .send(RoutingProgress::BuildingGraph { percent: 40 })
187 .await;
188 }
189
190 let network = Self::build_from_osm(&osm_data, config)?;
191
192 if let Some(tx) = progress {
193 let _ = tx
194 .send(RoutingProgress::BuildingGraph { percent: 50 })
195 .await;
196 let _ = tx.send(RoutingProgress::Complete).await;
197 }
198
199 Ok(network)
200 }
201
202 pub(super) fn build_from_osm(
203 osm: &OverpassResponse,
204 config: &NetworkConfig,
205 ) -> Result<Self, RoutingError> {
206 let mut network = Self::new();
207
208 let mut nodes: HashMap<i64, (f64, f64)> = HashMap::new();
209 for elem in &osm.elements {
210 if elem.elem_type == "node" {
211 if let (Some(lat), Some(lon)) = (elem.lat, elem.lon) {
212 nodes.insert(elem.id, (lat, lon));
213 }
214 }
215 }
216
217 info!("Parsed {} nodes", nodes.len());
218
219 let mut way_count = 0;
220 for elem in &osm.elements {
221 if elem.elem_type == "way" {
222 if let Some(ref node_ids) = elem.nodes {
223 let highway = elem.tags.as_ref().and_then(|t| t.highway.as_deref());
224 let oneway = elem.tags.as_ref().and_then(|t| t.oneway.as_deref());
225 let maxspeed = elem.tags.as_ref().and_then(|t| t.maxspeed.as_deref());
226 let speed = config
227 .speed_profile
228 .speed_mps(maxspeed, highway.unwrap_or("residential"));
229 let is_oneway_forward = matches!(oneway, Some("yes") | Some("1"));
230 let is_oneway_reverse = matches!(oneway, Some("-1"));
231
232 for window in node_ids.windows(2) {
233 let n1_id = window[0];
234 let n2_id = window[1];
235
236 let Some(&(lat1, lng1)) = nodes.get(&n1_id) else {
237 continue;
238 };
239 let Some(&(lat2, lng2)) = nodes.get(&n2_id) else {
240 continue;
241 };
242
243 let idx1 = network.get_or_create_node(lat1, lng1);
244 let idx2 = network.get_or_create_node(lat2, lng2);
245
246 let coord1 = Coord::new(lat1, lng1);
247 let coord2 = Coord::new(lat2, lng2);
248 let distance = super::geo::haversine_distance(coord1, coord2);
249 let travel_time = distance / speed;
250
251 let edge_data = EdgeData {
252 travel_time_s: travel_time,
253 distance_m: distance,
254 };
255
256 if is_oneway_reverse {
257 network.add_edge(idx2, idx1, edge_data);
259 } else {
260 network.add_edge(idx1, idx2, edge_data.clone());
262 if !is_oneway_forward {
263 network.add_edge(idx2, idx1, edge_data);
265 }
266 }
267 }
268
269 way_count += 1;
270 }
271 }
272 }
273
274 info!(
275 "Built graph with {} nodes and {} edges from {} ways",
276 network.node_count(),
277 network.edge_count(),
278 way_count
279 );
280
281 let scc_count = network.strongly_connected_components();
282 match config.connectivity_policy {
283 ConnectivityPolicy::KeepAll => {
284 if scc_count > 1 {
285 info!(
286 "Road network has {} SCCs, preserving all components by configuration",
287 scc_count
288 );
289 }
290 }
291 ConnectivityPolicy::LargestStronglyConnectedComponent => {
292 if scc_count > 1 {
293 info!(
294 "Road network has {} SCCs, filtering to largest component",
295 scc_count
296 );
297 network.filter_to_largest_scc();
298 info!(
299 "After SCC filter: {} nodes, {} edges",
300 network.node_count(),
301 network.edge_count()
302 );
303 }
304 }
305 }
306
307 network.build_spatial_index();
308
309 Ok(network)
310 }
311
312 async fn load_from_file(path: &Path, config: &NetworkConfig) -> Result<Self, RoutingError> {
313 let data = tokio::fs::read_to_string(path).await?;
314
315 let cached: CachedNetwork = match serde_json::from_str(&data) {
316 Ok(c) => c,
317 Err(e) => {
318 info!("Cache file corrupted, will re-download: {}", e);
319 let _ = tokio::fs::remove_file(path).await;
320 return Err(RoutingError::Parse(e.to_string()));
321 }
322 };
323
324 if cached.version != CACHE_VERSION {
325 info!(
326 "Cache version mismatch (got {}, need {}), will re-download",
327 cached.version, CACHE_VERSION
328 );
329 let _ = tokio::fs::remove_file(path).await;
330 return Err(RoutingError::Parse("cache version mismatch".into()));
331 }
332
333 let mut network = Self::new();
334
335 for node in &cached.nodes {
336 network.add_node_at(node.lat, node.lng);
337 }
338
339 for edge in &cached.edges {
340 network.add_edge_by_index(edge.from, edge.to, edge.travel_time_s, edge.distance_m);
341 }
342
343 let scc_count = network.strongly_connected_components();
344 match config.connectivity_policy {
345 ConnectivityPolicy::KeepAll => {
346 if scc_count > 1 {
347 info!(
348 "Cached network has {} SCCs, preserving all components by configuration",
349 scc_count
350 );
351 }
352 }
353 ConnectivityPolicy::LargestStronglyConnectedComponent => {
354 if scc_count > 1 {
355 info!(
356 "Cached network has {} SCCs, filtering to largest component",
357 scc_count
358 );
359 network.filter_to_largest_scc();
360 info!(
361 "After SCC filter: {} nodes, {} edges",
362 network.node_count(),
363 network.edge_count()
364 );
365 }
366 }
367 }
368
369 network.build_spatial_index();
370
371 Ok(network)
372 }
373
374 async fn save_to_file(&self, path: &Path) -> Result<(), RoutingError> {
375 let nodes: Vec<CachedNode> = self
376 .nodes_iter()
377 .map(|(lat, lng)| CachedNode { lat, lng })
378 .collect();
379
380 let edges: Vec<CachedEdge> = self
381 .edges_iter()
382 .map(|(from, to, travel_time_s, distance_m)| CachedEdge {
383 from,
384 to,
385 travel_time_s,
386 distance_m,
387 })
388 .collect();
389
390 let cached = CachedNetwork {
391 version: CACHE_VERSION,
392 nodes,
393 edges,
394 };
395 let data =
396 serde_json::to_string(&cached).map_err(|e| RoutingError::Parse(e.to_string()))?;
397 tokio::fs::write(path, data).await?;
398
399 Ok(())
400 }
401
402 async fn load_or_insert<F>(cache_key: String, load: F) -> Result<NetworkRef, RoutingError>
403 where
404 F: Future<Output = Result<RoadNetwork, RoutingError>>,
405 {
406 if let Some(cached) = Self::get_cached_network(cache_key.clone()).await {
407 record_memory_hit();
408 return Ok(cached);
409 }
410
411 let (slot, _slot_guard, waited) = acquire_in_flight_slot(&cache_key).await;
412 if waited {
413 record_in_flight_wait();
414 }
415
416 if let Some(cached) = Self::get_cached_network(cache_key.clone()).await {
417 record_memory_hit();
418 cleanup_in_flight_slot(&cache_key, &slot).await;
419 return Ok(cached);
420 }
421
422 let network = load.await?;
423
424 {
425 let mut cache_guard = cache().write().await;
426 cache_guard.entry(cache_key.clone()).or_insert(network);
427 }
428
429 cleanup_in_flight_slot(&cache_key, &slot).await;
430
431 Self::get_cached_network(cache_key).await.ok_or_else(|| {
432 RoutingError::Network("cached network disappeared after insertion".to_string())
433 })
434 }
435
436 async fn get_cached_network(cache_key: String) -> Option<NetworkRef> {
437 let cache_guard = cache().read().await;
438 if cache_guard.contains_key(&cache_key) {
439 Some(NetworkRef::new(cache_guard, cache_key))
440 } else {
441 None
442 }
443 }
444}
445
446async fn fetch_overpass_bytes(
447 client: &reqwest::Client,
448 query: &str,
449 config: &NetworkConfig,
450 progress: Option<&Sender<RoutingProgress>>,
451) -> Result<Vec<u8>, RoutingError> {
452 let endpoints = overpass_endpoints(config);
453 let mut failures = Vec::new();
454
455 for (endpoint_index, endpoint) in endpoints.iter().enumerate() {
456 for attempt in 0..=config.overpass_max_retries {
457 info!(
458 "Sending request to Overpass API endpoint {} attempt {}: {}",
459 endpoint_index + 1,
460 attempt + 1,
461 endpoint
462 );
463
464 let response = client
465 .post(endpoint)
466 .body(query.to_owned())
467 .header("Content-Type", "text/plain")
468 .send()
469 .await;
470
471 match response {
472 Ok(response) if response.status().is_success() => {
473 info!(
474 "Received successful Overpass response from {} with status {}",
475 endpoint,
476 response.status()
477 );
478
479 if let Some(tx) = progress {
480 let _ = tx
481 .send(RoutingProgress::DownloadingNetwork {
482 percent: 25,
483 bytes: 0,
484 })
485 .await;
486 }
487
488 return response
489 .bytes()
490 .await
491 .map(|bytes| bytes.to_vec())
492 .map_err(|error| {
493 RoutingError::Network(format!(
494 "Overpass response body read failed from {} on attempt {}: {}",
495 endpoint,
496 attempt + 1,
497 error
498 ))
499 });
500 }
501 Ok(response) => {
502 let status = response.status();
503 failures.push(format!(
504 "{} attempt {} returned HTTP {}",
505 endpoint,
506 attempt + 1,
507 status
508 ));
509
510 if is_retryable_status(status) && attempt < config.overpass_max_retries {
511 sleep(retry_backoff(config.overpass_retry_backoff, attempt)).await;
512 continue;
513 }
514
515 break;
516 }
517 Err(error) => {
518 failures.push(format!(
519 "{} attempt {} failed: {}",
520 endpoint,
521 attempt + 1,
522 error
523 ));
524
525 if is_retryable_error(&error) && attempt < config.overpass_max_retries {
526 sleep(retry_backoff(config.overpass_retry_backoff, attempt)).await;
527 continue;
528 }
529
530 break;
531 }
532 }
533 }
534 }
535
536 Err(RoutingError::Network(format!(
537 "Overpass fetch failed after trying {} endpoint(s): {}",
538 endpoints.len(),
539 failures.join("; ")
540 )))
541}
542
543fn overpass_endpoints(config: &NetworkConfig) -> Vec<String> {
544 if config.overpass_endpoints.is_empty() {
545 vec![config.overpass_url.clone()]
546 } else {
547 config.overpass_endpoints.clone()
548 }
549}
550
551fn retry_backoff(base: Duration, attempt: usize) -> Duration {
552 base.saturating_mul((attempt + 1) as u32)
553}
554
555fn is_retryable_status(status: reqwest::StatusCode) -> bool {
556 status.is_server_error()
557 || status == reqwest::StatusCode::TOO_MANY_REQUESTS
558 || status == reqwest::StatusCode::REQUEST_TIMEOUT
559}
560
561fn is_retryable_error(error: &reqwest::Error) -> bool {
562 error.is_timeout() || error.is_connect() || error.is_request()
563}
564
565async fn acquire_in_flight_slot(cache_key: &str) -> (Arc<Mutex<()>>, OwnedMutexGuard<()>, bool) {
566 let slot = {
567 let mut in_flight = in_flight_loads().lock().await;
568 match in_flight.get(cache_key) {
569 Some(slot) => slot.clone(),
570 None => {
571 let slot = Arc::new(Mutex::new(()));
572 in_flight.insert(cache_key.to_string(), slot.clone());
573 slot
574 }
575 }
576 };
577
578 match slot.clone().try_lock_owned() {
579 Ok(guard) => (slot, guard, false),
580 Err(_) => {
581 let guard = slot.clone().lock_owned().await;
582 (slot, guard, true)
583 }
584 }
585}
586
587async fn cleanup_in_flight_slot(cache_key: &str, slot: &Arc<Mutex<()>>) {
588 let mut in_flight = in_flight_loads().lock().await;
589 let should_remove = in_flight
590 .get(cache_key)
591 .map(|current| Arc::ptr_eq(current, slot) && Arc::strong_count(slot) == 2)
592 .unwrap_or(false);
593
594 if should_remove {
595 in_flight.remove(cache_key);
596 }
597}
598
599#[cfg(test)]
600mod tests {
601 use std::io::{Read, Write};
602 use std::net::TcpListener;
603 use std::sync::atomic::{AtomicUsize, Ordering};
604 use std::sync::Arc;
605 use std::sync::OnceLock;
606 use std::thread;
607 use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
608
609 use tokio::sync::Mutex;
610 use tokio::time::sleep;
611
612 use super::*;
613 use crate::routing::cache::{reset_cache_metrics, CacheStats};
614 use crate::routing::BoundingBox;
615
616 static FETCH_TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
617
618 fn fetch_test_lock() -> &'static Mutex<()> {
619 FETCH_TEST_LOCK.get_or_init(|| Mutex::new(()))
620 }
621
622 fn test_network() -> RoadNetwork {
623 RoadNetwork::from_test_data(&[(0.0, 0.0), (0.0, 0.01)], &[(0, 1, 60.0, 1_000.0)])
624 }
625
626 async fn reset_test_state() {
627 RoadNetwork::clear_cache().await;
628 in_flight_loads().lock().await.clear();
629 reset_cache_metrics();
630 }
631
632 fn unique_cache_dir(prefix: &str) -> std::path::PathBuf {
633 let suffix = SystemTime::now()
634 .duration_since(UNIX_EPOCH)
635 .expect("system time before unix epoch")
636 .as_nanos();
637 std::env::temp_dir().join(format!(
638 "solverforge-maps-{prefix}-{}-{suffix}",
639 std::process::id()
640 ))
641 }
642
643 async fn assert_cache_stats(expected: CacheStats) {
644 let stats = RoadNetwork::cache_stats().await;
645 assert_eq!(stats.networks_cached, expected.networks_cached);
646 assert_eq!(stats.load_requests, expected.load_requests);
647 assert_eq!(stats.memory_hits, expected.memory_hits);
648 assert_eq!(stats.disk_hits, expected.disk_hits);
649 assert_eq!(stats.network_fetches, expected.network_fetches);
650 assert_eq!(stats.in_flight_waits, expected.in_flight_waits);
651 }
652
653 #[tokio::test]
654 async fn load_or_insert_allows_different_keys_to_progress_concurrently() {
655 let _guard = fetch_test_lock().lock().await;
656 reset_test_state().await;
657
658 let start = Instant::now();
659 let first = async {
660 RoadNetwork::load_or_insert("region-a".to_string(), async {
661 sleep(Duration::from_millis(100)).await;
662 Ok(test_network())
663 })
664 .await
665 .map(|network| network.node_count())
666 };
667 let second = async {
668 RoadNetwork::load_or_insert("region-b".to_string(), async {
669 sleep(Duration::from_millis(100)).await;
670 Ok(test_network())
671 })
672 .await
673 .map(|network| network.node_count())
674 };
675 let (left, right) = tokio::join!(first, second);
676 left.expect("first load should succeed");
677 right.expect("second load should succeed");
678
679 assert!(
680 start.elapsed() < Duration::from_millis(180),
681 "different keys should not serialize slow loads"
682 );
683 }
684
685 #[tokio::test]
686 async fn load_or_insert_deduplicates_same_key_work() {
687 let _guard = fetch_test_lock().lock().await;
688 reset_test_state().await;
689
690 let loads = Arc::new(AtomicUsize::new(0));
691
692 let first = {
693 let loads = loads.clone();
694 async move {
695 RoadNetwork::load_or_insert("region-a".to_string(), async move {
696 loads.fetch_add(1, Ordering::Relaxed);
697 sleep(Duration::from_millis(50)).await;
698 Ok(test_network())
699 })
700 .await
701 .map(|network| network.node_count())
702 }
703 };
704 let second = {
705 let loads = loads.clone();
706 async move {
707 RoadNetwork::load_or_insert("region-a".to_string(), async move {
708 loads.fetch_add(1, Ordering::Relaxed);
709 sleep(Duration::from_millis(50)).await;
710 Ok(test_network())
711 })
712 .await
713 .map(|network| network.node_count())
714 }
715 };
716
717 let (left, right) = tokio::join!(first, second);
718 left.expect("first load should succeed");
719 right.expect("second load should succeed");
720
721 assert_eq!(loads.load(Ordering::Relaxed), 1);
722 }
723
724 #[tokio::test]
725 async fn load_or_fetch_records_network_then_memory_hit() {
726 let _guard = fetch_test_lock().lock().await;
727 reset_test_state().await;
728
729 let bbox = BoundingBox::new(39.95, -75.17, 39.96, -75.16);
730 let cache_dir = unique_cache_dir("network-memory");
731 let (endpoint, requests, handle) =
732 spawn_overpass_server(vec![("200 OK", overpass_fixture_json())]);
733 let config = NetworkConfig::new()
734 .overpass_endpoints(vec![endpoint])
735 .cache_dir(&cache_dir)
736 .overpass_max_retries(0);
737
738 let first = RoadNetwork::load_or_fetch(&bbox, &config, None).await;
739 assert!(first.is_ok(), "first load should succeed");
740 let second = RoadNetwork::load_or_fetch(&bbox, &config, None).await;
741 assert!(second.is_ok(), "second load should hit memory cache");
742
743 handle.join().expect("server thread should finish");
744 assert_eq!(requests.load(Ordering::Relaxed), 1);
745
746 assert_cache_stats(CacheStats {
747 networks_cached: 1,
748 total_nodes: 0,
749 total_edges: 0,
750 memory_bytes: 0,
751 load_requests: 2,
752 memory_hits: 1,
753 disk_hits: 0,
754 network_fetches: 1,
755 in_flight_waits: 0,
756 })
757 .await;
758
759 let _ = tokio::fs::remove_dir_all(&cache_dir).await;
760 }
761
762 #[tokio::test]
763 async fn load_or_fetch_records_disk_hit_without_network_fetch() {
764 let _guard = fetch_test_lock().lock().await;
765 reset_test_state().await;
766
767 let bbox = BoundingBox::new(39.95, -75.17, 39.96, -75.16);
768 let cache_dir = unique_cache_dir("disk-hit");
769 tokio::fs::create_dir_all(&cache_dir)
770 .await
771 .expect("cache dir should be created");
772 let cache_path = cache_dir.join(format!("{}.json", bbox.cache_key()));
773 let cached = CachedNetwork {
774 version: CACHE_VERSION,
775 nodes: vec![
776 CachedNode {
777 lat: 39.95,
778 lng: -75.16,
779 },
780 CachedNode {
781 lat: 39.96,
782 lng: -75.17,
783 },
784 ],
785 edges: vec![CachedEdge {
786 from: 0,
787 to: 1,
788 travel_time_s: 60.0,
789 distance_m: 1_000.0,
790 }],
791 };
792 let data = serde_json::to_string(&cached).expect("cached network should serialize");
793 tokio::fs::write(&cache_path, data)
794 .await
795 .expect("cache file should be written");
796
797 let config = NetworkConfig::new().cache_dir(&cache_dir);
798 let network = RoadNetwork::load_or_fetch(&bbox, &config, None).await;
799 assert!(network.is_ok(), "disk cache load should succeed");
800
801 assert_cache_stats(CacheStats {
802 networks_cached: 1,
803 total_nodes: 0,
804 total_edges: 0,
805 memory_bytes: 0,
806 load_requests: 1,
807 memory_hits: 0,
808 disk_hits: 1,
809 network_fetches: 0,
810 in_flight_waits: 0,
811 })
812 .await;
813
814 let _ = tokio::fs::remove_dir_all(&cache_dir).await;
815 }
816
817 #[tokio::test]
818 async fn load_or_fetch_records_waiter_for_same_key_contention() {
819 let _guard = fetch_test_lock().lock().await;
820 reset_test_state().await;
821
822 let bbox = BoundingBox::new(39.95, -75.17, 39.96, -75.16);
823 let cache_dir = unique_cache_dir("waiter");
824 let (endpoint, requests, handle) =
825 spawn_overpass_server(vec![("200 OK", overpass_fixture_json())]);
826 let config = NetworkConfig::new()
827 .overpass_endpoints(vec![endpoint])
828 .cache_dir(&cache_dir)
829 .overpass_max_retries(0);
830
831 let first = RoadNetwork::load_or_fetch(&bbox, &config, None);
832 let second = RoadNetwork::load_or_fetch(&bbox, &config, None);
833 let (left, right) = tokio::join!(first, second);
834 assert!(left.is_ok(), "first concurrent load should succeed");
835 assert!(right.is_ok(), "second concurrent load should succeed");
836
837 handle.join().expect("server thread should finish");
838 assert_eq!(requests.load(Ordering::Relaxed), 1);
839
840 assert_cache_stats(CacheStats {
841 networks_cached: 1,
842 total_nodes: 0,
843 total_edges: 0,
844 memory_bytes: 0,
845 load_requests: 2,
846 memory_hits: 1,
847 disk_hits: 0,
848 network_fetches: 1,
849 in_flight_waits: 1,
850 })
851 .await;
852
853 let _ = tokio::fs::remove_dir_all(&cache_dir).await;
854 }
855
856 #[tokio::test]
857 async fn acquire_in_flight_slot_does_not_count_existing_unlocked_slot_as_wait() {
858 let _guard = fetch_test_lock().lock().await;
859 reset_test_state().await;
860
861 let key = "burst-window";
862 let slot = Arc::new(Mutex::new(()));
863 in_flight_loads()
864 .lock()
865 .await
866 .insert(key.to_string(), slot.clone());
867
868 let (_slot, acquired_guard, waited) = acquire_in_flight_slot(key).await;
869 assert!(
870 !waited,
871 "existing slot without lock contention should not count as a wait"
872 );
873 drop(acquired_guard);
874
875 cleanup_in_flight_slot(key, &slot).await;
876 }
877
878 #[tokio::test]
879 async fn acquire_in_flight_slot_reports_wait_when_lock_is_held() {
880 let _guard = fetch_test_lock().lock().await;
881 reset_test_state().await;
882
883 let key = "held-slot";
884 let slot = Arc::new(Mutex::new(()));
885 let held_guard = slot.clone().lock_owned().await;
886 in_flight_loads()
887 .lock()
888 .await
889 .insert(key.to_string(), slot.clone());
890
891 let waiter = tokio::spawn(async move {
892 let (_slot, guard, waited) = acquire_in_flight_slot(key).await;
893 (guard, waited)
894 });
895
896 tokio::task::yield_now().await;
897 drop(held_guard);
898
899 let (acquired_guard, waited) = waiter
900 .await
901 .expect("waiter task should complete after lock release");
902 assert!(waited, "blocked acquisition should count as a wait");
903 drop(acquired_guard);
904
905 cleanup_in_flight_slot(key, &slot).await;
906 }
907
908 fn overpass_fixture_json() -> &'static str {
909 r#"{
910 "elements": [
911 {"type": "node", "id": 1, "lat": 39.95, "lon": -75.16},
912 {"type": "node", "id": 2, "lat": 39.96, "lon": -75.17},
913 {"type": "way", "id": 10, "nodes": [1, 2], "tags": {"highway": "residential"}}
914 ]
915 }"#
916 }
917
918 fn spawn_overpass_server(
919 responses: Vec<(&'static str, &'static str)>,
920 ) -> (String, Arc<AtomicUsize>, thread::JoinHandle<()>) {
921 let listener = TcpListener::bind("127.0.0.1:0").expect("listener should bind");
922 let address = format!(
923 "http://{}/api/interpreter",
924 listener.local_addr().expect("listener addr")
925 );
926 let requests = Arc::new(AtomicUsize::new(0));
927 let served = requests.clone();
928
929 let handle = thread::spawn(move || {
930 for (status, body) in responses {
931 let (mut stream, _) = listener.accept().expect("connection should arrive");
932 let mut buffer = [0_u8; 4096];
933 let _ = stream.read(&mut buffer);
934 let response = format!(
935 "HTTP/1.1 {}\r\nContent-Length: {}\r\nContent-Type: application/json\r\nConnection: close\r\n\r\n{}",
936 status,
937 body.len(),
938 body
939 );
940 stream
941 .write_all(response.as_bytes())
942 .expect("response should write");
943 served.fetch_add(1, Ordering::Relaxed);
944 }
945 });
946
947 (address, requests, handle)
948 }
949
950 #[tokio::test]
951 async fn fetch_retries_same_endpoint_until_success() {
952 let _guard = fetch_test_lock().lock().await;
953 let (endpoint, requests, handle) = spawn_overpass_server(vec![
954 ("429 Too Many Requests", r#"{"elements":[]}"#),
955 ("200 OK", overpass_fixture_json()),
956 ]);
957
958 let bbox = BoundingBox::try_new(39.94, -75.18, 39.97, -75.15).expect("bbox should build");
959 let config = NetworkConfig::default()
960 .overpass_url(endpoint)
961 .overpass_max_retries(1)
962 .overpass_retry_backoff(Duration::from_millis(1));
963
964 let network = RoadNetwork::fetch(&bbox, &config, None)
965 .await
966 .expect("fetch should succeed after retry");
967
968 assert_eq!(network.node_count(), 2);
969 assert_eq!(requests.load(Ordering::Relaxed), 2);
970 handle.join().expect("server should join");
971 }
972
973 #[tokio::test]
974 async fn fetch_falls_back_to_second_endpoint() {
975 let _guard = fetch_test_lock().lock().await;
976 let (primary, primary_requests, primary_handle) =
977 spawn_overpass_server(vec![("503 Service Unavailable", r#"{"elements":[]}"#)]);
978 let (secondary, secondary_requests, secondary_handle) =
979 spawn_overpass_server(vec![("200 OK", overpass_fixture_json())]);
980
981 let bbox = BoundingBox::try_new(39.94, -75.18, 39.97, -75.15).expect("bbox should build");
982 let config = NetworkConfig::default()
983 .overpass_endpoints(vec![primary, secondary])
984 .overpass_max_retries(0)
985 .overpass_retry_backoff(Duration::from_millis(1));
986
987 let network = RoadNetwork::fetch(&bbox, &config, None)
988 .await
989 .expect("fetch should fall back to second endpoint");
990
991 assert_eq!(network.node_count(), 2);
992 assert_eq!(primary_requests.load(Ordering::Relaxed), 1);
993 assert_eq!(secondary_requests.load(Ordering::Relaxed), 1);
994 primary_handle.join().expect("primary server should join");
995 secondary_handle
996 .join()
997 .expect("secondary server should join");
998 }
999}