Skip to main content

zlayer_proxy/
routes.rs

1//! Service registry for route resolution
2//!
3//! This module provides a production-ready `ServiceRegistry` for mapping incoming
4//! requests to backend services based on host patterns (including wildcards) and
5//! path prefixes.  Routes are stored in longest-prefix-first order so that
6//! `resolve()` returns the most specific match in O(n).
7
8use std::net::SocketAddr;
9use tokio::sync::RwLock;
10use zlayer_spec::{EndpointSpec, ExposeType, Protocol};
11
12// ---------------------------------------------------------------------------
13// ResolvedService
14// ---------------------------------------------------------------------------
15
16/// Fully-resolved service information returned by the registry.
17#[derive(Clone, Debug)]
18pub struct ResolvedService {
19    /// Service name (e.g. "api", "frontend")
20    pub name: String,
21    /// Backend addresses for load balancing
22    pub backends: Vec<SocketAddr>,
23    /// Whether to use TLS for upstream connections
24    pub use_tls: bool,
25    /// SNI hostname for TLS connections
26    pub sni_hostname: String,
27    /// Exposure type (public / internal)
28    pub expose: ExposeType,
29    /// Protocol (http, https, tcp, udp, websocket)
30    pub protocol: Protocol,
31    /// Whether to strip the matched path prefix before forwarding
32    pub strip_prefix: bool,
33    /// The path prefix this service was registered with
34    pub path_prefix: String,
35    /// The port the container actually listens on
36    pub target_port: u16,
37}
38
39// ---------------------------------------------------------------------------
40// RouteEntry
41// ---------------------------------------------------------------------------
42
43/// A single route entry in the registry.
44#[derive(Debug, Clone)]
45pub struct RouteEntry {
46    /// Owning service name (e.g. "api")
47    pub service_name: String,
48    /// Endpoint name within that service (e.g. "http", "grpc")
49    pub endpoint_name: String,
50    /// Host pattern to match.  `None` means match any host.
51    /// Supports wildcard patterns like `*.example.com`.
52    pub host: Option<String>,
53    /// Path prefix to match.  `"/"` matches all paths.
54    pub path_prefix: String,
55    /// The fully-resolved service returned on match.
56    pub resolved: ResolvedService,
57}
58
59impl RouteEntry {
60    /// Create a `RouteEntry` from a `zlayer_spec::EndpointSpec`.
61    ///
62    /// Fields that cannot be derived from the spec alone (backends, TLS,
63    /// SNI) are given sensible defaults and can be overridden after construction.
64    #[must_use]
65    pub fn from_endpoint(service_name: &str, endpoint: &EndpointSpec) -> Self {
66        let path_prefix = endpoint.path.clone().unwrap_or_else(|| "/".to_string());
67        let target_port = endpoint.target_port();
68
69        Self {
70            service_name: service_name.to_string(),
71            endpoint_name: endpoint.name.clone(),
72            host: endpoint.host.clone(),
73            path_prefix: path_prefix.clone(),
74            resolved: ResolvedService {
75                name: service_name.to_string(),
76                backends: Vec::new(),
77                use_tls: endpoint.protocol == Protocol::Https,
78                sni_hostname: String::new(),
79                expose: endpoint.expose,
80                protocol: endpoint.protocol,
81                strip_prefix: false,
82                path_prefix,
83                target_port,
84            },
85        }
86    }
87
88    /// Check whether this route matches the given host and path.
89    #[must_use]
90    pub fn matches(&self, host: Option<&str>, path: &str) -> bool {
91        // If the route specifies a host pattern the request must supply a
92        // host that satisfies it.
93        if let Some(ref pattern) = self.host {
94            match host {
95                Some(h) => {
96                    if !host_matches(pattern, h) {
97                        return false;
98                    }
99                }
100                None => return false,
101            }
102        }
103
104        path_matches(&self.path_prefix, path)
105    }
106}
107
108// ---------------------------------------------------------------------------
109// ServiceRegistry
110// ---------------------------------------------------------------------------
111
112/// Production-ready service registry for the `ZLayer` reverse proxy.
113///
114/// Routes are stored as a `Vec<RouteEntry>` behind a `tokio::sync::RwLock`,
115/// kept in **longest-prefix-first** order so that `resolve()` always returns
116/// the most specific match.
117pub struct ServiceRegistry {
118    /// Routes sorted by descending path-prefix length.
119    routes: RwLock<Vec<RouteEntry>>,
120}
121
122impl Default for ServiceRegistry {
123    fn default() -> Self {
124        Self::new()
125    }
126}
127
128impl ServiceRegistry {
129    /// Create an empty registry.
130    #[must_use]
131    pub fn new() -> Self {
132        Self {
133            routes: RwLock::new(Vec::new()),
134        }
135    }
136
137    /// Register a route, maintaining longest-prefix-first order.
138    pub async fn register(&self, entry: RouteEntry) {
139        let mut routes = self.routes.write().await;
140
141        let insert_idx = routes
142            .iter()
143            .position(|r| r.path_prefix.len() < entry.path_prefix.len())
144            .unwrap_or(routes.len());
145
146        routes.insert(insert_idx, entry);
147    }
148
149    /// Remove **all** routes belonging to `service_name`.
150    pub async fn unregister_service(&self, service_name: &str) {
151        let mut routes = self.routes.write().await;
152        routes.retain(|r| r.service_name != service_name);
153    }
154
155    /// Resolve an incoming request to the best-matching `ResolvedService`.
156    ///
157    /// Returns `None` when no route matches.
158    pub async fn resolve(&self, host: Option<&str>, path: &str) -> Option<ResolvedService> {
159        let routes = self.routes.read().await;
160
161        // First matching route wins (longest prefix is first due to ordering).
162        for entry in routes.iter() {
163            if entry.matches(host, path) {
164                return Some(entry.resolved.clone());
165            }
166        }
167
168        None
169    }
170
171    /// Replace the backend list for every route belonging to `service_name`.
172    pub async fn update_backends(&self, service_name: &str, backends: Vec<SocketAddr>) {
173        let mut routes = self.routes.write().await;
174        for entry in routes.iter_mut() {
175            if entry.service_name == service_name {
176                entry.resolved.backends.clone_from(&backends);
177            }
178        }
179    }
180
181    /// Append a single backend address to every route belonging to `service_name`.
182    pub async fn add_backend(&self, service_name: &str, addr: SocketAddr) {
183        let mut routes = self.routes.write().await;
184        for entry in routes.iter_mut() {
185            if entry.service_name == service_name && !entry.resolved.backends.contains(&addr) {
186                entry.resolved.backends.push(addr);
187            }
188        }
189    }
190
191    /// Remove a single backend address from every route belonging to `service_name`.
192    pub async fn remove_backend(&self, service_name: &str, addr: SocketAddr) {
193        let mut routes = self.routes.write().await;
194        for entry in routes.iter_mut() {
195            if entry.service_name == service_name {
196                entry.resolved.backends.retain(|a| *a != addr);
197            }
198        }
199    }
200
201    /// Return the unique set of service names across all registered routes.
202    pub async fn list_services(&self) -> Vec<String> {
203        let routes = self.routes.read().await;
204        let mut seen = Vec::new();
205        for entry in routes.iter() {
206            if !seen.contains(&entry.service_name) {
207                seen.push(entry.service_name.clone());
208            }
209        }
210        seen
211    }
212
213    /// Return the total number of registered routes.
214    pub async fn route_count(&self) -> usize {
215        self.routes.read().await.len()
216    }
217
218    /// Return a snapshot of all registered routes.
219    pub async fn list_routes(&self) -> Vec<RouteEntry> {
220        self.routes.read().await.clone()
221    }
222}
223
224// ---------------------------------------------------------------------------
225// Free functions
226// ---------------------------------------------------------------------------
227
228/// Transform `path` by optionally stripping `prefix`.
229///
230/// When `strip` is `true` the leading `prefix` is removed.  If the result
231/// would be empty, `"/"` is returned instead.
232#[must_use]
233pub fn transform_path(prefix: &str, path: &str, strip: bool) -> String {
234    if !strip || prefix == "/" {
235        return path.to_string();
236    }
237
238    let normalized_prefix = prefix.trim_end_matches('/');
239    if let Some(remainder) = path.strip_prefix(normalized_prefix) {
240        if remainder.is_empty() {
241            "/".to_string()
242        } else {
243            remainder.to_string()
244        }
245    } else {
246        path.to_string()
247    }
248}
249
250/// Check whether `pattern` matches `host`.
251///
252/// Supports simple wildcard patterns: `*.example.com` matches any
253/// single-level subdomain such as `api.example.com`.
254fn host_matches(pattern: &str, host: &str) -> bool {
255    if pattern.starts_with("*.") {
256        let suffix = &pattern[1..]; // e.g. ".example.com"
257        host.ends_with(suffix)
258    } else {
259        pattern == host
260    }
261}
262
263/// Check whether `prefix` matches the beginning of `path` with a proper
264/// boundary check (i.e. `/api` matches `/api/foo` but not `/apiary`).
265fn path_matches(prefix: &str, path: &str) -> bool {
266    if prefix == "/" {
267        return true;
268    }
269
270    let normalized = prefix.trim_end_matches('/');
271    let normalized_path = path.trim_end_matches('/');
272
273    normalized_path.starts_with(normalized)
274        && (normalized_path.len() == normalized.len()
275            || path.as_bytes().get(normalized.len()) == Some(&b'/'))
276}
277
278// ---------------------------------------------------------------------------
279// Tests
280// ---------------------------------------------------------------------------
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285
286    // -- Helpers -----------------------------------------------------------
287
288    /// Shorthand to build a minimal `ResolvedService`.
289    fn make_resolved(name: &str, backends: Vec<SocketAddr>) -> ResolvedService {
290        ResolvedService {
291            name: name.to_string(),
292            backends,
293            use_tls: false,
294            sni_hostname: String::new(),
295            expose: ExposeType::Internal,
296            protocol: Protocol::Http,
297            strip_prefix: false,
298            path_prefix: "/".to_string(),
299            target_port: 8080,
300        }
301    }
302
303    /// Shorthand to build a `RouteEntry`.
304    fn make_entry(
305        service: &str,
306        host: Option<&str>,
307        path: &str,
308        backends: Vec<SocketAddr>,
309    ) -> RouteEntry {
310        let mut resolved = make_resolved(service, backends);
311        resolved.path_prefix = path.to_string();
312        RouteEntry {
313            service_name: service.to_string(),
314            endpoint_name: "http".to_string(),
315            host: host.map(std::string::ToString::to_string),
316            path_prefix: path.to_string(),
317            resolved,
318        }
319    }
320
321    // -- Ported from routing.rs --------------------------------------------
322
323    #[test]
324    fn test_route_path_matching() {
325        let entry = make_entry("api", None, "/api/v1", vec![]);
326
327        assert!(entry.matches(None, "/api/v1"));
328        assert!(entry.matches(None, "/api/v1/"));
329        assert!(entry.matches(None, "/api/v1/users"));
330        assert!(entry.matches(None, "/api/v1/users/123"));
331        assert!(!entry.matches(None, "/api/v2"));
332        assert!(!entry.matches(None, "/api"));
333        assert!(!entry.matches(None, "/"));
334    }
335
336    #[test]
337    fn test_route_host_matching() {
338        let entry = make_entry("api", Some("api.example.com"), "/", vec![]);
339
340        assert!(entry.matches(Some("api.example.com"), "/anything"));
341        assert!(!entry.matches(Some("other.example.com"), "/anything"));
342        assert!(!entry.matches(None, "/anything"));
343    }
344
345    #[test]
346    fn test_route_wildcard_host() {
347        let entry = make_entry("api", Some("*.example.com"), "/", vec![]);
348
349        assert!(entry.matches(Some("api.example.com"), "/"));
350        assert!(entry.matches(Some("www.example.com"), "/"));
351        assert!(entry.matches(Some("foo.example.com"), "/"));
352        assert!(!entry.matches(Some("example.com"), "/"));
353        assert!(!entry.matches(Some("other.domain.com"), "/"));
354    }
355
356    #[test]
357    fn test_route_strip_prefix() {
358        assert_eq!(transform_path("/api/v1", "/api/v1/users", true), "/users");
359        assert_eq!(
360            transform_path("/api/v1", "/api/v1/users/123", true),
361            "/users/123"
362        );
363        assert_eq!(transform_path("/api/v1", "/api/v1", true), "/");
364        assert_eq!(transform_path("/api/v1", "/other", true), "/other");
365        // strip=false should be a no-op
366        assert_eq!(
367            transform_path("/api/v1", "/api/v1/users", false),
368            "/api/v1/users"
369        );
370    }
371
372    #[tokio::test]
373    async fn test_router_longest_prefix_match() {
374        let reg = ServiceRegistry::new();
375
376        reg.register(make_entry("root", None, "/", vec![])).await;
377        reg.register(make_entry("api", None, "/api", vec![])).await;
378        reg.register(make_entry("api-v1", None, "/api/v1", vec![]))
379            .await;
380
381        let m = reg.resolve(None, "/api/v1/users").await.unwrap();
382        assert_eq!(m.name, "api-v1");
383
384        let m = reg.resolve(None, "/api/v2/users").await.unwrap();
385        assert_eq!(m.name, "api");
386
387        let m = reg.resolve(None, "/other").await.unwrap();
388        assert_eq!(m.name, "root");
389    }
390
391    #[tokio::test]
392    async fn test_router_no_match() {
393        let reg = ServiceRegistry::new();
394
395        reg.register(make_entry("api", Some("api.example.com"), "/", vec![]))
396            .await;
397
398        let result = reg.resolve(Some("other.example.com"), "/").await;
399        assert!(result.is_none());
400    }
401
402    // -- New tests ---------------------------------------------------------
403
404    #[tokio::test]
405    async fn test_register_and_resolve_host() {
406        let reg = ServiceRegistry::new();
407
408        let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
409        reg.register(make_entry("api", Some("api.example.com"), "/", vec![addr]))
410            .await;
411
412        let resolved = reg
413            .resolve(Some("api.example.com"), "/anything")
414            .await
415            .unwrap();
416        assert_eq!(resolved.name, "api");
417        assert_eq!(resolved.backends.len(), 1);
418    }
419
420    #[tokio::test]
421    async fn test_register_and_resolve_path() {
422        let reg = ServiceRegistry::new();
423
424        let addr1: SocketAddr = "127.0.0.1:8081".parse().unwrap();
425        let addr2: SocketAddr = "127.0.0.1:8082".parse().unwrap();
426        reg.register(make_entry(
427            "api-v1",
428            Some("api.example.com"),
429            "/api/v1",
430            vec![addr1],
431        ))
432        .await;
433        reg.register(make_entry(
434            "api-v2",
435            Some("api.example.com"),
436            "/api/v2",
437            vec![addr2],
438        ))
439        .await;
440
441        let resolved = reg
442            .resolve(Some("api.example.com"), "/api/v1/users")
443            .await
444            .unwrap();
445        assert_eq!(resolved.name, "api-v1");
446
447        let resolved = reg
448            .resolve(Some("api.example.com"), "/api/v2/users")
449            .await
450            .unwrap();
451        assert_eq!(resolved.name, "api-v2");
452    }
453
454    #[tokio::test]
455    async fn test_resolve_not_found() {
456        let reg = ServiceRegistry::new();
457        let result = reg.resolve(Some("unknown.example.com"), "/").await;
458        assert!(result.is_none());
459    }
460
461    #[tokio::test]
462    async fn test_update_backends() {
463        let reg = ServiceRegistry::new();
464
465        let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
466        reg.register(make_entry("api", Some("api.example.com"), "/", vec![addr]))
467            .await;
468
469        let new_backends: Vec<SocketAddr> = vec![
470            "127.0.0.1:8081".parse().unwrap(),
471            "127.0.0.1:8082".parse().unwrap(),
472        ];
473        reg.update_backends("api", new_backends).await;
474
475        let resolved = reg.resolve(Some("api.example.com"), "/").await.unwrap();
476        assert_eq!(resolved.backends.len(), 2);
477    }
478
479    #[tokio::test]
480    async fn test_unregister_service() {
481        let reg = ServiceRegistry::new();
482
483        reg.register(make_entry("api", None, "/api", vec![])).await;
484        reg.register(make_entry("web", None, "/", vec![])).await;
485
486        assert_eq!(reg.route_count().await, 2);
487        reg.unregister_service("api").await;
488        assert_eq!(reg.route_count().await, 1);
489
490        // "api" route should be gone
491        let result = reg.resolve(None, "/api/foo").await;
492        // The "/" route still matches /api/foo so it resolves to "web"
493        assert_eq!(result.unwrap().name, "web");
494    }
495
496    #[tokio::test]
497    async fn test_list_services() {
498        let reg = ServiceRegistry::new();
499
500        reg.register(make_entry("api", None, "/api", vec![])).await;
501        reg.register(make_entry("api", None, "/api/v2", vec![]))
502            .await;
503        reg.register(make_entry("web", None, "/", vec![])).await;
504
505        let mut services = reg.list_services().await;
506        services.sort();
507        assert_eq!(services, vec!["api", "web"]);
508    }
509
510    #[tokio::test]
511    async fn test_route_count() {
512        let reg = ServiceRegistry::new();
513        assert_eq!(reg.route_count().await, 0);
514
515        reg.register(make_entry("a", None, "/a", vec![])).await;
516        reg.register(make_entry("b", None, "/b", vec![])).await;
517        reg.register(make_entry("c", None, "/c", vec![])).await;
518        assert_eq!(reg.route_count().await, 3);
519    }
520
521    #[tokio::test]
522    async fn test_add_remove_backend() {
523        let reg = ServiceRegistry::new();
524
525        let b1: SocketAddr = "127.0.0.1:8001".parse().unwrap();
526        reg.register(make_entry("api", None, "/", vec![b1])).await;
527
528        let b2: SocketAddr = "127.0.0.1:8002".parse().unwrap();
529        reg.add_backend("api", b2).await;
530
531        let resolved = reg.resolve(None, "/").await.unwrap();
532        assert_eq!(resolved.backends.len(), 2);
533        assert!(resolved.backends.contains(&b1));
534        assert!(resolved.backends.contains(&b2));
535
536        // Adding a duplicate should not create a second entry
537        reg.add_backend("api", b2).await;
538        let resolved = reg.resolve(None, "/").await.unwrap();
539        assert_eq!(resolved.backends.len(), 2);
540
541        // Remove b1
542        reg.remove_backend("api", b1).await;
543        let resolved = reg.resolve(None, "/").await.unwrap();
544        assert_eq!(resolved.backends.len(), 1);
545        assert_eq!(resolved.backends[0], b2);
546    }
547
548    #[tokio::test]
549    async fn test_from_endpoint() {
550        let endpoint = EndpointSpec {
551            name: "http".to_string(),
552            protocol: Protocol::Http,
553            port: 80,
554            target_port: Some(8080),
555            path: Some("/api".to_string()),
556            host: None,
557            expose: ExposeType::Public,
558            stream: None,
559            tunnel: None,
560        };
561
562        let entry = RouteEntry::from_endpoint("my-service", &endpoint);
563        assert_eq!(entry.service_name, "my-service");
564        assert_eq!(entry.endpoint_name, "http");
565        assert!(entry.host.is_none());
566        assert_eq!(entry.path_prefix, "/api");
567        assert_eq!(entry.resolved.name, "my-service");
568        assert_eq!(entry.resolved.protocol, Protocol::Http);
569        assert_eq!(entry.resolved.expose, ExposeType::Public);
570        assert_eq!(entry.resolved.target_port, 8080);
571        assert!(!entry.resolved.use_tls);
572        assert!(entry.resolved.backends.is_empty());
573    }
574}