1use std::net::SocketAddr;
9use tokio::sync::RwLock;
10use zlayer_spec::{EndpointSpec, ExposeType, Protocol};
11
12#[derive(Clone, Debug)]
18pub struct ResolvedService {
19 pub name: String,
21 pub backends: Vec<SocketAddr>,
23 pub use_tls: bool,
25 pub sni_hostname: String,
27 pub expose: ExposeType,
29 pub protocol: Protocol,
31 pub strip_prefix: bool,
33 pub path_prefix: String,
35 pub target_port: u16,
37}
38
39#[derive(Debug, Clone)]
45pub struct RouteEntry {
46 pub service_name: String,
48 pub endpoint_name: String,
50 pub host: Option<String>,
53 pub path_prefix: String,
55 pub resolved: ResolvedService,
57}
58
59impl RouteEntry {
60 #[must_use]
65 pub fn from_endpoint(service_name: &str, endpoint: &EndpointSpec) -> Self {
66 let path_prefix = endpoint.path.clone().unwrap_or_else(|| "/".to_string());
67 let target_port = endpoint.target_port();
68
69 Self {
70 service_name: service_name.to_string(),
71 endpoint_name: endpoint.name.clone(),
72 host: endpoint.host.clone(),
73 path_prefix: path_prefix.clone(),
74 resolved: ResolvedService {
75 name: service_name.to_string(),
76 backends: Vec::new(),
77 use_tls: endpoint.protocol == Protocol::Https,
78 sni_hostname: String::new(),
79 expose: endpoint.expose,
80 protocol: endpoint.protocol,
81 strip_prefix: false,
82 path_prefix,
83 target_port,
84 },
85 }
86 }
87
88 #[must_use]
90 pub fn matches(&self, host: Option<&str>, path: &str) -> bool {
91 if let Some(ref pattern) = self.host {
94 match host {
95 Some(h) => {
96 if !host_matches(pattern, h) {
97 return false;
98 }
99 }
100 None => return false,
101 }
102 }
103
104 path_matches(&self.path_prefix, path)
105 }
106}
107
108pub struct ServiceRegistry {
118 routes: RwLock<Vec<RouteEntry>>,
120}
121
122impl Default for ServiceRegistry {
123 fn default() -> Self {
124 Self::new()
125 }
126}
127
128impl ServiceRegistry {
129 #[must_use]
131 pub fn new() -> Self {
132 Self {
133 routes: RwLock::new(Vec::new()),
134 }
135 }
136
137 pub async fn register(&self, entry: RouteEntry) {
139 let mut routes = self.routes.write().await;
140
141 let insert_idx = routes
142 .iter()
143 .position(|r| r.path_prefix.len() < entry.path_prefix.len())
144 .unwrap_or(routes.len());
145
146 routes.insert(insert_idx, entry);
147 }
148
149 pub async fn unregister_service(&self, service_name: &str) {
151 let mut routes = self.routes.write().await;
152 routes.retain(|r| r.service_name != service_name);
153 }
154
155 pub async fn resolve(&self, host: Option<&str>, path: &str) -> Option<ResolvedService> {
159 let routes = self.routes.read().await;
160
161 for entry in routes.iter() {
163 if entry.matches(host, path) {
164 return Some(entry.resolved.clone());
165 }
166 }
167
168 None
169 }
170
171 pub async fn update_backends(&self, service_name: &str, backends: Vec<SocketAddr>) {
173 let mut routes = self.routes.write().await;
174 for entry in routes.iter_mut() {
175 if entry.service_name == service_name {
176 entry.resolved.backends.clone_from(&backends);
177 }
178 }
179 }
180
181 pub async fn add_backend(&self, service_name: &str, addr: SocketAddr) {
183 let mut routes = self.routes.write().await;
184 for entry in routes.iter_mut() {
185 if entry.service_name == service_name && !entry.resolved.backends.contains(&addr) {
186 entry.resolved.backends.push(addr);
187 }
188 }
189 }
190
191 pub async fn remove_backend(&self, service_name: &str, addr: SocketAddr) {
193 let mut routes = self.routes.write().await;
194 for entry in routes.iter_mut() {
195 if entry.service_name == service_name {
196 entry.resolved.backends.retain(|a| *a != addr);
197 }
198 }
199 }
200
201 pub async fn list_services(&self) -> Vec<String> {
203 let routes = self.routes.read().await;
204 let mut seen = Vec::new();
205 for entry in routes.iter() {
206 if !seen.contains(&entry.service_name) {
207 seen.push(entry.service_name.clone());
208 }
209 }
210 seen
211 }
212
213 pub async fn route_count(&self) -> usize {
215 self.routes.read().await.len()
216 }
217
218 pub async fn list_routes(&self) -> Vec<RouteEntry> {
220 self.routes.read().await.clone()
221 }
222}
223
224#[must_use]
233pub fn transform_path(prefix: &str, path: &str, strip: bool) -> String {
234 if !strip || prefix == "/" {
235 return path.to_string();
236 }
237
238 let normalized_prefix = prefix.trim_end_matches('/');
239 if let Some(remainder) = path.strip_prefix(normalized_prefix) {
240 if remainder.is_empty() {
241 "/".to_string()
242 } else {
243 remainder.to_string()
244 }
245 } else {
246 path.to_string()
247 }
248}
249
250fn host_matches(pattern: &str, host: &str) -> bool {
255 if pattern.starts_with("*.") {
256 let suffix = &pattern[1..]; host.ends_with(suffix)
258 } else {
259 pattern == host
260 }
261}
262
263fn path_matches(prefix: &str, path: &str) -> bool {
266 if prefix == "/" {
267 return true;
268 }
269
270 let normalized = prefix.trim_end_matches('/');
271 let normalized_path = path.trim_end_matches('/');
272
273 normalized_path.starts_with(normalized)
274 && (normalized_path.len() == normalized.len()
275 || path.as_bytes().get(normalized.len()) == Some(&b'/'))
276}
277
278#[cfg(test)]
283mod tests {
284 use super::*;
285
286 fn make_resolved(name: &str, backends: Vec<SocketAddr>) -> ResolvedService {
290 ResolvedService {
291 name: name.to_string(),
292 backends,
293 use_tls: false,
294 sni_hostname: String::new(),
295 expose: ExposeType::Internal,
296 protocol: Protocol::Http,
297 strip_prefix: false,
298 path_prefix: "/".to_string(),
299 target_port: 8080,
300 }
301 }
302
303 fn make_entry(
305 service: &str,
306 host: Option<&str>,
307 path: &str,
308 backends: Vec<SocketAddr>,
309 ) -> RouteEntry {
310 let mut resolved = make_resolved(service, backends);
311 resolved.path_prefix = path.to_string();
312 RouteEntry {
313 service_name: service.to_string(),
314 endpoint_name: "http".to_string(),
315 host: host.map(std::string::ToString::to_string),
316 path_prefix: path.to_string(),
317 resolved,
318 }
319 }
320
321 #[test]
324 fn test_route_path_matching() {
325 let entry = make_entry("api", None, "/api/v1", vec![]);
326
327 assert!(entry.matches(None, "/api/v1"));
328 assert!(entry.matches(None, "/api/v1/"));
329 assert!(entry.matches(None, "/api/v1/users"));
330 assert!(entry.matches(None, "/api/v1/users/123"));
331 assert!(!entry.matches(None, "/api/v2"));
332 assert!(!entry.matches(None, "/api"));
333 assert!(!entry.matches(None, "/"));
334 }
335
336 #[test]
337 fn test_route_host_matching() {
338 let entry = make_entry("api", Some("api.example.com"), "/", vec![]);
339
340 assert!(entry.matches(Some("api.example.com"), "/anything"));
341 assert!(!entry.matches(Some("other.example.com"), "/anything"));
342 assert!(!entry.matches(None, "/anything"));
343 }
344
345 #[test]
346 fn test_route_wildcard_host() {
347 let entry = make_entry("api", Some("*.example.com"), "/", vec![]);
348
349 assert!(entry.matches(Some("api.example.com"), "/"));
350 assert!(entry.matches(Some("www.example.com"), "/"));
351 assert!(entry.matches(Some("foo.example.com"), "/"));
352 assert!(!entry.matches(Some("example.com"), "/"));
353 assert!(!entry.matches(Some("other.domain.com"), "/"));
354 }
355
356 #[test]
357 fn test_route_strip_prefix() {
358 assert_eq!(transform_path("/api/v1", "/api/v1/users", true), "/users");
359 assert_eq!(
360 transform_path("/api/v1", "/api/v1/users/123", true),
361 "/users/123"
362 );
363 assert_eq!(transform_path("/api/v1", "/api/v1", true), "/");
364 assert_eq!(transform_path("/api/v1", "/other", true), "/other");
365 assert_eq!(
367 transform_path("/api/v1", "/api/v1/users", false),
368 "/api/v1/users"
369 );
370 }
371
372 #[tokio::test]
373 async fn test_router_longest_prefix_match() {
374 let reg = ServiceRegistry::new();
375
376 reg.register(make_entry("root", None, "/", vec![])).await;
377 reg.register(make_entry("api", None, "/api", vec![])).await;
378 reg.register(make_entry("api-v1", None, "/api/v1", vec![]))
379 .await;
380
381 let m = reg.resolve(None, "/api/v1/users").await.unwrap();
382 assert_eq!(m.name, "api-v1");
383
384 let m = reg.resolve(None, "/api/v2/users").await.unwrap();
385 assert_eq!(m.name, "api");
386
387 let m = reg.resolve(None, "/other").await.unwrap();
388 assert_eq!(m.name, "root");
389 }
390
391 #[tokio::test]
392 async fn test_router_no_match() {
393 let reg = ServiceRegistry::new();
394
395 reg.register(make_entry("api", Some("api.example.com"), "/", vec![]))
396 .await;
397
398 let result = reg.resolve(Some("other.example.com"), "/").await;
399 assert!(result.is_none());
400 }
401
402 #[tokio::test]
405 async fn test_register_and_resolve_host() {
406 let reg = ServiceRegistry::new();
407
408 let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
409 reg.register(make_entry("api", Some("api.example.com"), "/", vec![addr]))
410 .await;
411
412 let resolved = reg
413 .resolve(Some("api.example.com"), "/anything")
414 .await
415 .unwrap();
416 assert_eq!(resolved.name, "api");
417 assert_eq!(resolved.backends.len(), 1);
418 }
419
420 #[tokio::test]
421 async fn test_register_and_resolve_path() {
422 let reg = ServiceRegistry::new();
423
424 let addr1: SocketAddr = "127.0.0.1:8081".parse().unwrap();
425 let addr2: SocketAddr = "127.0.0.1:8082".parse().unwrap();
426 reg.register(make_entry(
427 "api-v1",
428 Some("api.example.com"),
429 "/api/v1",
430 vec![addr1],
431 ))
432 .await;
433 reg.register(make_entry(
434 "api-v2",
435 Some("api.example.com"),
436 "/api/v2",
437 vec![addr2],
438 ))
439 .await;
440
441 let resolved = reg
442 .resolve(Some("api.example.com"), "/api/v1/users")
443 .await
444 .unwrap();
445 assert_eq!(resolved.name, "api-v1");
446
447 let resolved = reg
448 .resolve(Some("api.example.com"), "/api/v2/users")
449 .await
450 .unwrap();
451 assert_eq!(resolved.name, "api-v2");
452 }
453
454 #[tokio::test]
455 async fn test_resolve_not_found() {
456 let reg = ServiceRegistry::new();
457 let result = reg.resolve(Some("unknown.example.com"), "/").await;
458 assert!(result.is_none());
459 }
460
461 #[tokio::test]
462 async fn test_update_backends() {
463 let reg = ServiceRegistry::new();
464
465 let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
466 reg.register(make_entry("api", Some("api.example.com"), "/", vec![addr]))
467 .await;
468
469 let new_backends: Vec<SocketAddr> = vec![
470 "127.0.0.1:8081".parse().unwrap(),
471 "127.0.0.1:8082".parse().unwrap(),
472 ];
473 reg.update_backends("api", new_backends).await;
474
475 let resolved = reg.resolve(Some("api.example.com"), "/").await.unwrap();
476 assert_eq!(resolved.backends.len(), 2);
477 }
478
479 #[tokio::test]
480 async fn test_unregister_service() {
481 let reg = ServiceRegistry::new();
482
483 reg.register(make_entry("api", None, "/api", vec![])).await;
484 reg.register(make_entry("web", None, "/", vec![])).await;
485
486 assert_eq!(reg.route_count().await, 2);
487 reg.unregister_service("api").await;
488 assert_eq!(reg.route_count().await, 1);
489
490 let result = reg.resolve(None, "/api/foo").await;
492 assert_eq!(result.unwrap().name, "web");
494 }
495
496 #[tokio::test]
497 async fn test_list_services() {
498 let reg = ServiceRegistry::new();
499
500 reg.register(make_entry("api", None, "/api", vec![])).await;
501 reg.register(make_entry("api", None, "/api/v2", vec![]))
502 .await;
503 reg.register(make_entry("web", None, "/", vec![])).await;
504
505 let mut services = reg.list_services().await;
506 services.sort();
507 assert_eq!(services, vec!["api", "web"]);
508 }
509
510 #[tokio::test]
511 async fn test_route_count() {
512 let reg = ServiceRegistry::new();
513 assert_eq!(reg.route_count().await, 0);
514
515 reg.register(make_entry("a", None, "/a", vec![])).await;
516 reg.register(make_entry("b", None, "/b", vec![])).await;
517 reg.register(make_entry("c", None, "/c", vec![])).await;
518 assert_eq!(reg.route_count().await, 3);
519 }
520
521 #[tokio::test]
522 async fn test_add_remove_backend() {
523 let reg = ServiceRegistry::new();
524
525 let b1: SocketAddr = "127.0.0.1:8001".parse().unwrap();
526 reg.register(make_entry("api", None, "/", vec![b1])).await;
527
528 let b2: SocketAddr = "127.0.0.1:8002".parse().unwrap();
529 reg.add_backend("api", b2).await;
530
531 let resolved = reg.resolve(None, "/").await.unwrap();
532 assert_eq!(resolved.backends.len(), 2);
533 assert!(resolved.backends.contains(&b1));
534 assert!(resolved.backends.contains(&b2));
535
536 reg.add_backend("api", b2).await;
538 let resolved = reg.resolve(None, "/").await.unwrap();
539 assert_eq!(resolved.backends.len(), 2);
540
541 reg.remove_backend("api", b1).await;
543 let resolved = reg.resolve(None, "/").await.unwrap();
544 assert_eq!(resolved.backends.len(), 1);
545 assert_eq!(resolved.backends[0], b2);
546 }
547
548 #[tokio::test]
549 async fn test_from_endpoint() {
550 let endpoint = EndpointSpec {
551 name: "http".to_string(),
552 protocol: Protocol::Http,
553 port: 80,
554 target_port: Some(8080),
555 path: Some("/api".to_string()),
556 host: None,
557 expose: ExposeType::Public,
558 stream: None,
559 tunnel: None,
560 };
561
562 let entry = RouteEntry::from_endpoint("my-service", &endpoint);
563 assert_eq!(entry.service_name, "my-service");
564 assert_eq!(entry.endpoint_name, "http");
565 assert!(entry.host.is_none());
566 assert_eq!(entry.path_prefix, "/api");
567 assert_eq!(entry.resolved.name, "my-service");
568 assert_eq!(entry.resolved.protocol, Protocol::Http);
569 assert_eq!(entry.resolved.expose, ExposeType::Public);
570 assert_eq!(entry.resolved.target_port, 8080);
571 assert!(!entry.resolved.use_tls);
572 assert!(entry.resolved.backends.is_empty());
573 }
574}