Skip to main content

zlayer_proxy/
routes.rs

1//! Service registry for route resolution
2//!
3//! This module provides a production-ready `ServiceRegistry` for mapping incoming
4//! requests to backend services based on host patterns (including wildcards) and
5//! path prefixes.  Routes are stored in longest-prefix-first order so that
6//! `resolve()` returns the most specific match in O(n).
7
8use std::net::SocketAddr;
9use tokio::sync::RwLock;
10use zlayer_spec::{EndpointSpec, ExposeType, Protocol};
11
12// ---------------------------------------------------------------------------
13// ResolvedService
14// ---------------------------------------------------------------------------
15
16/// Fully-resolved service information returned by the registry.
17#[derive(Clone, Debug)]
18pub struct ResolvedService {
19    /// Service name (e.g. "api", "frontend")
20    pub name: String,
21    /// Backend addresses for load balancing
22    pub backends: Vec<SocketAddr>,
23    /// Whether to use TLS for upstream connections
24    pub use_tls: bool,
25    /// SNI hostname for TLS connections
26    pub sni_hostname: String,
27    /// Exposure type (public / internal)
28    pub expose: ExposeType,
29    /// Protocol (http, https, tcp, udp, websocket)
30    pub protocol: Protocol,
31    /// Whether to strip the matched path prefix before forwarding
32    pub strip_prefix: bool,
33    /// The path prefix this service was registered with
34    pub path_prefix: String,
35    /// The port the container actually listens on
36    pub target_port: u16,
37}
38
39// ---------------------------------------------------------------------------
40// RouteEntry
41// ---------------------------------------------------------------------------
42
43/// A single route entry in the registry.
44#[derive(Debug, Clone)]
45pub struct RouteEntry {
46    /// Owning service name (e.g. "api")
47    pub service_name: String,
48    /// Endpoint name within that service (e.g. "http", "grpc")
49    pub endpoint_name: String,
50    /// Host pattern to match.  `None` means match any host.
51    /// Supports wildcard patterns like `*.example.com`.
52    pub host: Option<String>,
53    /// Path prefix to match.  `"/"` matches all paths.
54    pub path_prefix: String,
55    /// The fully-resolved service returned on match.
56    pub resolved: ResolvedService,
57}
58
59impl RouteEntry {
60    /// Create a `RouteEntry` from a `zlayer_spec::EndpointSpec`.
61    ///
62    /// Fields that cannot be derived from the spec alone (backends, TLS,
63    /// SNI) are given sensible defaults and can be overridden after construction.
64    ///
65    /// `resolved.name` uses the composite key form
66    /// [`endpoint_lb_key`]`(service_name, endpoint.name)` so that the
67    /// load balancer can maintain a distinct backend group per endpoint,
68    /// which is required for `target_role` filtering (different endpoints
69    /// on the same service may target different replica groups).
70    #[must_use]
71    pub fn from_endpoint(service_name: &str, endpoint: &EndpointSpec) -> Self {
72        let path_prefix = endpoint.path.clone().unwrap_or_else(|| "/".to_string());
73        let target_port = endpoint.target_port();
74
75        Self {
76            service_name: service_name.to_string(),
77            endpoint_name: endpoint.name.clone(),
78            host: endpoint.host.clone(),
79            path_prefix: path_prefix.clone(),
80            resolved: ResolvedService {
81                name: endpoint_lb_key(service_name, &endpoint.name),
82                backends: Vec::new(),
83                use_tls: endpoint.protocol == Protocol::Https,
84                sni_hostname: String::new(),
85                expose: endpoint.expose,
86                protocol: endpoint.protocol,
87                strip_prefix: false,
88                path_prefix,
89                target_port,
90            },
91        }
92    }
93
94    /// Check whether this route matches the given host and path.
95    #[must_use]
96    pub fn matches(&self, host: Option<&str>, path: &str) -> bool {
97        // If the route specifies a host pattern the request must supply a
98        // host that satisfies it.
99        if let Some(ref pattern) = self.host {
100            match host {
101                Some(h) => {
102                    if !host_matches(pattern, h) {
103                        return false;
104                    }
105                }
106                None => return false,
107            }
108        }
109
110        path_matches(&self.path_prefix, path)
111    }
112}
113
114// ---------------------------------------------------------------------------
115// ServiceRegistry
116// ---------------------------------------------------------------------------
117
118/// Production-ready service registry for the `ZLayer` reverse proxy.
119///
120/// Routes are stored as a `Vec<RouteEntry>` behind a `tokio::sync::RwLock`,
121/// kept in **longest-prefix-first** order so that `resolve()` always returns
122/// the most specific match.
123pub struct ServiceRegistry {
124    /// Routes sorted by descending path-prefix length.
125    routes: RwLock<Vec<RouteEntry>>,
126}
127
128impl Default for ServiceRegistry {
129    fn default() -> Self {
130        Self::new()
131    }
132}
133
134impl ServiceRegistry {
135    /// Create an empty registry.
136    #[must_use]
137    pub fn new() -> Self {
138        Self {
139            routes: RwLock::new(Vec::new()),
140        }
141    }
142
143    /// Register a route, maintaining longest-prefix-first order.
144    pub async fn register(&self, entry: RouteEntry) {
145        let mut routes = self.routes.write().await;
146
147        let insert_idx = routes
148            .iter()
149            .position(|r| r.path_prefix.len() < entry.path_prefix.len())
150            .unwrap_or(routes.len());
151
152        routes.insert(insert_idx, entry);
153    }
154
155    /// Remove **all** routes belonging to `service_name`.
156    pub async fn unregister_service(&self, service_name: &str) {
157        let mut routes = self.routes.write().await;
158        routes.retain(|r| r.service_name != service_name);
159    }
160
161    /// Resolve an incoming request to the best-matching `ResolvedService`.
162    ///
163    /// Returns `None` when no route matches.
164    pub async fn resolve(&self, host: Option<&str>, path: &str) -> Option<ResolvedService> {
165        let routes = self.routes.read().await;
166
167        // First matching route wins (longest prefix is first due to ordering).
168        for entry in routes.iter() {
169            if entry.matches(host, path) {
170                return Some(entry.resolved.clone());
171            }
172        }
173
174        None
175    }
176
177    /// Replace the backend list for every route belonging to `service_name`.
178    pub async fn update_backends(&self, service_name: &str, backends: Vec<SocketAddr>) {
179        let mut routes = self.routes.write().await;
180        for entry in routes.iter_mut() {
181            if entry.service_name == service_name {
182                entry.resolved.backends.clone_from(&backends);
183            }
184        }
185    }
186
187    /// Replace the backend list only for routes matching both `service_name`
188    /// and `endpoint_name`.
189    ///
190    /// This is used by the agent's proxy manager to apply
191    /// `EndpointSpec.target_role` filtering: different endpoints of the same
192    /// service may have different filtered backend sets.
193    pub async fn update_backends_for_endpoint(
194        &self,
195        service_name: &str,
196        endpoint_name: &str,
197        backends: Vec<SocketAddr>,
198    ) {
199        let mut routes = self.routes.write().await;
200        for entry in routes.iter_mut() {
201            if entry.service_name == service_name && entry.endpoint_name == endpoint_name {
202                entry.resolved.backends.clone_from(&backends);
203            }
204        }
205    }
206
207    /// Append a single backend address to every route belonging to `service_name`.
208    pub async fn add_backend(&self, service_name: &str, addr: SocketAddr) {
209        let mut routes = self.routes.write().await;
210        for entry in routes.iter_mut() {
211            if entry.service_name == service_name && !entry.resolved.backends.contains(&addr) {
212                entry.resolved.backends.push(addr);
213            }
214        }
215    }
216
217    /// Remove a single backend address from every route belonging to `service_name`.
218    pub async fn remove_backend(&self, service_name: &str, addr: SocketAddr) {
219        let mut routes = self.routes.write().await;
220        for entry in routes.iter_mut() {
221            if entry.service_name == service_name {
222                entry.resolved.backends.retain(|a| *a != addr);
223            }
224        }
225    }
226
227    /// Return the unique set of service names across all registered routes.
228    pub async fn list_services(&self) -> Vec<String> {
229        let routes = self.routes.read().await;
230        let mut seen = Vec::new();
231        for entry in routes.iter() {
232            if !seen.contains(&entry.service_name) {
233                seen.push(entry.service_name.clone());
234            }
235        }
236        seen
237    }
238
239    /// Return the total number of registered routes.
240    pub async fn route_count(&self) -> usize {
241        self.routes.read().await.len()
242    }
243
244    /// Return a snapshot of all registered routes.
245    pub async fn list_routes(&self) -> Vec<RouteEntry> {
246        self.routes.read().await.clone()
247    }
248}
249
250// ---------------------------------------------------------------------------
251// Free functions
252// ---------------------------------------------------------------------------
253
254/// Build the composite load-balancer key for a service endpoint.
255///
256/// The agent's proxy manager keys backend groups by `{service}#{endpoint}`
257/// so that endpoints with different `target_role` values maintain
258/// independent backend pools. `RouteEntry::from_endpoint` sets
259/// `ResolvedService.name` to this same key so that
260/// `LoadBalancer::select(&resolved.name)` resolves to the correct
261/// per-endpoint group at request time.
262#[must_use]
263pub fn endpoint_lb_key(service_name: &str, endpoint_name: &str) -> String {
264    format!("{service_name}#{endpoint_name}")
265}
266
267/// Transform `path` by optionally stripping `prefix`.
268///
269/// When `strip` is `true` the leading `prefix` is removed.  If the result
270/// would be empty, `"/"` is returned instead.
271#[must_use]
272pub fn transform_path(prefix: &str, path: &str, strip: bool) -> String {
273    if !strip || prefix == "/" {
274        return path.to_string();
275    }
276
277    let normalized_prefix = prefix.trim_end_matches('/');
278    if let Some(remainder) = path.strip_prefix(normalized_prefix) {
279        if remainder.is_empty() {
280            "/".to_string()
281        } else {
282            remainder.to_string()
283        }
284    } else {
285        path.to_string()
286    }
287}
288
289/// Check whether `pattern` matches `host`.
290///
291/// Supports simple wildcard patterns: `*.example.com` matches any
292/// single-level subdomain such as `api.example.com`.
293fn host_matches(pattern: &str, host: &str) -> bool {
294    if pattern.starts_with("*.") {
295        let suffix = &pattern[1..]; // e.g. ".example.com"
296        host.ends_with(suffix)
297    } else {
298        pattern == host
299    }
300}
301
302/// Check whether `prefix` matches the beginning of `path` with a proper
303/// boundary check (i.e. `/api` matches `/api/foo` but not `/apiary`).
304fn path_matches(prefix: &str, path: &str) -> bool {
305    if prefix == "/" {
306        return true;
307    }
308
309    let normalized = prefix.trim_end_matches('/');
310    let normalized_path = path.trim_end_matches('/');
311
312    normalized_path.starts_with(normalized)
313        && (normalized_path.len() == normalized.len()
314            || path.as_bytes().get(normalized.len()) == Some(&b'/'))
315}
316
317// ---------------------------------------------------------------------------
318// Tests
319// ---------------------------------------------------------------------------
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324
325    // -- Helpers -----------------------------------------------------------
326
327    /// Shorthand to build a minimal `ResolvedService`.
328    fn make_resolved(name: &str, backends: Vec<SocketAddr>) -> ResolvedService {
329        ResolvedService {
330            name: name.to_string(),
331            backends,
332            use_tls: false,
333            sni_hostname: String::new(),
334            expose: ExposeType::Internal,
335            protocol: Protocol::Http,
336            strip_prefix: false,
337            path_prefix: "/".to_string(),
338            target_port: 8080,
339        }
340    }
341
342    /// Shorthand to build a `RouteEntry`.
343    fn make_entry(
344        service: &str,
345        host: Option<&str>,
346        path: &str,
347        backends: Vec<SocketAddr>,
348    ) -> RouteEntry {
349        let mut resolved = make_resolved(service, backends);
350        resolved.path_prefix = path.to_string();
351        RouteEntry {
352            service_name: service.to_string(),
353            endpoint_name: "http".to_string(),
354            host: host.map(std::string::ToString::to_string),
355            path_prefix: path.to_string(),
356            resolved,
357        }
358    }
359
360    // -- Ported from routing.rs --------------------------------------------
361
362    #[test]
363    fn test_route_path_matching() {
364        let entry = make_entry("api", None, "/api/v1", vec![]);
365
366        assert!(entry.matches(None, "/api/v1"));
367        assert!(entry.matches(None, "/api/v1/"));
368        assert!(entry.matches(None, "/api/v1/users"));
369        assert!(entry.matches(None, "/api/v1/users/123"));
370        assert!(!entry.matches(None, "/api/v2"));
371        assert!(!entry.matches(None, "/api"));
372        assert!(!entry.matches(None, "/"));
373    }
374
375    #[test]
376    fn test_route_host_matching() {
377        let entry = make_entry("api", Some("api.example.com"), "/", vec![]);
378
379        assert!(entry.matches(Some("api.example.com"), "/anything"));
380        assert!(!entry.matches(Some("other.example.com"), "/anything"));
381        assert!(!entry.matches(None, "/anything"));
382    }
383
384    #[test]
385    fn test_route_wildcard_host() {
386        let entry = make_entry("api", Some("*.example.com"), "/", vec![]);
387
388        assert!(entry.matches(Some("api.example.com"), "/"));
389        assert!(entry.matches(Some("www.example.com"), "/"));
390        assert!(entry.matches(Some("foo.example.com"), "/"));
391        assert!(!entry.matches(Some("example.com"), "/"));
392        assert!(!entry.matches(Some("other.domain.com"), "/"));
393    }
394
395    #[test]
396    fn test_route_strip_prefix() {
397        assert_eq!(transform_path("/api/v1", "/api/v1/users", true), "/users");
398        assert_eq!(
399            transform_path("/api/v1", "/api/v1/users/123", true),
400            "/users/123"
401        );
402        assert_eq!(transform_path("/api/v1", "/api/v1", true), "/");
403        assert_eq!(transform_path("/api/v1", "/other", true), "/other");
404        // strip=false should be a no-op
405        assert_eq!(
406            transform_path("/api/v1", "/api/v1/users", false),
407            "/api/v1/users"
408        );
409    }
410
411    #[tokio::test]
412    async fn test_router_longest_prefix_match() {
413        let reg = ServiceRegistry::new();
414
415        reg.register(make_entry("root", None, "/", vec![])).await;
416        reg.register(make_entry("api", None, "/api", vec![])).await;
417        reg.register(make_entry("api-v1", None, "/api/v1", vec![]))
418            .await;
419
420        let m = reg.resolve(None, "/api/v1/users").await.unwrap();
421        assert_eq!(m.name, "api-v1");
422
423        let m = reg.resolve(None, "/api/v2/users").await.unwrap();
424        assert_eq!(m.name, "api");
425
426        let m = reg.resolve(None, "/other").await.unwrap();
427        assert_eq!(m.name, "root");
428    }
429
430    #[tokio::test]
431    async fn test_router_no_match() {
432        let reg = ServiceRegistry::new();
433
434        reg.register(make_entry("api", Some("api.example.com"), "/", vec![]))
435            .await;
436
437        let result = reg.resolve(Some("other.example.com"), "/").await;
438        assert!(result.is_none());
439    }
440
441    // -- New tests ---------------------------------------------------------
442
443    #[tokio::test]
444    async fn test_register_and_resolve_host() {
445        let reg = ServiceRegistry::new();
446
447        let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
448        reg.register(make_entry("api", Some("api.example.com"), "/", vec![addr]))
449            .await;
450
451        let resolved = reg
452            .resolve(Some("api.example.com"), "/anything")
453            .await
454            .unwrap();
455        assert_eq!(resolved.name, "api");
456        assert_eq!(resolved.backends.len(), 1);
457    }
458
459    #[tokio::test]
460    async fn test_register_and_resolve_path() {
461        let reg = ServiceRegistry::new();
462
463        let addr1: SocketAddr = "127.0.0.1:8081".parse().unwrap();
464        let addr2: SocketAddr = "127.0.0.1:8082".parse().unwrap();
465        reg.register(make_entry(
466            "api-v1",
467            Some("api.example.com"),
468            "/api/v1",
469            vec![addr1],
470        ))
471        .await;
472        reg.register(make_entry(
473            "api-v2",
474            Some("api.example.com"),
475            "/api/v2",
476            vec![addr2],
477        ))
478        .await;
479
480        let resolved = reg
481            .resolve(Some("api.example.com"), "/api/v1/users")
482            .await
483            .unwrap();
484        assert_eq!(resolved.name, "api-v1");
485
486        let resolved = reg
487            .resolve(Some("api.example.com"), "/api/v2/users")
488            .await
489            .unwrap();
490        assert_eq!(resolved.name, "api-v2");
491    }
492
493    #[tokio::test]
494    async fn test_resolve_not_found() {
495        let reg = ServiceRegistry::new();
496        let result = reg.resolve(Some("unknown.example.com"), "/").await;
497        assert!(result.is_none());
498    }
499
500    #[tokio::test]
501    async fn test_update_backends() {
502        let reg = ServiceRegistry::new();
503
504        let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
505        reg.register(make_entry("api", Some("api.example.com"), "/", vec![addr]))
506            .await;
507
508        let new_backends: Vec<SocketAddr> = vec![
509            "127.0.0.1:8081".parse().unwrap(),
510            "127.0.0.1:8082".parse().unwrap(),
511        ];
512        reg.update_backends("api", new_backends).await;
513
514        let resolved = reg.resolve(Some("api.example.com"), "/").await.unwrap();
515        assert_eq!(resolved.backends.len(), 2);
516    }
517
518    #[tokio::test]
519    async fn test_unregister_service() {
520        let reg = ServiceRegistry::new();
521
522        reg.register(make_entry("api", None, "/api", vec![])).await;
523        reg.register(make_entry("web", None, "/", vec![])).await;
524
525        assert_eq!(reg.route_count().await, 2);
526        reg.unregister_service("api").await;
527        assert_eq!(reg.route_count().await, 1);
528
529        // "api" route should be gone
530        let result = reg.resolve(None, "/api/foo").await;
531        // The "/" route still matches /api/foo so it resolves to "web"
532        assert_eq!(result.unwrap().name, "web");
533    }
534
535    #[tokio::test]
536    async fn test_list_services() {
537        let reg = ServiceRegistry::new();
538
539        reg.register(make_entry("api", None, "/api", vec![])).await;
540        reg.register(make_entry("api", None, "/api/v2", vec![]))
541            .await;
542        reg.register(make_entry("web", None, "/", vec![])).await;
543
544        let mut services = reg.list_services().await;
545        services.sort();
546        assert_eq!(services, vec!["api", "web"]);
547    }
548
549    #[tokio::test]
550    async fn test_route_count() {
551        let reg = ServiceRegistry::new();
552        assert_eq!(reg.route_count().await, 0);
553
554        reg.register(make_entry("a", None, "/a", vec![])).await;
555        reg.register(make_entry("b", None, "/b", vec![])).await;
556        reg.register(make_entry("c", None, "/c", vec![])).await;
557        assert_eq!(reg.route_count().await, 3);
558    }
559
560    #[tokio::test]
561    async fn test_add_remove_backend() {
562        let reg = ServiceRegistry::new();
563
564        let b1: SocketAddr = "127.0.0.1:8001".parse().unwrap();
565        reg.register(make_entry("api", None, "/", vec![b1])).await;
566
567        let b2: SocketAddr = "127.0.0.1:8002".parse().unwrap();
568        reg.add_backend("api", b2).await;
569
570        let resolved = reg.resolve(None, "/").await.unwrap();
571        assert_eq!(resolved.backends.len(), 2);
572        assert!(resolved.backends.contains(&b1));
573        assert!(resolved.backends.contains(&b2));
574
575        // Adding a duplicate should not create a second entry
576        reg.add_backend("api", b2).await;
577        let resolved = reg.resolve(None, "/").await.unwrap();
578        assert_eq!(resolved.backends.len(), 2);
579
580        // Remove b1
581        reg.remove_backend("api", b1).await;
582        let resolved = reg.resolve(None, "/").await.unwrap();
583        assert_eq!(resolved.backends.len(), 1);
584        assert_eq!(resolved.backends[0], b2);
585    }
586
587    #[tokio::test]
588    async fn test_from_endpoint() {
589        let endpoint = EndpointSpec {
590            name: "http".to_string(),
591            protocol: Protocol::Http,
592            port: 80,
593            target_port: Some(8080),
594            path: Some("/api".to_string()),
595            host: None,
596            expose: ExposeType::Public,
597            stream: None,
598            tunnel: None,
599            target_role: None,
600        };
601
602        let entry = RouteEntry::from_endpoint("my-service", &endpoint);
603        assert_eq!(entry.service_name, "my-service");
604        assert_eq!(entry.endpoint_name, "http");
605        assert!(entry.host.is_none());
606        assert_eq!(entry.path_prefix, "/api");
607        // resolved.name is the composite per-endpoint key used by the LB
608        // (see `endpoint_lb_key`).
609        assert_eq!(entry.resolved.name, endpoint_lb_key("my-service", "http"));
610        assert_eq!(entry.resolved.protocol, Protocol::Http);
611        assert_eq!(entry.resolved.expose, ExposeType::Public);
612        assert_eq!(entry.resolved.target_port, 8080);
613        assert!(!entry.resolved.use_tls);
614        assert!(entry.resolved.backends.is_empty());
615    }
616
617    #[test]
618    fn test_endpoint_lb_key_format() {
619        assert_eq!(endpoint_lb_key("api", "http"), "api#http");
620        assert_eq!(endpoint_lb_key("postgres", "read"), "postgres#read");
621    }
622
623    #[tokio::test]
624    async fn test_update_backends_for_endpoint_isolates_endpoints() {
625        // Two endpoints on the same service should maintain independent
626        // backend pools when updated via update_backends_for_endpoint.
627        let reg = ServiceRegistry::new();
628
629        let mut http_entry = make_entry("postgres", None, "/write", vec![]);
630        http_entry.endpoint_name = "write".to_string();
631        let mut read_entry = make_entry("postgres", None, "/read", vec![]);
632        read_entry.endpoint_name = "read".to_string();
633
634        reg.register(http_entry).await;
635        reg.register(read_entry).await;
636
637        let primary: SocketAddr = "10.0.0.1:5432".parse().unwrap();
638        let replica: SocketAddr = "10.0.0.2:5432".parse().unwrap();
639
640        reg.update_backends_for_endpoint("postgres", "write", vec![primary])
641            .await;
642        reg.update_backends_for_endpoint("postgres", "read", vec![replica])
643            .await;
644
645        let write_resolved = reg.resolve(None, "/write").await.unwrap();
646        assert_eq!(write_resolved.backends, vec![primary]);
647
648        let read_resolved = reg.resolve(None, "/read").await.unwrap();
649        assert_eq!(read_resolved.backends, vec![replica]);
650    }
651}