Skip to main content

solverforge_maps/routing/
fetch.rs

1//! Overpass API fetching and caching for road networks.
2
3use std::collections::HashMap;
4use std::future::Future;
5use std::path::Path;
6use std::sync::Arc;
7use std::time::Duration;
8
9use tokio::sync::{mpsc::Sender, Mutex, OwnedMutexGuard};
10use tokio::time::sleep;
11use tracing::{debug, info};
12
13use super::bbox::BoundingBox;
14use super::cache::{
15    cache, in_flight_loads, record_disk_hit, record_in_flight_wait, record_load_request,
16    record_memory_hit, record_network_fetch, CachedEdge, CachedNetwork, CachedNode, NetworkRef,
17    CACHE_VERSION,
18};
19use super::config::{ConnectivityPolicy, NetworkConfig};
20use super::coord::Coord;
21use super::error::RoutingError;
22use super::network::{EdgeData, RoadNetwork};
23use super::osm::OverpassResponse;
24use super::progress::RoutingProgress;
25
26impl RoadNetwork {
27    pub async fn load_or_fetch(
28        bbox: &BoundingBox,
29        config: &NetworkConfig,
30        progress: Option<&Sender<RoutingProgress>>,
31    ) -> Result<NetworkRef, RoutingError> {
32        let cache_key = bbox.cache_key();
33        record_load_request();
34
35        if let Some(tx) = progress {
36            let _ = tx.send(RoutingProgress::CheckingCache { percent: 0 }).await;
37        }
38
39        {
40            let cache_guard = cache().read().await;
41            if cache_guard.contains_key(&cache_key) {
42                record_memory_hit();
43                info!("Using in-memory cached road network for {}", cache_key);
44                if let Some(tx) = progress {
45                    let _ = tx
46                        .send(RoutingProgress::CheckingCache { percent: 10 })
47                        .await;
48                }
49                return Ok(NetworkRef::new(cache_guard, cache_key));
50            }
51        }
52        if let Some(tx) = progress {
53            let _ = tx.send(RoutingProgress::CheckingCache { percent: 5 }).await;
54        }
55
56        Self::load_or_insert(cache_key, async {
57            tokio::fs::create_dir_all(&config.cache_dir).await?;
58            let cache_path = config.cache_dir.join(format!("{}.json", bbox.cache_key()));
59
60            if tokio::fs::try_exists(&cache_path).await.unwrap_or(false) {
61                info!("Loading road network from file cache: {:?}", cache_path);
62                if let Some(tx) = progress {
63                    let _ = tx.send(RoutingProgress::CheckingCache { percent: 8 }).await;
64                }
65                match Self::load_from_file(&cache_path, config).await {
66                    Ok(network) => {
67                        record_disk_hit();
68                        if let Some(tx) = progress {
69                            let _ = tx
70                                .send(RoutingProgress::BuildingGraph { percent: 50 })
71                                .await;
72                        }
73                        return Ok(network);
74                    }
75                    Err(e) => info!("File cache invalid ({}), downloading fresh", e),
76                }
77            } else {
78                info!("Downloading road network from Overpass API");
79            }
80
81            record_network_fetch();
82            let network = Self::fetch_from_api(bbox, config, progress).await?;
83            network.save_to_file(&cache_path).await?;
84            info!("Saved road network to file cache: {:?}", cache_path);
85            Ok(network)
86        })
87        .await
88    }
89
90    pub async fn fetch(
91        bbox: &BoundingBox,
92        config: &NetworkConfig,
93        progress: Option<&Sender<RoutingProgress>>,
94    ) -> Result<Self, RoutingError> {
95        Self::fetch_from_api(bbox, config, progress).await
96    }
97
98    async fn fetch_from_api(
99        bbox: &BoundingBox,
100        config: &NetworkConfig,
101        progress: Option<&Sender<RoutingProgress>>,
102    ) -> Result<Self, RoutingError> {
103        let highway_regex = config.highway_regex();
104        let query = format!(
105            r#"[out:json][timeout:120];
106(
107  way["highway"~"{}"]
108    ({},{},{},{});
109);
110(._;>;);
111out body;"#,
112            highway_regex, bbox.min_lat, bbox.min_lng, bbox.max_lat, bbox.max_lng
113        );
114
115        debug!("Overpass query:\n{}", query);
116        info!(
117            "Preparing Overpass query for bbox: {:.4},{:.4} to {:.4},{:.4}",
118            bbox.min_lat, bbox.min_lng, bbox.max_lat, bbox.max_lng
119        );
120
121        if let Some(tx) = progress {
122            let _ = tx
123                .send(RoutingProgress::DownloadingNetwork {
124                    percent: 10,
125                    bytes: 0,
126                })
127                .await;
128        }
129
130        let client = reqwest::Client::builder()
131            .connect_timeout(config.connect_timeout)
132            .read_timeout(config.read_timeout)
133            .timeout(config.read_timeout)
134            .user_agent("SolverForge/0.5.0")
135            .build()
136            .map_err(|e| RoutingError::Network(e.to_string()))?;
137
138        if let Some(tx) = progress {
139            let _ = tx
140                .send(RoutingProgress::DownloadingNetwork {
141                    percent: 15,
142                    bytes: 0,
143                })
144                .await;
145        }
146
147        let bytes = fetch_overpass_bytes(&client, &query, config, progress).await?;
148
149        let bytes_len = bytes.len();
150        if let Some(tx) = progress {
151            let _ = tx
152                .send(RoutingProgress::DownloadingNetwork {
153                    percent: 30,
154                    bytes: bytes_len,
155                })
156                .await;
157        }
158
159        if let Some(tx) = progress {
160            let _ = tx
161                .send(RoutingProgress::ParsingOsm {
162                    percent: 32,
163                    nodes: 0,
164                    edges: 0,
165                })
166                .await;
167        }
168
169        let osm_data: OverpassResponse =
170            serde_json::from_slice(&bytes).map_err(|e| RoutingError::Parse(e.to_string()))?;
171
172        info!("Downloaded {} OSM elements", osm_data.elements.len());
173
174        if let Some(tx) = progress {
175            let _ = tx
176                .send(RoutingProgress::ParsingOsm {
177                    percent: 35,
178                    nodes: osm_data.elements.len(),
179                    edges: 0,
180                })
181                .await;
182        }
183
184        if let Some(tx) = progress {
185            let _ = tx
186                .send(RoutingProgress::BuildingGraph { percent: 40 })
187                .await;
188        }
189
190        let network = Self::build_from_osm(&osm_data, config)?;
191
192        if let Some(tx) = progress {
193            let _ = tx
194                .send(RoutingProgress::BuildingGraph { percent: 50 })
195                .await;
196            let _ = tx.send(RoutingProgress::Complete).await;
197        }
198
199        Ok(network)
200    }
201
202    pub(super) fn build_from_osm(
203        osm: &OverpassResponse,
204        config: &NetworkConfig,
205    ) -> Result<Self, RoutingError> {
206        let mut network = Self::new();
207
208        let mut nodes: HashMap<i64, (f64, f64)> = HashMap::new();
209        for elem in &osm.elements {
210            if elem.elem_type == "node" {
211                if let (Some(lat), Some(lon)) = (elem.lat, elem.lon) {
212                    nodes.insert(elem.id, (lat, lon));
213                }
214            }
215        }
216
217        info!("Parsed {} nodes", nodes.len());
218
219        let mut way_count = 0;
220        for elem in &osm.elements {
221            if elem.elem_type == "way" {
222                if let Some(ref node_ids) = elem.nodes {
223                    let highway = elem.tags.as_ref().and_then(|t| t.highway.as_deref());
224                    let oneway = elem.tags.as_ref().and_then(|t| t.oneway.as_deref());
225                    let maxspeed = elem.tags.as_ref().and_then(|t| t.maxspeed.as_deref());
226                    let speed = config
227                        .speed_profile
228                        .speed_mps(maxspeed, highway.unwrap_or("residential"));
229                    let is_oneway_forward = matches!(oneway, Some("yes") | Some("1"));
230                    let is_oneway_reverse = matches!(oneway, Some("-1"));
231
232                    for window in node_ids.windows(2) {
233                        let n1_id = window[0];
234                        let n2_id = window[1];
235
236                        let Some(&(lat1, lng1)) = nodes.get(&n1_id) else {
237                            continue;
238                        };
239                        let Some(&(lat2, lng2)) = nodes.get(&n2_id) else {
240                            continue;
241                        };
242
243                        let idx1 = network.get_or_create_node(lat1, lng1);
244                        let idx2 = network.get_or_create_node(lat2, lng2);
245
246                        let coord1 = Coord::new(lat1, lng1);
247                        let coord2 = Coord::new(lat2, lng2);
248                        let distance = super::geo::haversine_distance(coord1, coord2);
249                        let travel_time = distance / speed;
250
251                        let edge_data = EdgeData {
252                            travel_time_s: travel_time,
253                            distance_m: distance,
254                        };
255
256                        if is_oneway_reverse {
257                            // oneway=-1 means traffic flows opposite to way direction
258                            network.add_edge(idx2, idx1, edge_data);
259                        } else {
260                            // Forward direction (always added unless reverse-only)
261                            network.add_edge(idx1, idx2, edge_data.clone());
262                            if !is_oneway_forward {
263                                // Bidirectional road
264                                network.add_edge(idx2, idx1, edge_data);
265                            }
266                        }
267                    }
268
269                    way_count += 1;
270                }
271            }
272        }
273
274        info!(
275            "Built graph with {} nodes and {} edges from {} ways",
276            network.node_count(),
277            network.edge_count(),
278            way_count
279        );
280
281        let scc_count = network.strongly_connected_components();
282        match config.connectivity_policy {
283            ConnectivityPolicy::KeepAll => {
284                if scc_count > 1 {
285                    info!(
286                        "Road network has {} SCCs, preserving all components by configuration",
287                        scc_count
288                    );
289                }
290            }
291            ConnectivityPolicy::LargestStronglyConnectedComponent => {
292                if scc_count > 1 {
293                    info!(
294                        "Road network has {} SCCs, filtering to largest component",
295                        scc_count
296                    );
297                    network.filter_to_largest_scc();
298                    info!(
299                        "After SCC filter: {} nodes, {} edges",
300                        network.node_count(),
301                        network.edge_count()
302                    );
303                }
304            }
305        }
306
307        network.build_spatial_index();
308
309        Ok(network)
310    }
311
312    async fn load_from_file(path: &Path, config: &NetworkConfig) -> Result<Self, RoutingError> {
313        let data = tokio::fs::read_to_string(path).await?;
314
315        let cached: CachedNetwork = match serde_json::from_str(&data) {
316            Ok(c) => c,
317            Err(e) => {
318                info!("Cache file corrupted, will re-download: {}", e);
319                let _ = tokio::fs::remove_file(path).await;
320                return Err(RoutingError::Parse(e.to_string()));
321            }
322        };
323
324        if cached.version != CACHE_VERSION {
325            info!(
326                "Cache version mismatch (got {}, need {}), will re-download",
327                cached.version, CACHE_VERSION
328            );
329            let _ = tokio::fs::remove_file(path).await;
330            return Err(RoutingError::Parse("cache version mismatch".into()));
331        }
332
333        let mut network = Self::new();
334
335        for node in &cached.nodes {
336            network.add_node_at(node.lat, node.lng);
337        }
338
339        for edge in &cached.edges {
340            network.add_edge_by_index(edge.from, edge.to, edge.travel_time_s, edge.distance_m);
341        }
342
343        let scc_count = network.strongly_connected_components();
344        match config.connectivity_policy {
345            ConnectivityPolicy::KeepAll => {
346                if scc_count > 1 {
347                    info!(
348                        "Cached network has {} SCCs, preserving all components by configuration",
349                        scc_count
350                    );
351                }
352            }
353            ConnectivityPolicy::LargestStronglyConnectedComponent => {
354                if scc_count > 1 {
355                    info!(
356                        "Cached network has {} SCCs, filtering to largest component",
357                        scc_count
358                    );
359                    network.filter_to_largest_scc();
360                    info!(
361                        "After SCC filter: {} nodes, {} edges",
362                        network.node_count(),
363                        network.edge_count()
364                    );
365                }
366            }
367        }
368
369        network.build_spatial_index();
370
371        Ok(network)
372    }
373
374    async fn save_to_file(&self, path: &Path) -> Result<(), RoutingError> {
375        let nodes: Vec<CachedNode> = self
376            .nodes_iter()
377            .map(|(lat, lng)| CachedNode { lat, lng })
378            .collect();
379
380        let edges: Vec<CachedEdge> = self
381            .edges_iter()
382            .map(|(from, to, travel_time_s, distance_m)| CachedEdge {
383                from,
384                to,
385                travel_time_s,
386                distance_m,
387            })
388            .collect();
389
390        let cached = CachedNetwork {
391            version: CACHE_VERSION,
392            nodes,
393            edges,
394        };
395        let data =
396            serde_json::to_string(&cached).map_err(|e| RoutingError::Parse(e.to_string()))?;
397        tokio::fs::write(path, data).await?;
398
399        Ok(())
400    }
401
402    async fn load_or_insert<F>(cache_key: String, load: F) -> Result<NetworkRef, RoutingError>
403    where
404        F: Future<Output = Result<RoadNetwork, RoutingError>>,
405    {
406        if let Some(cached) = Self::get_cached_network(cache_key.clone()).await {
407            record_memory_hit();
408            return Ok(cached);
409        }
410
411        let (slot, _slot_guard, waited) = acquire_in_flight_slot(&cache_key).await;
412        if waited {
413            record_in_flight_wait();
414        }
415
416        if let Some(cached) = Self::get_cached_network(cache_key.clone()).await {
417            record_memory_hit();
418            cleanup_in_flight_slot(&cache_key, &slot).await;
419            return Ok(cached);
420        }
421
422        let network = load.await?;
423
424        {
425            let mut cache_guard = cache().write().await;
426            cache_guard.entry(cache_key.clone()).or_insert(network);
427        }
428
429        cleanup_in_flight_slot(&cache_key, &slot).await;
430
431        Self::get_cached_network(cache_key).await.ok_or_else(|| {
432            RoutingError::Network("cached network disappeared after insertion".to_string())
433        })
434    }
435
436    async fn get_cached_network(cache_key: String) -> Option<NetworkRef> {
437        let cache_guard = cache().read().await;
438        if cache_guard.contains_key(&cache_key) {
439            Some(NetworkRef::new(cache_guard, cache_key))
440        } else {
441            None
442        }
443    }
444}
445
446async fn fetch_overpass_bytes(
447    client: &reqwest::Client,
448    query: &str,
449    config: &NetworkConfig,
450    progress: Option<&Sender<RoutingProgress>>,
451) -> Result<Vec<u8>, RoutingError> {
452    let endpoints = overpass_endpoints(config);
453    let mut failures = Vec::new();
454
455    for (endpoint_index, endpoint) in endpoints.iter().enumerate() {
456        for attempt in 0..=config.overpass_max_retries {
457            info!(
458                "Sending request to Overpass API endpoint {} attempt {}: {}",
459                endpoint_index + 1,
460                attempt + 1,
461                endpoint
462            );
463
464            let response = client
465                .post(endpoint)
466                .body(query.to_owned())
467                .header("Content-Type", "text/plain")
468                .send()
469                .await;
470
471            match response {
472                Ok(response) if response.status().is_success() => {
473                    info!(
474                        "Received successful Overpass response from {} with status {}",
475                        endpoint,
476                        response.status()
477                    );
478
479                    if let Some(tx) = progress {
480                        let _ = tx
481                            .send(RoutingProgress::DownloadingNetwork {
482                                percent: 25,
483                                bytes: 0,
484                            })
485                            .await;
486                    }
487
488                    return response
489                        .bytes()
490                        .await
491                        .map(|bytes| bytes.to_vec())
492                        .map_err(|error| {
493                            RoutingError::Network(format!(
494                                "Overpass response body read failed from {} on attempt {}: {}",
495                                endpoint,
496                                attempt + 1,
497                                error
498                            ))
499                        });
500                }
501                Ok(response) => {
502                    let status = response.status();
503                    failures.push(format!(
504                        "{} attempt {} returned HTTP {}",
505                        endpoint,
506                        attempt + 1,
507                        status
508                    ));
509
510                    if is_retryable_status(status) && attempt < config.overpass_max_retries {
511                        sleep(retry_backoff(config.overpass_retry_backoff, attempt)).await;
512                        continue;
513                    }
514
515                    break;
516                }
517                Err(error) => {
518                    failures.push(format!(
519                        "{} attempt {} failed: {}",
520                        endpoint,
521                        attempt + 1,
522                        error
523                    ));
524
525                    if is_retryable_error(&error) && attempt < config.overpass_max_retries {
526                        sleep(retry_backoff(config.overpass_retry_backoff, attempt)).await;
527                        continue;
528                    }
529
530                    break;
531                }
532            }
533        }
534    }
535
536    Err(RoutingError::Network(format!(
537        "Overpass fetch failed after trying {} endpoint(s): {}",
538        endpoints.len(),
539        failures.join("; ")
540    )))
541}
542
543fn overpass_endpoints(config: &NetworkConfig) -> Vec<String> {
544    if config.overpass_endpoints.is_empty() {
545        vec![config.overpass_url.clone()]
546    } else {
547        config.overpass_endpoints.clone()
548    }
549}
550
551fn retry_backoff(base: Duration, attempt: usize) -> Duration {
552    base.saturating_mul((attempt + 1) as u32)
553}
554
555fn is_retryable_status(status: reqwest::StatusCode) -> bool {
556    status.is_server_error()
557        || status == reqwest::StatusCode::TOO_MANY_REQUESTS
558        || status == reqwest::StatusCode::REQUEST_TIMEOUT
559}
560
561fn is_retryable_error(error: &reqwest::Error) -> bool {
562    error.is_timeout() || error.is_connect() || error.is_request()
563}
564
565async fn acquire_in_flight_slot(cache_key: &str) -> (Arc<Mutex<()>>, OwnedMutexGuard<()>, bool) {
566    let slot = {
567        let mut in_flight = in_flight_loads().lock().await;
568        match in_flight.get(cache_key) {
569            Some(slot) => slot.clone(),
570            None => {
571                let slot = Arc::new(Mutex::new(()));
572                in_flight.insert(cache_key.to_string(), slot.clone());
573                slot
574            }
575        }
576    };
577
578    match slot.clone().try_lock_owned() {
579        Ok(guard) => (slot, guard, false),
580        Err(_) => {
581            let guard = slot.clone().lock_owned().await;
582            (slot, guard, true)
583        }
584    }
585}
586
587async fn cleanup_in_flight_slot(cache_key: &str, slot: &Arc<Mutex<()>>) {
588    let mut in_flight = in_flight_loads().lock().await;
589    let should_remove = in_flight
590        .get(cache_key)
591        .map(|current| Arc::ptr_eq(current, slot) && Arc::strong_count(slot) == 2)
592        .unwrap_or(false);
593
594    if should_remove {
595        in_flight.remove(cache_key);
596    }
597}
598
599#[cfg(test)]
600#[path = "fetch_tests.rs"]
601mod tests;