1use std::net::SocketAddr;
9use tokio::sync::RwLock;
10use zlayer_spec::{EndpointSpec, ExposeType, Protocol};
11
12#[derive(Clone, Debug)]
18pub struct ResolvedService {
19 pub name: String,
21 pub backends: Vec<SocketAddr>,
23 pub use_tls: bool,
25 pub sni_hostname: String,
27 pub expose: ExposeType,
29 pub protocol: Protocol,
31 pub strip_prefix: bool,
33 pub path_prefix: String,
35 pub target_port: u16,
37}
38
39#[derive(Debug, Clone)]
45pub struct RouteEntry {
46 pub service_name: String,
48 pub endpoint_name: String,
50 pub host: Option<String>,
53 pub path_prefix: String,
55 pub resolved: ResolvedService,
57}
58
59impl RouteEntry {
60 #[must_use]
78 pub fn from_endpoint(
79 deployment: Option<&str>,
80 service_name: &str,
81 endpoint: &EndpointSpec,
82 ) -> Self {
83 let path_prefix = endpoint.path.clone().unwrap_or_else(|| "/".to_string());
84 let target_port = endpoint.target_port();
85
86 Self {
87 service_name: service_name.to_string(),
88 endpoint_name: endpoint.name.clone(),
89 host: endpoint.host.clone(),
90 path_prefix: path_prefix.clone(),
91 resolved: ResolvedService {
92 name: endpoint_lb_key(deployment, service_name, &endpoint.name),
93 backends: Vec::new(),
94 use_tls: endpoint.protocol == Protocol::Https,
95 sni_hostname: String::new(),
96 expose: endpoint.expose,
97 protocol: endpoint.protocol,
98 strip_prefix: false,
99 path_prefix,
100 target_port,
101 },
102 }
103 }
104
105 #[must_use]
107 pub fn matches(&self, host: Option<&str>, path: &str) -> bool {
108 if let Some(ref pattern) = self.host {
111 match host {
112 Some(h) => {
113 if !host_matches(pattern, h) {
114 return false;
115 }
116 }
117 None => return false,
118 }
119 }
120
121 path_matches(&self.path_prefix, path)
122 }
123}
124
125pub struct ServiceRegistry {
135 routes: RwLock<Vec<RouteEntry>>,
137}
138
139impl Default for ServiceRegistry {
140 fn default() -> Self {
141 Self::new()
142 }
143}
144
145impl ServiceRegistry {
146 #[must_use]
148 pub fn new() -> Self {
149 Self {
150 routes: RwLock::new(Vec::new()),
151 }
152 }
153
154 pub async fn register(&self, entry: RouteEntry) {
156 let mut routes = self.routes.write().await;
157
158 let insert_idx = routes
159 .iter()
160 .position(|r| r.path_prefix.len() < entry.path_prefix.len())
161 .unwrap_or(routes.len());
162
163 routes.insert(insert_idx, entry);
164 }
165
166 pub async fn unregister_service(&self, service_name: &str) {
168 let mut routes = self.routes.write().await;
169 routes.retain(|r| r.service_name != service_name);
170 }
171
172 pub async fn resolve(&self, host: Option<&str>, path: &str) -> Option<ResolvedService> {
176 let routes = self.routes.read().await;
177
178 for entry in routes.iter() {
180 if entry.matches(host, path) {
181 return Some(entry.resolved.clone());
182 }
183 }
184
185 None
186 }
187
188 pub async fn update_backends(&self, service_name: &str, backends: Vec<SocketAddr>) {
190 let mut routes = self.routes.write().await;
191 for entry in routes.iter_mut() {
192 if entry.service_name == service_name {
193 entry.resolved.backends.clone_from(&backends);
194 }
195 }
196 }
197
198 pub async fn update_backends_for_endpoint(
205 &self,
206 service_name: &str,
207 endpoint_name: &str,
208 backends: Vec<SocketAddr>,
209 ) {
210 let mut routes = self.routes.write().await;
211 for entry in routes.iter_mut() {
212 if entry.service_name == service_name && entry.endpoint_name == endpoint_name {
213 entry.resolved.backends.clone_from(&backends);
214 }
215 }
216 }
217
218 pub async fn add_backend(&self, service_name: &str, addr: SocketAddr) {
220 let mut routes = self.routes.write().await;
221 for entry in routes.iter_mut() {
222 if entry.service_name == service_name && !entry.resolved.backends.contains(&addr) {
223 entry.resolved.backends.push(addr);
224 }
225 }
226 }
227
228 pub async fn remove_backend(&self, service_name: &str, addr: SocketAddr) {
230 let mut routes = self.routes.write().await;
231 for entry in routes.iter_mut() {
232 if entry.service_name == service_name {
233 entry.resolved.backends.retain(|a| *a != addr);
234 }
235 }
236 }
237
238 pub async fn list_services(&self) -> Vec<String> {
240 let routes = self.routes.read().await;
241 let mut seen = Vec::new();
242 for entry in routes.iter() {
243 if !seen.contains(&entry.service_name) {
244 seen.push(entry.service_name.clone());
245 }
246 }
247 seen
248 }
249
250 pub async fn route_count(&self) -> usize {
252 self.routes.read().await.len()
253 }
254
255 pub async fn list_routes(&self) -> Vec<RouteEntry> {
257 self.routes.read().await.clone()
258 }
259}
260
261#[must_use]
287pub fn endpoint_lb_key(
288 deployment: Option<&str>,
289 service_name: &str,
290 endpoint_name: &str,
291) -> String {
292 let scope = deployment.unwrap_or("_");
293 format!("{scope}/{service_name}#{endpoint_name}")
294}
295
296#[must_use]
301pub fn transform_path(prefix: &str, path: &str, strip: bool) -> String {
302 if !strip || prefix == "/" {
303 return path.to_string();
304 }
305
306 let normalized_prefix = prefix.trim_end_matches('/');
307 if let Some(remainder) = path.strip_prefix(normalized_prefix) {
308 if remainder.is_empty() {
309 "/".to_string()
310 } else {
311 remainder.to_string()
312 }
313 } else {
314 path.to_string()
315 }
316}
317
318fn host_matches(pattern: &str, host: &str) -> bool {
323 if pattern.starts_with("*.") {
324 let suffix = &pattern[1..]; host.ends_with(suffix)
326 } else {
327 pattern == host
328 }
329}
330
331fn path_matches(prefix: &str, path: &str) -> bool {
334 if prefix == "/" {
335 return true;
336 }
337
338 let normalized = prefix.trim_end_matches('/');
339 let normalized_path = path.trim_end_matches('/');
340
341 normalized_path.starts_with(normalized)
342 && (normalized_path.len() == normalized.len()
343 || path.as_bytes().get(normalized.len()) == Some(&b'/'))
344}
345
346#[cfg(test)]
351mod tests {
352 use super::*;
353
354 fn make_resolved(name: &str, backends: Vec<SocketAddr>) -> ResolvedService {
358 ResolvedService {
359 name: name.to_string(),
360 backends,
361 use_tls: false,
362 sni_hostname: String::new(),
363 expose: ExposeType::Internal,
364 protocol: Protocol::Http,
365 strip_prefix: false,
366 path_prefix: "/".to_string(),
367 target_port: 8080,
368 }
369 }
370
371 fn make_entry(
373 service: &str,
374 host: Option<&str>,
375 path: &str,
376 backends: Vec<SocketAddr>,
377 ) -> RouteEntry {
378 let mut resolved = make_resolved(service, backends);
379 resolved.path_prefix = path.to_string();
380 RouteEntry {
381 service_name: service.to_string(),
382 endpoint_name: "http".to_string(),
383 host: host.map(std::string::ToString::to_string),
384 path_prefix: path.to_string(),
385 resolved,
386 }
387 }
388
389 #[test]
392 fn test_route_path_matching() {
393 let entry = make_entry("api", None, "/api/v1", vec![]);
394
395 assert!(entry.matches(None, "/api/v1"));
396 assert!(entry.matches(None, "/api/v1/"));
397 assert!(entry.matches(None, "/api/v1/users"));
398 assert!(entry.matches(None, "/api/v1/users/123"));
399 assert!(!entry.matches(None, "/api/v2"));
400 assert!(!entry.matches(None, "/api"));
401 assert!(!entry.matches(None, "/"));
402 }
403
404 #[test]
405 fn test_route_host_matching() {
406 let entry = make_entry("api", Some("api.example.com"), "/", vec![]);
407
408 assert!(entry.matches(Some("api.example.com"), "/anything"));
409 assert!(!entry.matches(Some("other.example.com"), "/anything"));
410 assert!(!entry.matches(None, "/anything"));
411 }
412
413 #[test]
414 fn test_route_wildcard_host() {
415 let entry = make_entry("api", Some("*.example.com"), "/", vec![]);
416
417 assert!(entry.matches(Some("api.example.com"), "/"));
418 assert!(entry.matches(Some("www.example.com"), "/"));
419 assert!(entry.matches(Some("foo.example.com"), "/"));
420 assert!(!entry.matches(Some("example.com"), "/"));
421 assert!(!entry.matches(Some("other.domain.com"), "/"));
422 }
423
424 #[test]
425 fn test_route_strip_prefix() {
426 assert_eq!(transform_path("/api/v1", "/api/v1/users", true), "/users");
427 assert_eq!(
428 transform_path("/api/v1", "/api/v1/users/123", true),
429 "/users/123"
430 );
431 assert_eq!(transform_path("/api/v1", "/api/v1", true), "/");
432 assert_eq!(transform_path("/api/v1", "/other", true), "/other");
433 assert_eq!(
435 transform_path("/api/v1", "/api/v1/users", false),
436 "/api/v1/users"
437 );
438 }
439
440 #[tokio::test]
441 async fn test_router_longest_prefix_match() {
442 let reg = ServiceRegistry::new();
443
444 reg.register(make_entry("root", None, "/", vec![])).await;
445 reg.register(make_entry("api", None, "/api", vec![])).await;
446 reg.register(make_entry("api-v1", None, "/api/v1", vec![]))
447 .await;
448
449 let m = reg.resolve(None, "/api/v1/users").await.unwrap();
450 assert_eq!(m.name, "api-v1");
451
452 let m = reg.resolve(None, "/api/v2/users").await.unwrap();
453 assert_eq!(m.name, "api");
454
455 let m = reg.resolve(None, "/other").await.unwrap();
456 assert_eq!(m.name, "root");
457 }
458
459 #[tokio::test]
460 async fn test_router_no_match() {
461 let reg = ServiceRegistry::new();
462
463 reg.register(make_entry("api", Some("api.example.com"), "/", vec![]))
464 .await;
465
466 let result = reg.resolve(Some("other.example.com"), "/").await;
467 assert!(result.is_none());
468 }
469
470 #[tokio::test]
473 async fn test_register_and_resolve_host() {
474 let reg = ServiceRegistry::new();
475
476 let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
477 reg.register(make_entry("api", Some("api.example.com"), "/", vec![addr]))
478 .await;
479
480 let resolved = reg
481 .resolve(Some("api.example.com"), "/anything")
482 .await
483 .unwrap();
484 assert_eq!(resolved.name, "api");
485 assert_eq!(resolved.backends.len(), 1);
486 }
487
488 #[tokio::test]
489 async fn test_register_and_resolve_path() {
490 let reg = ServiceRegistry::new();
491
492 let addr1: SocketAddr = "127.0.0.1:8081".parse().unwrap();
493 let addr2: SocketAddr = "127.0.0.1:8082".parse().unwrap();
494 reg.register(make_entry(
495 "api-v1",
496 Some("api.example.com"),
497 "/api/v1",
498 vec![addr1],
499 ))
500 .await;
501 reg.register(make_entry(
502 "api-v2",
503 Some("api.example.com"),
504 "/api/v2",
505 vec![addr2],
506 ))
507 .await;
508
509 let resolved = reg
510 .resolve(Some("api.example.com"), "/api/v1/users")
511 .await
512 .unwrap();
513 assert_eq!(resolved.name, "api-v1");
514
515 let resolved = reg
516 .resolve(Some("api.example.com"), "/api/v2/users")
517 .await
518 .unwrap();
519 assert_eq!(resolved.name, "api-v2");
520 }
521
522 #[tokio::test]
523 async fn test_resolve_not_found() {
524 let reg = ServiceRegistry::new();
525 let result = reg.resolve(Some("unknown.example.com"), "/").await;
526 assert!(result.is_none());
527 }
528
529 #[tokio::test]
530 async fn test_update_backends() {
531 let reg = ServiceRegistry::new();
532
533 let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
534 reg.register(make_entry("api", Some("api.example.com"), "/", vec![addr]))
535 .await;
536
537 let new_backends: Vec<SocketAddr> = vec![
538 "127.0.0.1:8081".parse().unwrap(),
539 "127.0.0.1:8082".parse().unwrap(),
540 ];
541 reg.update_backends("api", new_backends).await;
542
543 let resolved = reg.resolve(Some("api.example.com"), "/").await.unwrap();
544 assert_eq!(resolved.backends.len(), 2);
545 }
546
547 #[tokio::test]
548 async fn test_unregister_service() {
549 let reg = ServiceRegistry::new();
550
551 reg.register(make_entry("api", None, "/api", vec![])).await;
552 reg.register(make_entry("web", None, "/", vec![])).await;
553
554 assert_eq!(reg.route_count().await, 2);
555 reg.unregister_service("api").await;
556 assert_eq!(reg.route_count().await, 1);
557
558 let result = reg.resolve(None, "/api/foo").await;
560 assert_eq!(result.unwrap().name, "web");
562 }
563
564 #[tokio::test]
565 async fn test_list_services() {
566 let reg = ServiceRegistry::new();
567
568 reg.register(make_entry("api", None, "/api", vec![])).await;
569 reg.register(make_entry("api", None, "/api/v2", vec![]))
570 .await;
571 reg.register(make_entry("web", None, "/", vec![])).await;
572
573 let mut services = reg.list_services().await;
574 services.sort();
575 assert_eq!(services, vec!["api", "web"]);
576 }
577
578 #[tokio::test]
579 async fn test_route_count() {
580 let reg = ServiceRegistry::new();
581 assert_eq!(reg.route_count().await, 0);
582
583 reg.register(make_entry("a", None, "/a", vec![])).await;
584 reg.register(make_entry("b", None, "/b", vec![])).await;
585 reg.register(make_entry("c", None, "/c", vec![])).await;
586 assert_eq!(reg.route_count().await, 3);
587 }
588
589 #[tokio::test]
590 async fn test_add_remove_backend() {
591 let reg = ServiceRegistry::new();
592
593 let b1: SocketAddr = "127.0.0.1:8001".parse().unwrap();
594 reg.register(make_entry("api", None, "/", vec![b1])).await;
595
596 let b2: SocketAddr = "127.0.0.1:8002".parse().unwrap();
597 reg.add_backend("api", b2).await;
598
599 let resolved = reg.resolve(None, "/").await.unwrap();
600 assert_eq!(resolved.backends.len(), 2);
601 assert!(resolved.backends.contains(&b1));
602 assert!(resolved.backends.contains(&b2));
603
604 reg.add_backend("api", b2).await;
606 let resolved = reg.resolve(None, "/").await.unwrap();
607 assert_eq!(resolved.backends.len(), 2);
608
609 reg.remove_backend("api", b1).await;
611 let resolved = reg.resolve(None, "/").await.unwrap();
612 assert_eq!(resolved.backends.len(), 1);
613 assert_eq!(resolved.backends[0], b2);
614 }
615
616 #[tokio::test]
617 async fn test_from_endpoint() {
618 let endpoint = EndpointSpec {
619 name: "http".to_string(),
620 protocol: Protocol::Http,
621 port: 80,
622 target_port: Some(8080),
623 path: Some("/api".to_string()),
624 host: None,
625 expose: ExposeType::Public,
626 stream: None,
627 tunnel: None,
628 target_role: None,
629 };
630
631 let entry = RouteEntry::from_endpoint(Some("prod"), "my-service", &endpoint);
632 assert_eq!(entry.service_name, "my-service");
633 assert_eq!(entry.endpoint_name, "http");
634 assert!(entry.host.is_none());
635 assert_eq!(entry.path_prefix, "/api");
636 assert_eq!(
639 entry.resolved.name,
640 endpoint_lb_key(Some("prod"), "my-service", "http")
641 );
642 assert_eq!(entry.resolved.protocol, Protocol::Http);
643 assert_eq!(entry.resolved.expose, ExposeType::Public);
644 assert_eq!(entry.resolved.target_port, 8080);
645 assert!(!entry.resolved.use_tls);
646 assert!(entry.resolved.backends.is_empty());
647 }
648
649 #[test]
650 fn test_endpoint_lb_key_format() {
651 assert_eq!(
655 endpoint_lb_key(Some("prod"), "api", "http"),
656 "prod/api#http"
657 );
658 assert_eq!(
659 endpoint_lb_key(Some("staging"), "api", "http"),
660 "staging/api#http"
661 );
662 assert_ne!(
665 endpoint_lb_key(Some("a"), "zregistry", "port-8080"),
666 endpoint_lb_key(Some("b"), "zregistry", "port-8080")
667 );
668 assert_eq!(endpoint_lb_key(None, "postgres", "read"), "_/postgres#read");
670 }
671
672 #[tokio::test]
673 async fn test_update_backends_for_endpoint_isolates_endpoints() {
674 let reg = ServiceRegistry::new();
677
678 let mut http_entry = make_entry("postgres", None, "/write", vec![]);
679 http_entry.endpoint_name = "write".to_string();
680 let mut read_entry = make_entry("postgres", None, "/read", vec![]);
681 read_entry.endpoint_name = "read".to_string();
682
683 reg.register(http_entry).await;
684 reg.register(read_entry).await;
685
686 let primary: SocketAddr = "10.0.0.1:5432".parse().unwrap();
687 let replica: SocketAddr = "10.0.0.2:5432".parse().unwrap();
688
689 reg.update_backends_for_endpoint("postgres", "write", vec![primary])
690 .await;
691 reg.update_backends_for_endpoint("postgres", "read", vec![replica])
692 .await;
693
694 let write_resolved = reg.resolve(None, "/write").await.unwrap();
695 assert_eq!(write_resolved.backends, vec![primary]);
696
697 let read_resolved = reg.resolve(None, "/read").await.unwrap();
698 assert_eq!(read_resolved.backends, vec![replica]);
699 }
700}