1use std::collections::HashMap;
4use std::future::Future;
5use std::path::Path;
6use std::sync::Arc;
7use std::time::Duration;
8
9use tokio::sync::{mpsc::Sender, Mutex, OwnedMutexGuard};
10use tokio::time::sleep;
11use tracing::{debug, info};
12
13use super::bbox::BoundingBox;
14use super::cache::{
15 cache, in_flight_loads, record_disk_hit, record_in_flight_wait, record_load_request,
16 record_memory_hit, record_network_fetch, CachedEdge, CachedNetwork, CachedNode, NetworkRef,
17 CACHE_VERSION,
18};
19use super::config::{ConnectivityPolicy, NetworkConfig};
20use super::coord::Coord;
21use super::error::RoutingError;
22use super::network::{EdgeData, RoadNetwork};
23use super::osm::OverpassResponse;
24use super::progress::RoutingProgress;
25
26impl RoadNetwork {
27 pub async fn load_or_fetch(
28 bbox: &BoundingBox,
29 config: &NetworkConfig,
30 progress: Option<&Sender<RoutingProgress>>,
31 ) -> Result<NetworkRef, RoutingError> {
32 let cache_key = bbox.cache_key();
33 record_load_request();
34
35 if let Some(tx) = progress {
36 let _ = tx.send(RoutingProgress::CheckingCache { percent: 0 }).await;
37 }
38
39 {
40 let cache_guard = cache().read().await;
41 if cache_guard.contains_key(&cache_key) {
42 record_memory_hit();
43 info!("Using in-memory cached road network for {}", cache_key);
44 if let Some(tx) = progress {
45 let _ = tx
46 .send(RoutingProgress::CheckingCache { percent: 10 })
47 .await;
48 }
49 return Ok(NetworkRef::new(cache_guard, cache_key));
50 }
51 }
52 if let Some(tx) = progress {
53 let _ = tx.send(RoutingProgress::CheckingCache { percent: 5 }).await;
54 }
55
56 Self::load_or_insert(cache_key, async {
57 tokio::fs::create_dir_all(&config.cache_dir).await?;
58 let cache_path = config.cache_dir.join(format!("{}.json", bbox.cache_key()));
59
60 if tokio::fs::try_exists(&cache_path).await.unwrap_or(false) {
61 info!("Loading road network from file cache: {:?}", cache_path);
62 if let Some(tx) = progress {
63 let _ = tx.send(RoutingProgress::CheckingCache { percent: 8 }).await;
64 }
65 match Self::load_from_file(&cache_path, config).await {
66 Ok(network) => {
67 record_disk_hit();
68 if let Some(tx) = progress {
69 let _ = tx
70 .send(RoutingProgress::BuildingGraph { percent: 50 })
71 .await;
72 }
73 return Ok(network);
74 }
75 Err(e) => info!("File cache invalid ({}), downloading fresh", e),
76 }
77 } else {
78 info!("Downloading road network from Overpass API");
79 }
80
81 record_network_fetch();
82 let network = Self::fetch_from_api(bbox, config, progress).await?;
83 network.save_to_file(&cache_path).await?;
84 info!("Saved road network to file cache: {:?}", cache_path);
85 Ok(network)
86 })
87 .await
88 }
89
90 pub async fn fetch(
91 bbox: &BoundingBox,
92 config: &NetworkConfig,
93 progress: Option<&Sender<RoutingProgress>>,
94 ) -> Result<Self, RoutingError> {
95 Self::fetch_from_api(bbox, config, progress).await
96 }
97
98 async fn fetch_from_api(
99 bbox: &BoundingBox,
100 config: &NetworkConfig,
101 progress: Option<&Sender<RoutingProgress>>,
102 ) -> Result<Self, RoutingError> {
103 let highway_regex = config.highway_regex();
104 let query = format!(
105 r#"[out:json][timeout:120];
106(
107 way["highway"~"{}"]
108 ({},{},{},{});
109);
110(._;>;);
111out body;"#,
112 highway_regex, bbox.min_lat, bbox.min_lng, bbox.max_lat, bbox.max_lng
113 );
114
115 debug!("Overpass query:\n{}", query);
116 info!(
117 "Preparing Overpass query for bbox: {:.4},{:.4} to {:.4},{:.4}",
118 bbox.min_lat, bbox.min_lng, bbox.max_lat, bbox.max_lng
119 );
120
121 if let Some(tx) = progress {
122 let _ = tx
123 .send(RoutingProgress::DownloadingNetwork {
124 percent: 10,
125 bytes: 0,
126 })
127 .await;
128 }
129
130 let client = reqwest::Client::builder()
131 .connect_timeout(config.connect_timeout)
132 .read_timeout(config.read_timeout)
133 .timeout(config.read_timeout)
134 .user_agent("SolverForge/0.5.0")
135 .build()
136 .map_err(|e| RoutingError::Network(e.to_string()))?;
137
138 if let Some(tx) = progress {
139 let _ = tx
140 .send(RoutingProgress::DownloadingNetwork {
141 percent: 15,
142 bytes: 0,
143 })
144 .await;
145 }
146
147 let bytes = fetch_overpass_bytes(&client, &query, config, progress).await?;
148
149 let bytes_len = bytes.len();
150 if let Some(tx) = progress {
151 let _ = tx
152 .send(RoutingProgress::DownloadingNetwork {
153 percent: 30,
154 bytes: bytes_len,
155 })
156 .await;
157 }
158
159 if let Some(tx) = progress {
160 let _ = tx
161 .send(RoutingProgress::ParsingOsm {
162 percent: 32,
163 nodes: 0,
164 edges: 0,
165 })
166 .await;
167 }
168
169 let osm_data: OverpassResponse =
170 serde_json::from_slice(&bytes).map_err(|e| RoutingError::Parse(e.to_string()))?;
171
172 info!("Downloaded {} OSM elements", osm_data.elements.len());
173
174 if let Some(tx) = progress {
175 let _ = tx
176 .send(RoutingProgress::ParsingOsm {
177 percent: 35,
178 nodes: osm_data.elements.len(),
179 edges: 0,
180 })
181 .await;
182 }
183
184 if let Some(tx) = progress {
185 let _ = tx
186 .send(RoutingProgress::BuildingGraph { percent: 40 })
187 .await;
188 }
189
190 let network = Self::build_from_osm(&osm_data, config)?;
191
192 if let Some(tx) = progress {
193 let _ = tx
194 .send(RoutingProgress::BuildingGraph { percent: 50 })
195 .await;
196 let _ = tx.send(RoutingProgress::Complete).await;
197 }
198
199 Ok(network)
200 }
201
202 pub(super) fn build_from_osm(
203 osm: &OverpassResponse,
204 config: &NetworkConfig,
205 ) -> Result<Self, RoutingError> {
206 let mut network = Self::new();
207
208 let mut nodes: HashMap<i64, (f64, f64)> = HashMap::new();
209 for elem in &osm.elements {
210 if elem.elem_type == "node" {
211 if let (Some(lat), Some(lon)) = (elem.lat, elem.lon) {
212 nodes.insert(elem.id, (lat, lon));
213 }
214 }
215 }
216
217 info!("Parsed {} nodes", nodes.len());
218
219 let mut way_count = 0;
220 for elem in &osm.elements {
221 if elem.elem_type == "way" {
222 if let Some(ref node_ids) = elem.nodes {
223 let highway = elem.tags.as_ref().and_then(|t| t.highway.as_deref());
224 let oneway = elem.tags.as_ref().and_then(|t| t.oneway.as_deref());
225 let maxspeed = elem.tags.as_ref().and_then(|t| t.maxspeed.as_deref());
226 let speed = config
227 .speed_profile
228 .speed_mps(maxspeed, highway.unwrap_or("residential"));
229 let is_oneway_forward = matches!(oneway, Some("yes") | Some("1"));
230 let is_oneway_reverse = matches!(oneway, Some("-1"));
231
232 for window in node_ids.windows(2) {
233 let n1_id = window[0];
234 let n2_id = window[1];
235
236 let Some(&(lat1, lng1)) = nodes.get(&n1_id) else {
237 continue;
238 };
239 let Some(&(lat2, lng2)) = nodes.get(&n2_id) else {
240 continue;
241 };
242
243 let idx1 = network.get_or_create_node(lat1, lng1);
244 let idx2 = network.get_or_create_node(lat2, lng2);
245
246 let coord1 = Coord::new(lat1, lng1);
247 let coord2 = Coord::new(lat2, lng2);
248 let distance = super::geo::haversine_distance(coord1, coord2);
249 let travel_time = distance / speed;
250
251 let edge_data = EdgeData {
252 travel_time_s: travel_time,
253 distance_m: distance,
254 };
255
256 if is_oneway_reverse {
257 network.add_edge(idx2, idx1, edge_data);
259 } else {
260 network.add_edge(idx1, idx2, edge_data.clone());
262 if !is_oneway_forward {
263 network.add_edge(idx2, idx1, edge_data);
265 }
266 }
267 }
268
269 way_count += 1;
270 }
271 }
272 }
273
274 info!(
275 "Built graph with {} nodes and {} edges from {} ways",
276 network.node_count(),
277 network.edge_count(),
278 way_count
279 );
280
281 let scc_count = network.strongly_connected_components();
282 match config.connectivity_policy {
283 ConnectivityPolicy::KeepAll => {
284 if scc_count > 1 {
285 info!(
286 "Road network has {} SCCs, preserving all components by configuration",
287 scc_count
288 );
289 }
290 }
291 ConnectivityPolicy::LargestStronglyConnectedComponent => {
292 if scc_count > 1 {
293 info!(
294 "Road network has {} SCCs, filtering to largest component",
295 scc_count
296 );
297 network.filter_to_largest_scc();
298 info!(
299 "After SCC filter: {} nodes, {} edges",
300 network.node_count(),
301 network.edge_count()
302 );
303 }
304 }
305 }
306
307 network.build_spatial_index();
308
309 Ok(network)
310 }
311
312 async fn load_from_file(path: &Path, config: &NetworkConfig) -> Result<Self, RoutingError> {
313 let data = tokio::fs::read_to_string(path).await?;
314
315 let cached: CachedNetwork = match serde_json::from_str(&data) {
316 Ok(c) => c,
317 Err(e) => {
318 info!("Cache file corrupted, will re-download: {}", e);
319 let _ = tokio::fs::remove_file(path).await;
320 return Err(RoutingError::Parse(e.to_string()));
321 }
322 };
323
324 if cached.version != CACHE_VERSION {
325 info!(
326 "Cache version mismatch (got {}, need {}), will re-download",
327 cached.version, CACHE_VERSION
328 );
329 let _ = tokio::fs::remove_file(path).await;
330 return Err(RoutingError::Parse("cache version mismatch".into()));
331 }
332
333 let mut network = Self::new();
334
335 for node in &cached.nodes {
336 network.add_node_at(node.lat, node.lng);
337 }
338
339 for edge in &cached.edges {
340 network.add_edge_by_index(edge.from, edge.to, edge.travel_time_s, edge.distance_m);
341 }
342
343 let scc_count = network.strongly_connected_components();
344 match config.connectivity_policy {
345 ConnectivityPolicy::KeepAll => {
346 if scc_count > 1 {
347 info!(
348 "Cached network has {} SCCs, preserving all components by configuration",
349 scc_count
350 );
351 }
352 }
353 ConnectivityPolicy::LargestStronglyConnectedComponent => {
354 if scc_count > 1 {
355 info!(
356 "Cached network has {} SCCs, filtering to largest component",
357 scc_count
358 );
359 network.filter_to_largest_scc();
360 info!(
361 "After SCC filter: {} nodes, {} edges",
362 network.node_count(),
363 network.edge_count()
364 );
365 }
366 }
367 }
368
369 network.build_spatial_index();
370
371 Ok(network)
372 }
373
374 async fn save_to_file(&self, path: &Path) -> Result<(), RoutingError> {
375 let nodes: Vec<CachedNode> = self
376 .nodes_iter()
377 .map(|(lat, lng)| CachedNode { lat, lng })
378 .collect();
379
380 let edges: Vec<CachedEdge> = self
381 .edges_iter()
382 .map(|(from, to, travel_time_s, distance_m)| CachedEdge {
383 from,
384 to,
385 travel_time_s,
386 distance_m,
387 })
388 .collect();
389
390 let cached = CachedNetwork {
391 version: CACHE_VERSION,
392 nodes,
393 edges,
394 };
395 let data =
396 serde_json::to_string(&cached).map_err(|e| RoutingError::Parse(e.to_string()))?;
397 tokio::fs::write(path, data).await?;
398
399 Ok(())
400 }
401
402 async fn load_or_insert<F>(cache_key: String, load: F) -> Result<NetworkRef, RoutingError>
403 where
404 F: Future<Output = Result<RoadNetwork, RoutingError>>,
405 {
406 if let Some(cached) = Self::get_cached_network(cache_key.clone()).await {
407 record_memory_hit();
408 return Ok(cached);
409 }
410
411 let (slot, _slot_guard, waited) = acquire_in_flight_slot(&cache_key).await;
412 if waited {
413 record_in_flight_wait();
414 }
415
416 if let Some(cached) = Self::get_cached_network(cache_key.clone()).await {
417 record_memory_hit();
418 cleanup_in_flight_slot(&cache_key, &slot).await;
419 return Ok(cached);
420 }
421
422 let network = load.await?;
423
424 {
425 let mut cache_guard = cache().write().await;
426 cache_guard.entry(cache_key.clone()).or_insert(network);
427 }
428
429 cleanup_in_flight_slot(&cache_key, &slot).await;
430
431 Self::get_cached_network(cache_key).await.ok_or_else(|| {
432 RoutingError::Network("cached network disappeared after insertion".to_string())
433 })
434 }
435
436 async fn get_cached_network(cache_key: String) -> Option<NetworkRef> {
437 let cache_guard = cache().read().await;
438 if cache_guard.contains_key(&cache_key) {
439 Some(NetworkRef::new(cache_guard, cache_key))
440 } else {
441 None
442 }
443 }
444}
445
446async fn fetch_overpass_bytes(
447 client: &reqwest::Client,
448 query: &str,
449 config: &NetworkConfig,
450 progress: Option<&Sender<RoutingProgress>>,
451) -> Result<Vec<u8>, RoutingError> {
452 let endpoints = overpass_endpoints(config);
453 let mut failures = Vec::new();
454
455 for (endpoint_index, endpoint) in endpoints.iter().enumerate() {
456 for attempt in 0..=config.overpass_max_retries {
457 info!(
458 "Sending request to Overpass API endpoint {} attempt {}: {}",
459 endpoint_index + 1,
460 attempt + 1,
461 endpoint
462 );
463
464 let response = client
465 .post(endpoint)
466 .body(query.to_owned())
467 .header("Content-Type", "text/plain")
468 .send()
469 .await;
470
471 match response {
472 Ok(response) if response.status().is_success() => {
473 info!(
474 "Received successful Overpass response from {} with status {}",
475 endpoint,
476 response.status()
477 );
478
479 if let Some(tx) = progress {
480 let _ = tx
481 .send(RoutingProgress::DownloadingNetwork {
482 percent: 25,
483 bytes: 0,
484 })
485 .await;
486 }
487
488 return response
489 .bytes()
490 .await
491 .map(|bytes| bytes.to_vec())
492 .map_err(|error| {
493 RoutingError::Network(format!(
494 "Overpass response body read failed from {} on attempt {}: {}",
495 endpoint,
496 attempt + 1,
497 error
498 ))
499 });
500 }
501 Ok(response) => {
502 let status = response.status();
503 failures.push(format!(
504 "{} attempt {} returned HTTP {}",
505 endpoint,
506 attempt + 1,
507 status
508 ));
509
510 if is_retryable_status(status) && attempt < config.overpass_max_retries {
511 sleep(retry_backoff(config.overpass_retry_backoff, attempt)).await;
512 continue;
513 }
514
515 break;
516 }
517 Err(error) => {
518 failures.push(format!(
519 "{} attempt {} failed: {}",
520 endpoint,
521 attempt + 1,
522 error
523 ));
524
525 if is_retryable_error(&error) && attempt < config.overpass_max_retries {
526 sleep(retry_backoff(config.overpass_retry_backoff, attempt)).await;
527 continue;
528 }
529
530 break;
531 }
532 }
533 }
534 }
535
536 Err(RoutingError::Network(format!(
537 "Overpass fetch failed after trying {} endpoint(s): {}",
538 endpoints.len(),
539 failures.join("; ")
540 )))
541}
542
543fn overpass_endpoints(config: &NetworkConfig) -> Vec<String> {
544 if config.overpass_endpoints.is_empty() {
545 vec![config.overpass_url.clone()]
546 } else {
547 config.overpass_endpoints.clone()
548 }
549}
550
551fn retry_backoff(base: Duration, attempt: usize) -> Duration {
552 base.saturating_mul((attempt + 1) as u32)
553}
554
555fn is_retryable_status(status: reqwest::StatusCode) -> bool {
556 status.is_server_error()
557 || status == reqwest::StatusCode::TOO_MANY_REQUESTS
558 || status == reqwest::StatusCode::REQUEST_TIMEOUT
559}
560
561fn is_retryable_error(error: &reqwest::Error) -> bool {
562 error.is_timeout() || error.is_connect() || error.is_request()
563}
564
565async fn acquire_in_flight_slot(cache_key: &str) -> (Arc<Mutex<()>>, OwnedMutexGuard<()>, bool) {
566 let slot = {
567 let mut in_flight = in_flight_loads().lock().await;
568 match in_flight.get(cache_key) {
569 Some(slot) => slot.clone(),
570 None => {
571 let slot = Arc::new(Mutex::new(()));
572 in_flight.insert(cache_key.to_string(), slot.clone());
573 slot
574 }
575 }
576 };
577
578 match slot.clone().try_lock_owned() {
579 Ok(guard) => (slot, guard, false),
580 Err(_) => {
581 let guard = slot.clone().lock_owned().await;
582 (slot, guard, true)
583 }
584 }
585}
586
587async fn cleanup_in_flight_slot(cache_key: &str, slot: &Arc<Mutex<()>>) {
588 let mut in_flight = in_flight_loads().lock().await;
589 let should_remove = in_flight
590 .get(cache_key)
591 .map(|current| Arc::ptr_eq(current, slot) && Arc::strong_count(slot) == 2)
592 .unwrap_or(false);
593
594 if should_remove {
595 in_flight.remove(cache_key);
596 }
597}
598
599#[cfg(test)]
600#[path = "fetch_tests.rs"]
601mod tests;