1use std::net::SocketAddr;
9use tokio::sync::RwLock;
10use zlayer_spec::{EndpointSpec, ExposeType, Protocol};
11
12#[derive(Clone, Debug)]
18pub struct ResolvedService {
19 pub name: String,
21 pub backends: Vec<SocketAddr>,
23 pub use_tls: bool,
25 pub sni_hostname: String,
27 pub expose: ExposeType,
29 pub protocol: Protocol,
31 pub strip_prefix: bool,
33 pub path_prefix: String,
35 pub target_port: u16,
37}
38
39#[derive(Debug, Clone)]
45pub struct RouteEntry {
46 pub service_name: String,
48 pub endpoint_name: String,
50 pub host: Option<String>,
53 pub path_prefix: String,
55 pub resolved: ResolvedService,
57}
58
59impl RouteEntry {
60 #[must_use]
71 pub fn from_endpoint(service_name: &str, endpoint: &EndpointSpec) -> Self {
72 let path_prefix = endpoint.path.clone().unwrap_or_else(|| "/".to_string());
73 let target_port = endpoint.target_port();
74
75 Self {
76 service_name: service_name.to_string(),
77 endpoint_name: endpoint.name.clone(),
78 host: endpoint.host.clone(),
79 path_prefix: path_prefix.clone(),
80 resolved: ResolvedService {
81 name: endpoint_lb_key(service_name, &endpoint.name),
82 backends: Vec::new(),
83 use_tls: endpoint.protocol == Protocol::Https,
84 sni_hostname: String::new(),
85 expose: endpoint.expose,
86 protocol: endpoint.protocol,
87 strip_prefix: false,
88 path_prefix,
89 target_port,
90 },
91 }
92 }
93
94 #[must_use]
96 pub fn matches(&self, host: Option<&str>, path: &str) -> bool {
97 if let Some(ref pattern) = self.host {
100 match host {
101 Some(h) => {
102 if !host_matches(pattern, h) {
103 return false;
104 }
105 }
106 None => return false,
107 }
108 }
109
110 path_matches(&self.path_prefix, path)
111 }
112}
113
114pub struct ServiceRegistry {
124 routes: RwLock<Vec<RouteEntry>>,
126}
127
128impl Default for ServiceRegistry {
129 fn default() -> Self {
130 Self::new()
131 }
132}
133
134impl ServiceRegistry {
135 #[must_use]
137 pub fn new() -> Self {
138 Self {
139 routes: RwLock::new(Vec::new()),
140 }
141 }
142
143 pub async fn register(&self, entry: RouteEntry) {
145 let mut routes = self.routes.write().await;
146
147 let insert_idx = routes
148 .iter()
149 .position(|r| r.path_prefix.len() < entry.path_prefix.len())
150 .unwrap_or(routes.len());
151
152 routes.insert(insert_idx, entry);
153 }
154
155 pub async fn unregister_service(&self, service_name: &str) {
157 let mut routes = self.routes.write().await;
158 routes.retain(|r| r.service_name != service_name);
159 }
160
161 pub async fn resolve(&self, host: Option<&str>, path: &str) -> Option<ResolvedService> {
165 let routes = self.routes.read().await;
166
167 for entry in routes.iter() {
169 if entry.matches(host, path) {
170 return Some(entry.resolved.clone());
171 }
172 }
173
174 None
175 }
176
177 pub async fn update_backends(&self, service_name: &str, backends: Vec<SocketAddr>) {
179 let mut routes = self.routes.write().await;
180 for entry in routes.iter_mut() {
181 if entry.service_name == service_name {
182 entry.resolved.backends.clone_from(&backends);
183 }
184 }
185 }
186
187 pub async fn update_backends_for_endpoint(
194 &self,
195 service_name: &str,
196 endpoint_name: &str,
197 backends: Vec<SocketAddr>,
198 ) {
199 let mut routes = self.routes.write().await;
200 for entry in routes.iter_mut() {
201 if entry.service_name == service_name && entry.endpoint_name == endpoint_name {
202 entry.resolved.backends.clone_from(&backends);
203 }
204 }
205 }
206
207 pub async fn add_backend(&self, service_name: &str, addr: SocketAddr) {
209 let mut routes = self.routes.write().await;
210 for entry in routes.iter_mut() {
211 if entry.service_name == service_name && !entry.resolved.backends.contains(&addr) {
212 entry.resolved.backends.push(addr);
213 }
214 }
215 }
216
217 pub async fn remove_backend(&self, service_name: &str, addr: SocketAddr) {
219 let mut routes = self.routes.write().await;
220 for entry in routes.iter_mut() {
221 if entry.service_name == service_name {
222 entry.resolved.backends.retain(|a| *a != addr);
223 }
224 }
225 }
226
227 pub async fn list_services(&self) -> Vec<String> {
229 let routes = self.routes.read().await;
230 let mut seen = Vec::new();
231 for entry in routes.iter() {
232 if !seen.contains(&entry.service_name) {
233 seen.push(entry.service_name.clone());
234 }
235 }
236 seen
237 }
238
239 pub async fn route_count(&self) -> usize {
241 self.routes.read().await.len()
242 }
243
244 pub async fn list_routes(&self) -> Vec<RouteEntry> {
246 self.routes.read().await.clone()
247 }
248}
249
250#[must_use]
263pub fn endpoint_lb_key(service_name: &str, endpoint_name: &str) -> String {
264 format!("{service_name}#{endpoint_name}")
265}
266
267#[must_use]
272pub fn transform_path(prefix: &str, path: &str, strip: bool) -> String {
273 if !strip || prefix == "/" {
274 return path.to_string();
275 }
276
277 let normalized_prefix = prefix.trim_end_matches('/');
278 if let Some(remainder) = path.strip_prefix(normalized_prefix) {
279 if remainder.is_empty() {
280 "/".to_string()
281 } else {
282 remainder.to_string()
283 }
284 } else {
285 path.to_string()
286 }
287}
288
289fn host_matches(pattern: &str, host: &str) -> bool {
294 if pattern.starts_with("*.") {
295 let suffix = &pattern[1..]; host.ends_with(suffix)
297 } else {
298 pattern == host
299 }
300}
301
302fn path_matches(prefix: &str, path: &str) -> bool {
305 if prefix == "/" {
306 return true;
307 }
308
309 let normalized = prefix.trim_end_matches('/');
310 let normalized_path = path.trim_end_matches('/');
311
312 normalized_path.starts_with(normalized)
313 && (normalized_path.len() == normalized.len()
314 || path.as_bytes().get(normalized.len()) == Some(&b'/'))
315}
316
317#[cfg(test)]
322mod tests {
323 use super::*;
324
325 fn make_resolved(name: &str, backends: Vec<SocketAddr>) -> ResolvedService {
329 ResolvedService {
330 name: name.to_string(),
331 backends,
332 use_tls: false,
333 sni_hostname: String::new(),
334 expose: ExposeType::Internal,
335 protocol: Protocol::Http,
336 strip_prefix: false,
337 path_prefix: "/".to_string(),
338 target_port: 8080,
339 }
340 }
341
342 fn make_entry(
344 service: &str,
345 host: Option<&str>,
346 path: &str,
347 backends: Vec<SocketAddr>,
348 ) -> RouteEntry {
349 let mut resolved = make_resolved(service, backends);
350 resolved.path_prefix = path.to_string();
351 RouteEntry {
352 service_name: service.to_string(),
353 endpoint_name: "http".to_string(),
354 host: host.map(std::string::ToString::to_string),
355 path_prefix: path.to_string(),
356 resolved,
357 }
358 }
359
360 #[test]
363 fn test_route_path_matching() {
364 let entry = make_entry("api", None, "/api/v1", vec![]);
365
366 assert!(entry.matches(None, "/api/v1"));
367 assert!(entry.matches(None, "/api/v1/"));
368 assert!(entry.matches(None, "/api/v1/users"));
369 assert!(entry.matches(None, "/api/v1/users/123"));
370 assert!(!entry.matches(None, "/api/v2"));
371 assert!(!entry.matches(None, "/api"));
372 assert!(!entry.matches(None, "/"));
373 }
374
375 #[test]
376 fn test_route_host_matching() {
377 let entry = make_entry("api", Some("api.example.com"), "/", vec![]);
378
379 assert!(entry.matches(Some("api.example.com"), "/anything"));
380 assert!(!entry.matches(Some("other.example.com"), "/anything"));
381 assert!(!entry.matches(None, "/anything"));
382 }
383
384 #[test]
385 fn test_route_wildcard_host() {
386 let entry = make_entry("api", Some("*.example.com"), "/", vec![]);
387
388 assert!(entry.matches(Some("api.example.com"), "/"));
389 assert!(entry.matches(Some("www.example.com"), "/"));
390 assert!(entry.matches(Some("foo.example.com"), "/"));
391 assert!(!entry.matches(Some("example.com"), "/"));
392 assert!(!entry.matches(Some("other.domain.com"), "/"));
393 }
394
395 #[test]
396 fn test_route_strip_prefix() {
397 assert_eq!(transform_path("/api/v1", "/api/v1/users", true), "/users");
398 assert_eq!(
399 transform_path("/api/v1", "/api/v1/users/123", true),
400 "/users/123"
401 );
402 assert_eq!(transform_path("/api/v1", "/api/v1", true), "/");
403 assert_eq!(transform_path("/api/v1", "/other", true), "/other");
404 assert_eq!(
406 transform_path("/api/v1", "/api/v1/users", false),
407 "/api/v1/users"
408 );
409 }
410
411 #[tokio::test]
412 async fn test_router_longest_prefix_match() {
413 let reg = ServiceRegistry::new();
414
415 reg.register(make_entry("root", None, "/", vec![])).await;
416 reg.register(make_entry("api", None, "/api", vec![])).await;
417 reg.register(make_entry("api-v1", None, "/api/v1", vec![]))
418 .await;
419
420 let m = reg.resolve(None, "/api/v1/users").await.unwrap();
421 assert_eq!(m.name, "api-v1");
422
423 let m = reg.resolve(None, "/api/v2/users").await.unwrap();
424 assert_eq!(m.name, "api");
425
426 let m = reg.resolve(None, "/other").await.unwrap();
427 assert_eq!(m.name, "root");
428 }
429
430 #[tokio::test]
431 async fn test_router_no_match() {
432 let reg = ServiceRegistry::new();
433
434 reg.register(make_entry("api", Some("api.example.com"), "/", vec![]))
435 .await;
436
437 let result = reg.resolve(Some("other.example.com"), "/").await;
438 assert!(result.is_none());
439 }
440
441 #[tokio::test]
444 async fn test_register_and_resolve_host() {
445 let reg = ServiceRegistry::new();
446
447 let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
448 reg.register(make_entry("api", Some("api.example.com"), "/", vec![addr]))
449 .await;
450
451 let resolved = reg
452 .resolve(Some("api.example.com"), "/anything")
453 .await
454 .unwrap();
455 assert_eq!(resolved.name, "api");
456 assert_eq!(resolved.backends.len(), 1);
457 }
458
459 #[tokio::test]
460 async fn test_register_and_resolve_path() {
461 let reg = ServiceRegistry::new();
462
463 let addr1: SocketAddr = "127.0.0.1:8081".parse().unwrap();
464 let addr2: SocketAddr = "127.0.0.1:8082".parse().unwrap();
465 reg.register(make_entry(
466 "api-v1",
467 Some("api.example.com"),
468 "/api/v1",
469 vec![addr1],
470 ))
471 .await;
472 reg.register(make_entry(
473 "api-v2",
474 Some("api.example.com"),
475 "/api/v2",
476 vec![addr2],
477 ))
478 .await;
479
480 let resolved = reg
481 .resolve(Some("api.example.com"), "/api/v1/users")
482 .await
483 .unwrap();
484 assert_eq!(resolved.name, "api-v1");
485
486 let resolved = reg
487 .resolve(Some("api.example.com"), "/api/v2/users")
488 .await
489 .unwrap();
490 assert_eq!(resolved.name, "api-v2");
491 }
492
493 #[tokio::test]
494 async fn test_resolve_not_found() {
495 let reg = ServiceRegistry::new();
496 let result = reg.resolve(Some("unknown.example.com"), "/").await;
497 assert!(result.is_none());
498 }
499
500 #[tokio::test]
501 async fn test_update_backends() {
502 let reg = ServiceRegistry::new();
503
504 let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
505 reg.register(make_entry("api", Some("api.example.com"), "/", vec![addr]))
506 .await;
507
508 let new_backends: Vec<SocketAddr> = vec![
509 "127.0.0.1:8081".parse().unwrap(),
510 "127.0.0.1:8082".parse().unwrap(),
511 ];
512 reg.update_backends("api", new_backends).await;
513
514 let resolved = reg.resolve(Some("api.example.com"), "/").await.unwrap();
515 assert_eq!(resolved.backends.len(), 2);
516 }
517
518 #[tokio::test]
519 async fn test_unregister_service() {
520 let reg = ServiceRegistry::new();
521
522 reg.register(make_entry("api", None, "/api", vec![])).await;
523 reg.register(make_entry("web", None, "/", vec![])).await;
524
525 assert_eq!(reg.route_count().await, 2);
526 reg.unregister_service("api").await;
527 assert_eq!(reg.route_count().await, 1);
528
529 let result = reg.resolve(None, "/api/foo").await;
531 assert_eq!(result.unwrap().name, "web");
533 }
534
535 #[tokio::test]
536 async fn test_list_services() {
537 let reg = ServiceRegistry::new();
538
539 reg.register(make_entry("api", None, "/api", vec![])).await;
540 reg.register(make_entry("api", None, "/api/v2", vec![]))
541 .await;
542 reg.register(make_entry("web", None, "/", vec![])).await;
543
544 let mut services = reg.list_services().await;
545 services.sort();
546 assert_eq!(services, vec!["api", "web"]);
547 }
548
549 #[tokio::test]
550 async fn test_route_count() {
551 let reg = ServiceRegistry::new();
552 assert_eq!(reg.route_count().await, 0);
553
554 reg.register(make_entry("a", None, "/a", vec![])).await;
555 reg.register(make_entry("b", None, "/b", vec![])).await;
556 reg.register(make_entry("c", None, "/c", vec![])).await;
557 assert_eq!(reg.route_count().await, 3);
558 }
559
560 #[tokio::test]
561 async fn test_add_remove_backend() {
562 let reg = ServiceRegistry::new();
563
564 let b1: SocketAddr = "127.0.0.1:8001".parse().unwrap();
565 reg.register(make_entry("api", None, "/", vec![b1])).await;
566
567 let b2: SocketAddr = "127.0.0.1:8002".parse().unwrap();
568 reg.add_backend("api", b2).await;
569
570 let resolved = reg.resolve(None, "/").await.unwrap();
571 assert_eq!(resolved.backends.len(), 2);
572 assert!(resolved.backends.contains(&b1));
573 assert!(resolved.backends.contains(&b2));
574
575 reg.add_backend("api", b2).await;
577 let resolved = reg.resolve(None, "/").await.unwrap();
578 assert_eq!(resolved.backends.len(), 2);
579
580 reg.remove_backend("api", b1).await;
582 let resolved = reg.resolve(None, "/").await.unwrap();
583 assert_eq!(resolved.backends.len(), 1);
584 assert_eq!(resolved.backends[0], b2);
585 }
586
587 #[tokio::test]
588 async fn test_from_endpoint() {
589 let endpoint = EndpointSpec {
590 name: "http".to_string(),
591 protocol: Protocol::Http,
592 port: 80,
593 target_port: Some(8080),
594 path: Some("/api".to_string()),
595 host: None,
596 expose: ExposeType::Public,
597 stream: None,
598 tunnel: None,
599 target_role: None,
600 };
601
602 let entry = RouteEntry::from_endpoint("my-service", &endpoint);
603 assert_eq!(entry.service_name, "my-service");
604 assert_eq!(entry.endpoint_name, "http");
605 assert!(entry.host.is_none());
606 assert_eq!(entry.path_prefix, "/api");
607 assert_eq!(entry.resolved.name, endpoint_lb_key("my-service", "http"));
610 assert_eq!(entry.resolved.protocol, Protocol::Http);
611 assert_eq!(entry.resolved.expose, ExposeType::Public);
612 assert_eq!(entry.resolved.target_port, 8080);
613 assert!(!entry.resolved.use_tls);
614 assert!(entry.resolved.backends.is_empty());
615 }
616
617 #[test]
618 fn test_endpoint_lb_key_format() {
619 assert_eq!(endpoint_lb_key("api", "http"), "api#http");
620 assert_eq!(endpoint_lb_key("postgres", "read"), "postgres#read");
621 }
622
623 #[tokio::test]
624 async fn test_update_backends_for_endpoint_isolates_endpoints() {
625 let reg = ServiceRegistry::new();
628
629 let mut http_entry = make_entry("postgres", None, "/write", vec![]);
630 http_entry.endpoint_name = "write".to_string();
631 let mut read_entry = make_entry("postgres", None, "/read", vec![]);
632 read_entry.endpoint_name = "read".to_string();
633
634 reg.register(http_entry).await;
635 reg.register(read_entry).await;
636
637 let primary: SocketAddr = "10.0.0.1:5432".parse().unwrap();
638 let replica: SocketAddr = "10.0.0.2:5432".parse().unwrap();
639
640 reg.update_backends_for_endpoint("postgres", "write", vec![primary])
641 .await;
642 reg.update_backends_for_endpoint("postgres", "read", vec![replica])
643 .await;
644
645 let write_resolved = reg.resolve(None, "/write").await.unwrap();
646 assert_eq!(write_resolved.backends, vec![primary]);
647
648 let read_resolved = reg.resolve(None, "/read").await.unwrap();
649 assert_eq!(read_resolved.backends, vec![replica]);
650 }
651}