1use std::collections::HashMap;
4use std::future::Future;
5use std::path::Path;
6use std::sync::Arc;
7use std::time::Duration;
8
9use tokio::sync::{mpsc::Sender, Mutex, OwnedMutexGuard};
10use tokio::time::sleep;
11use tracing::{debug, info};
12
13use super::bbox::BoundingBox;
14use super::cache::{
15 cache, in_flight_loads, record_hit, record_miss, CachedEdge, CachedNetwork, CachedNode,
16 NetworkRef, CACHE_VERSION,
17};
18use super::config::{ConnectivityPolicy, NetworkConfig};
19use super::coord::Coord;
20use super::error::RoutingError;
21use super::network::{EdgeData, RoadNetwork};
22use super::osm::OverpassResponse;
23use super::progress::RoutingProgress;
24
25impl RoadNetwork {
26 pub async fn load_or_fetch(
27 bbox: &BoundingBox,
28 config: &NetworkConfig,
29 progress: Option<&Sender<RoutingProgress>>,
30 ) -> Result<NetworkRef, RoutingError> {
31 let cache_key = bbox.cache_key();
32
33 if let Some(tx) = progress {
34 let _ = tx.send(RoutingProgress::CheckingCache { percent: 0 }).await;
35 }
36
37 {
38 let cache_guard = cache().read().await;
39 if cache_guard.contains_key(&cache_key) {
40 record_hit();
41 info!("Using in-memory cached road network for {}", cache_key);
42 if let Some(tx) = progress {
43 let _ = tx
44 .send(RoutingProgress::CheckingCache { percent: 10 })
45 .await;
46 }
47 return Ok(NetworkRef::new(cache_guard, cache_key));
48 }
49 }
50 record_miss();
51
52 if let Some(tx) = progress {
53 let _ = tx.send(RoutingProgress::CheckingCache { percent: 5 }).await;
54 }
55
56 Self::load_or_insert(cache_key, async {
57 tokio::fs::create_dir_all(&config.cache_dir).await?;
58 let cache_path = config.cache_dir.join(format!("{}.json", bbox.cache_key()));
59
60 if tokio::fs::try_exists(&cache_path).await.unwrap_or(false) {
61 info!("Loading road network from file cache: {:?}", cache_path);
62 if let Some(tx) = progress {
63 let _ = tx.send(RoutingProgress::CheckingCache { percent: 8 }).await;
64 }
65 match Self::load_from_file(&cache_path, config).await {
66 Ok(network) => {
67 if let Some(tx) = progress {
68 let _ = tx
69 .send(RoutingProgress::BuildingGraph { percent: 50 })
70 .await;
71 }
72 return Ok(network);
73 }
74 Err(e) => info!("File cache invalid ({}), downloading fresh", e),
75 }
76 } else {
77 info!("Downloading road network from Overpass API");
78 }
79
80 let network = Self::fetch_from_api(bbox, config, progress).await?;
81 network.save_to_file(&cache_path).await?;
82 info!("Saved road network to file cache: {:?}", cache_path);
83 Ok(network)
84 })
85 .await
86 }
87
88 pub async fn fetch(
89 bbox: &BoundingBox,
90 config: &NetworkConfig,
91 progress: Option<&Sender<RoutingProgress>>,
92 ) -> Result<Self, RoutingError> {
93 Self::fetch_from_api(bbox, config, progress).await
94 }
95
96 async fn fetch_from_api(
97 bbox: &BoundingBox,
98 config: &NetworkConfig,
99 progress: Option<&Sender<RoutingProgress>>,
100 ) -> Result<Self, RoutingError> {
101 let highway_regex = config.highway_regex();
102 let query = format!(
103 r#"[out:json][timeout:120];
104(
105 way["highway"~"{}"]
106 ({},{},{},{});
107);
108(._;>;);
109out body;"#,
110 highway_regex, bbox.min_lat, bbox.min_lng, bbox.max_lat, bbox.max_lng
111 );
112
113 debug!("Overpass query:\n{}", query);
114 info!(
115 "Preparing Overpass query for bbox: {:.4},{:.4} to {:.4},{:.4}",
116 bbox.min_lat, bbox.min_lng, bbox.max_lat, bbox.max_lng
117 );
118
119 if let Some(tx) = progress {
120 let _ = tx
121 .send(RoutingProgress::DownloadingNetwork {
122 percent: 10,
123 bytes: 0,
124 })
125 .await;
126 }
127
128 let client = reqwest::Client::builder()
129 .connect_timeout(config.connect_timeout)
130 .read_timeout(config.read_timeout)
131 .timeout(config.read_timeout)
132 .user_agent("SolverForge/0.5.0")
133 .build()
134 .map_err(|e| RoutingError::Network(e.to_string()))?;
135
136 if let Some(tx) = progress {
137 let _ = tx
138 .send(RoutingProgress::DownloadingNetwork {
139 percent: 15,
140 bytes: 0,
141 })
142 .await;
143 }
144
145 let bytes = fetch_overpass_bytes(&client, &query, config, progress).await?;
146
147 let bytes_len = bytes.len();
148 if let Some(tx) = progress {
149 let _ = tx
150 .send(RoutingProgress::DownloadingNetwork {
151 percent: 30,
152 bytes: bytes_len,
153 })
154 .await;
155 }
156
157 if let Some(tx) = progress {
158 let _ = tx
159 .send(RoutingProgress::ParsingOsm {
160 percent: 32,
161 nodes: 0,
162 edges: 0,
163 })
164 .await;
165 }
166
167 let osm_data: OverpassResponse =
168 serde_json::from_slice(&bytes).map_err(|e| RoutingError::Parse(e.to_string()))?;
169
170 info!("Downloaded {} OSM elements", osm_data.elements.len());
171
172 if let Some(tx) = progress {
173 let _ = tx
174 .send(RoutingProgress::ParsingOsm {
175 percent: 35,
176 nodes: osm_data.elements.len(),
177 edges: 0,
178 })
179 .await;
180 }
181
182 if let Some(tx) = progress {
183 let _ = tx
184 .send(RoutingProgress::BuildingGraph { percent: 40 })
185 .await;
186 }
187
188 let network = Self::build_from_osm(&osm_data, config)?;
189
190 if let Some(tx) = progress {
191 let _ = tx
192 .send(RoutingProgress::BuildingGraph { percent: 50 })
193 .await;
194 let _ = tx.send(RoutingProgress::Complete).await;
195 }
196
197 Ok(network)
198 }
199
200 pub(super) fn build_from_osm(
201 osm: &OverpassResponse,
202 config: &NetworkConfig,
203 ) -> Result<Self, RoutingError> {
204 let mut network = Self::new();
205
206 let mut nodes: HashMap<i64, (f64, f64)> = HashMap::new();
207 for elem in &osm.elements {
208 if elem.elem_type == "node" {
209 if let (Some(lat), Some(lon)) = (elem.lat, elem.lon) {
210 nodes.insert(elem.id, (lat, lon));
211 }
212 }
213 }
214
215 info!("Parsed {} nodes", nodes.len());
216
217 let mut way_count = 0;
218 for elem in &osm.elements {
219 if elem.elem_type == "way" {
220 if let Some(ref node_ids) = elem.nodes {
221 let highway = elem.tags.as_ref().and_then(|t| t.highway.as_deref());
222 let oneway = elem.tags.as_ref().and_then(|t| t.oneway.as_deref());
223 let maxspeed = elem.tags.as_ref().and_then(|t| t.maxspeed.as_deref());
224 let speed = config
225 .speed_profile
226 .speed_mps(maxspeed, highway.unwrap_or("residential"));
227 let is_oneway_forward = matches!(oneway, Some("yes") | Some("1"));
228 let is_oneway_reverse = matches!(oneway, Some("-1"));
229
230 for window in node_ids.windows(2) {
231 let n1_id = window[0];
232 let n2_id = window[1];
233
234 let Some(&(lat1, lng1)) = nodes.get(&n1_id) else {
235 continue;
236 };
237 let Some(&(lat2, lng2)) = nodes.get(&n2_id) else {
238 continue;
239 };
240
241 let idx1 = network.get_or_create_node(lat1, lng1);
242 let idx2 = network.get_or_create_node(lat2, lng2);
243
244 let coord1 = Coord::new(lat1, lng1);
245 let coord2 = Coord::new(lat2, lng2);
246 let distance = super::geo::haversine_distance(coord1, coord2);
247 let travel_time = distance / speed;
248
249 let edge_data = EdgeData {
250 travel_time_s: travel_time,
251 distance_m: distance,
252 };
253
254 if is_oneway_reverse {
255 network.add_edge(idx2, idx1, edge_data);
257 } else {
258 network.add_edge(idx1, idx2, edge_data.clone());
260 if !is_oneway_forward {
261 network.add_edge(idx2, idx1, edge_data);
263 }
264 }
265 }
266
267 way_count += 1;
268 }
269 }
270 }
271
272 info!(
273 "Built graph with {} nodes and {} edges from {} ways",
274 network.node_count(),
275 network.edge_count(),
276 way_count
277 );
278
279 let scc_count = network.strongly_connected_components();
280 match config.connectivity_policy {
281 ConnectivityPolicy::KeepAll => {
282 if scc_count > 1 {
283 info!(
284 "Road network has {} SCCs, preserving all components by configuration",
285 scc_count
286 );
287 }
288 }
289 ConnectivityPolicy::LargestStronglyConnectedComponent => {
290 if scc_count > 1 {
291 info!(
292 "Road network has {} SCCs, filtering to largest component",
293 scc_count
294 );
295 network.filter_to_largest_scc();
296 info!(
297 "After SCC filter: {} nodes, {} edges",
298 network.node_count(),
299 network.edge_count()
300 );
301 }
302 }
303 }
304
305 network.build_spatial_index();
306
307 Ok(network)
308 }
309
310 async fn load_from_file(path: &Path, config: &NetworkConfig) -> Result<Self, RoutingError> {
311 let data = tokio::fs::read_to_string(path).await?;
312
313 let cached: CachedNetwork = match serde_json::from_str(&data) {
314 Ok(c) => c,
315 Err(e) => {
316 info!("Cache file corrupted, will re-download: {}", e);
317 let _ = tokio::fs::remove_file(path).await;
318 return Err(RoutingError::Parse(e.to_string()));
319 }
320 };
321
322 if cached.version != CACHE_VERSION {
323 info!(
324 "Cache version mismatch (got {}, need {}), will re-download",
325 cached.version, CACHE_VERSION
326 );
327 let _ = tokio::fs::remove_file(path).await;
328 return Err(RoutingError::Parse("cache version mismatch".into()));
329 }
330
331 let mut network = Self::new();
332
333 for node in &cached.nodes {
334 network.add_node_at(node.lat, node.lng);
335 }
336
337 for edge in &cached.edges {
338 network.add_edge_by_index(edge.from, edge.to, edge.travel_time_s, edge.distance_m);
339 }
340
341 let scc_count = network.strongly_connected_components();
342 match config.connectivity_policy {
343 ConnectivityPolicy::KeepAll => {
344 if scc_count > 1 {
345 info!(
346 "Cached network has {} SCCs, preserving all components by configuration",
347 scc_count
348 );
349 }
350 }
351 ConnectivityPolicy::LargestStronglyConnectedComponent => {
352 if scc_count > 1 {
353 info!(
354 "Cached network has {} SCCs, filtering to largest component",
355 scc_count
356 );
357 network.filter_to_largest_scc();
358 info!(
359 "After SCC filter: {} nodes, {} edges",
360 network.node_count(),
361 network.edge_count()
362 );
363 }
364 }
365 }
366
367 network.build_spatial_index();
368
369 Ok(network)
370 }
371
372 async fn save_to_file(&self, path: &Path) -> Result<(), RoutingError> {
373 let nodes: Vec<CachedNode> = self
374 .nodes_iter()
375 .map(|(lat, lng)| CachedNode { lat, lng })
376 .collect();
377
378 let edges: Vec<CachedEdge> = self
379 .edges_iter()
380 .map(|(from, to, travel_time_s, distance_m)| CachedEdge {
381 from,
382 to,
383 travel_time_s,
384 distance_m,
385 })
386 .collect();
387
388 let cached = CachedNetwork {
389 version: CACHE_VERSION,
390 nodes,
391 edges,
392 };
393 let data =
394 serde_json::to_string(&cached).map_err(|e| RoutingError::Parse(e.to_string()))?;
395 tokio::fs::write(path, data).await?;
396
397 Ok(())
398 }
399
400 async fn load_or_insert<F>(cache_key: String, load: F) -> Result<NetworkRef, RoutingError>
401 where
402 F: Future<Output = Result<RoadNetwork, RoutingError>>,
403 {
404 if let Some(cached) = Self::get_cached_network(cache_key.clone()).await {
405 return Ok(cached);
406 }
407
408 record_miss();
409
410 let (slot, _slot_guard) = acquire_in_flight_slot(&cache_key).await;
411
412 if let Some(cached) = Self::get_cached_network(cache_key.clone()).await {
413 cleanup_in_flight_slot(&cache_key, &slot).await;
414 return Ok(cached);
415 }
416
417 let network = load.await?;
418
419 {
420 let mut cache_guard = cache().write().await;
421 cache_guard.entry(cache_key.clone()).or_insert(network);
422 }
423
424 cleanup_in_flight_slot(&cache_key, &slot).await;
425
426 Self::get_cached_network(cache_key).await.ok_or_else(|| {
427 RoutingError::Network("cached network disappeared after insertion".to_string())
428 })
429 }
430
431 async fn get_cached_network(cache_key: String) -> Option<NetworkRef> {
432 let cache_guard = cache().read().await;
433 if cache_guard.contains_key(&cache_key) {
434 record_hit();
435 Some(NetworkRef::new(cache_guard, cache_key))
436 } else {
437 None
438 }
439 }
440}
441
442async fn fetch_overpass_bytes(
443 client: &reqwest::Client,
444 query: &str,
445 config: &NetworkConfig,
446 progress: Option<&Sender<RoutingProgress>>,
447) -> Result<Vec<u8>, RoutingError> {
448 let endpoints = overpass_endpoints(config);
449 let mut failures = Vec::new();
450
451 for (endpoint_index, endpoint) in endpoints.iter().enumerate() {
452 for attempt in 0..=config.overpass_max_retries {
453 info!(
454 "Sending request to Overpass API endpoint {} attempt {}: {}",
455 endpoint_index + 1,
456 attempt + 1,
457 endpoint
458 );
459
460 let response = client
461 .post(endpoint)
462 .body(query.to_owned())
463 .header("Content-Type", "text/plain")
464 .send()
465 .await;
466
467 match response {
468 Ok(response) if response.status().is_success() => {
469 info!(
470 "Received successful Overpass response from {} with status {}",
471 endpoint,
472 response.status()
473 );
474
475 if let Some(tx) = progress {
476 let _ = tx
477 .send(RoutingProgress::DownloadingNetwork {
478 percent: 25,
479 bytes: 0,
480 })
481 .await;
482 }
483
484 return response
485 .bytes()
486 .await
487 .map(|bytes| bytes.to_vec())
488 .map_err(|error| {
489 RoutingError::Network(format!(
490 "Overpass response body read failed from {} on attempt {}: {}",
491 endpoint,
492 attempt + 1,
493 error
494 ))
495 });
496 }
497 Ok(response) => {
498 let status = response.status();
499 failures.push(format!(
500 "{} attempt {} returned HTTP {}",
501 endpoint,
502 attempt + 1,
503 status
504 ));
505
506 if is_retryable_status(status) && attempt < config.overpass_max_retries {
507 sleep(retry_backoff(config.overpass_retry_backoff, attempt)).await;
508 continue;
509 }
510
511 break;
512 }
513 Err(error) => {
514 failures.push(format!(
515 "{} attempt {} failed: {}",
516 endpoint,
517 attempt + 1,
518 error
519 ));
520
521 if is_retryable_error(&error) && attempt < config.overpass_max_retries {
522 sleep(retry_backoff(config.overpass_retry_backoff, attempt)).await;
523 continue;
524 }
525
526 break;
527 }
528 }
529 }
530 }
531
532 Err(RoutingError::Network(format!(
533 "Overpass fetch failed after trying {} endpoint(s): {}",
534 endpoints.len(),
535 failures.join("; ")
536 )))
537}
538
539fn overpass_endpoints(config: &NetworkConfig) -> Vec<String> {
540 if config.overpass_endpoints.is_empty() {
541 vec![config.overpass_url.clone()]
542 } else {
543 config.overpass_endpoints.clone()
544 }
545}
546
547fn retry_backoff(base: Duration, attempt: usize) -> Duration {
548 base.saturating_mul((attempt + 1) as u32)
549}
550
551fn is_retryable_status(status: reqwest::StatusCode) -> bool {
552 status.is_server_error()
553 || status == reqwest::StatusCode::TOO_MANY_REQUESTS
554 || status == reqwest::StatusCode::REQUEST_TIMEOUT
555}
556
557fn is_retryable_error(error: &reqwest::Error) -> bool {
558 error.is_timeout() || error.is_connect() || error.is_request()
559}
560
561async fn acquire_in_flight_slot(cache_key: &str) -> (Arc<Mutex<()>>, OwnedMutexGuard<()>) {
562 let slot = {
563 let mut in_flight = in_flight_loads().lock().await;
564 in_flight
565 .entry(cache_key.to_string())
566 .or_insert_with(|| Arc::new(Mutex::new(())))
567 .clone()
568 };
569
570 let guard = slot.clone().lock_owned().await;
571 (slot, guard)
572}
573
574async fn cleanup_in_flight_slot(cache_key: &str, slot: &Arc<Mutex<()>>) {
575 let mut in_flight = in_flight_loads().lock().await;
576 let should_remove = in_flight
577 .get(cache_key)
578 .map(|current| Arc::ptr_eq(current, slot) && Arc::strong_count(slot) == 2)
579 .unwrap_or(false);
580
581 if should_remove {
582 in_flight.remove(cache_key);
583 }
584}
585
586#[cfg(test)]
587mod tests {
588 use std::io::{Read, Write};
589 use std::net::TcpListener;
590 use std::sync::atomic::{AtomicUsize, Ordering};
591 use std::sync::Arc;
592 use std::sync::OnceLock;
593 use std::thread;
594 use std::time::{Duration, Instant};
595
596 use tokio::sync::Mutex;
597 use tokio::time::sleep;
598
599 use super::*;
600 use crate::routing::BoundingBox;
601
602 static FETCH_TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
603
604 fn fetch_test_lock() -> &'static Mutex<()> {
605 FETCH_TEST_LOCK.get_or_init(|| Mutex::new(()))
606 }
607
608 fn test_network() -> RoadNetwork {
609 RoadNetwork::from_test_data(&[(0.0, 0.0), (0.0, 0.01)], &[(0, 1, 60.0, 1_000.0)])
610 }
611
612 async fn reset_test_state() {
613 RoadNetwork::clear_cache().await;
614 in_flight_loads().lock().await.clear();
615 }
616
617 #[tokio::test]
618 async fn load_or_insert_allows_different_keys_to_progress_concurrently() {
619 let _guard = fetch_test_lock().lock().await;
620 reset_test_state().await;
621
622 let start = Instant::now();
623 let first = async {
624 RoadNetwork::load_or_insert("region-a".to_string(), async {
625 sleep(Duration::from_millis(100)).await;
626 Ok(test_network())
627 })
628 .await
629 .map(|network| network.node_count())
630 };
631 let second = async {
632 RoadNetwork::load_or_insert("region-b".to_string(), async {
633 sleep(Duration::from_millis(100)).await;
634 Ok(test_network())
635 })
636 .await
637 .map(|network| network.node_count())
638 };
639 let (left, right) = tokio::join!(first, second);
640 left.expect("first load should succeed");
641 right.expect("second load should succeed");
642
643 assert!(
644 start.elapsed() < Duration::from_millis(180),
645 "different keys should not serialize slow loads"
646 );
647 }
648
649 #[tokio::test]
650 async fn load_or_insert_deduplicates_same_key_work() {
651 let _guard = fetch_test_lock().lock().await;
652 reset_test_state().await;
653
654 let loads = Arc::new(AtomicUsize::new(0));
655
656 let first = {
657 let loads = loads.clone();
658 async move {
659 RoadNetwork::load_or_insert("region-a".to_string(), async move {
660 loads.fetch_add(1, Ordering::Relaxed);
661 sleep(Duration::from_millis(50)).await;
662 Ok(test_network())
663 })
664 .await
665 .map(|network| network.node_count())
666 }
667 };
668 let second = {
669 let loads = loads.clone();
670 async move {
671 RoadNetwork::load_or_insert("region-a".to_string(), async move {
672 loads.fetch_add(1, Ordering::Relaxed);
673 sleep(Duration::from_millis(50)).await;
674 Ok(test_network())
675 })
676 .await
677 .map(|network| network.node_count())
678 }
679 };
680
681 let (left, right) = tokio::join!(first, second);
682 left.expect("first load should succeed");
683 right.expect("second load should succeed");
684
685 assert_eq!(loads.load(Ordering::Relaxed), 1);
686 }
687
688 fn overpass_fixture_json() -> &'static str {
689 r#"{
690 "elements": [
691 {"type": "node", "id": 1, "lat": 39.95, "lon": -75.16},
692 {"type": "node", "id": 2, "lat": 39.96, "lon": -75.17},
693 {"type": "way", "id": 10, "nodes": [1, 2], "tags": {"highway": "residential"}}
694 ]
695 }"#
696 }
697
698 fn spawn_overpass_server(
699 responses: Vec<(&'static str, &'static str)>,
700 ) -> (String, Arc<AtomicUsize>, thread::JoinHandle<()>) {
701 let listener = TcpListener::bind("127.0.0.1:0").expect("listener should bind");
702 let address = format!(
703 "http://{}/api/interpreter",
704 listener.local_addr().expect("listener addr")
705 );
706 let requests = Arc::new(AtomicUsize::new(0));
707 let served = requests.clone();
708
709 let handle = thread::spawn(move || {
710 for (status, body) in responses {
711 let (mut stream, _) = listener.accept().expect("connection should arrive");
712 let mut buffer = [0_u8; 4096];
713 let _ = stream.read(&mut buffer);
714 let response = format!(
715 "HTTP/1.1 {}\r\nContent-Length: {}\r\nContent-Type: application/json\r\nConnection: close\r\n\r\n{}",
716 status,
717 body.len(),
718 body
719 );
720 stream
721 .write_all(response.as_bytes())
722 .expect("response should write");
723 served.fetch_add(1, Ordering::Relaxed);
724 }
725 });
726
727 (address, requests, handle)
728 }
729
730 #[tokio::test]
731 async fn fetch_retries_same_endpoint_until_success() {
732 let _guard = fetch_test_lock().lock().await;
733 let (endpoint, requests, handle) = spawn_overpass_server(vec![
734 ("429 Too Many Requests", r#"{"elements":[]}"#),
735 ("200 OK", overpass_fixture_json()),
736 ]);
737
738 let bbox = BoundingBox::try_new(39.94, -75.18, 39.97, -75.15).expect("bbox should build");
739 let config = NetworkConfig::default()
740 .overpass_url(endpoint)
741 .overpass_max_retries(1)
742 .overpass_retry_backoff(Duration::from_millis(1));
743
744 let network = RoadNetwork::fetch(&bbox, &config, None)
745 .await
746 .expect("fetch should succeed after retry");
747
748 assert_eq!(network.node_count(), 2);
749 assert_eq!(requests.load(Ordering::Relaxed), 2);
750 handle.join().expect("server should join");
751 }
752
753 #[tokio::test]
754 async fn fetch_falls_back_to_second_endpoint() {
755 let _guard = fetch_test_lock().lock().await;
756 let (primary, primary_requests, primary_handle) =
757 spawn_overpass_server(vec![("503 Service Unavailable", r#"{"elements":[]}"#)]);
758 let (secondary, secondary_requests, secondary_handle) =
759 spawn_overpass_server(vec![("200 OK", overpass_fixture_json())]);
760
761 let bbox = BoundingBox::try_new(39.94, -75.18, 39.97, -75.15).expect("bbox should build");
762 let config = NetworkConfig::default()
763 .overpass_endpoints(vec![primary, secondary])
764 .overpass_max_retries(0)
765 .overpass_retry_backoff(Duration::from_millis(1));
766
767 let network = RoadNetwork::fetch(&bbox, &config, None)
768 .await
769 .expect("fetch should fall back to second endpoint");
770
771 assert_eq!(network.node_count(), 2);
772 assert_eq!(primary_requests.load(Ordering::Relaxed), 1);
773 assert_eq!(secondary_requests.load(Ordering::Relaxed), 1);
774 primary_handle.join().expect("primary server should join");
775 secondary_handle
776 .join()
777 .expect("secondary server should join");
778 }
779}