1#[cfg(test)]
35mod tests;
36
37use std::io::{Read, Write};
38use std::net::TcpStream;
39use std::sync::{Arc, RwLock};
40use std::time::Duration;
41
42use crate::application::Application;
43use crate::core::New;
44use crate::mime_type::MimeType;
45use crate::range::Range;
46use crate::request::Request;
47use crate::response::{Response, STATUS_CODE_REASON_PHRASE};
48use crate::server::ConnectionInfo;
49
50#[derive(Debug, Clone, PartialEq)]
54pub struct IngressRule {
55 pub host: String,
57 pub path: String,
59 pub service_name: String,
61 pub service_port: u16,
63 pub namespace: String,
65}
66
67impl IngressRule {
68 pub fn upstream_addr(&self) -> String {
72 format!(
73 "{}.{}.svc.cluster.local:{}",
74 self.service_name, self.namespace, self.service_port
75 )
76 }
77
78 pub fn matches(&self, host: &str, uri: &str) -> bool {
85 if !self.host.is_empty() && !self.host.eq_ignore_ascii_case(host) {
86 return false;
87 }
88 self.path == "/" || uri.starts_with(&self.path)
89 }
90}
91
92fn extract_str_field<'a>(json: &'a str, field: &str) -> Option<&'a str> {
96 let needle = format!("\"{}\":", field);
97 let start = json.find(needle.as_str())?;
98 let after_colon = &json[start + needle.len()..];
99 let after_colon = after_colon.trim_start_matches(' ');
100 if !after_colon.starts_with('"') {
101 return None;
102 }
103 let inner = &after_colon[1..];
104 let end = inner.find('"')?;
105 Some(&inner[..end])
106}
107
108fn extract_u16_field(json: &str, field: &str) -> Option<u16> {
110 let needle = format!("\"{}\":", field);
111 let start = json.find(needle.as_str())?;
112 let after_colon = &json[start + needle.len()..];
113 let after_colon = after_colon.trim_start_matches(' ');
114 let end = after_colon.find(|c: char| !c.is_ascii_digit())?;
115 after_colon[..end].parse().ok()
116}
117
118pub fn parse_ingress_list(json: &str) -> Vec<IngressRule> {
126 let mut rules = Vec::new();
127
128 let spec_sections: Vec<&str> = json.split("\"spec\"").collect();
131 for section in spec_sections.iter().skip(1) {
132 let namespace = extract_str_field(section, "namespace")
136 .unwrap_or("default")
137 .to_string();
138
139 let rules_sections: Vec<&str> = section.split("\"rules\"").collect();
141 for rules_section in rules_sections.iter().skip(1) {
142 let host = extract_str_field(rules_section, "host").unwrap_or("").to_string();
144
145 let paths_sections: Vec<&str> = rules_section.split("\"paths\"").collect();
147 for paths_section in paths_sections.iter().skip(1) {
148 let path_entries: Vec<&str> = paths_section.split("\"path\"").collect();
152 for path_entry in path_entries.iter().skip(1) {
153 let path = extract_str_field(path_entry, "path")
154 .or_else(|| {
155 let after_colon = path_entry.trim_start_matches(':').trim_start_matches(' ');
157 if after_colon.starts_with('"') {
158 let inner = &after_colon[1..];
159 inner.find('"').map(|end| &inner[..end])
160 } else {
161 None
162 }
163 })
164 .unwrap_or("/")
165 .to_string();
166
167 let service_name =
168 extract_str_field(path_entry, "name").unwrap_or("").to_string();
169 let service_port =
170 extract_u16_field(path_entry, "number").unwrap_or(80);
171
172 if !service_name.is_empty() {
173 rules.push(IngressRule {
174 host: host.clone(),
175 path,
176 service_name,
177 service_port,
178 namespace: namespace.clone(),
179 });
180 }
181 }
182 }
183 }
184 }
185
186 rules
187}
188
189pub struct KubernetesIngressWatcher {
194 api_server: String,
195 token: String,
196 namespace: String,
197 poll_interval_secs: u64,
198 rules: Arc<RwLock<Vec<IngressRule>>>,
199}
200
201impl KubernetesIngressWatcher {
202 pub fn new(api_server: impl Into<String>, token: impl Into<String>) -> Self {
207 Self {
208 api_server: api_server.into(),
209 token: token.into(),
210 namespace: "default".to_string(),
211 poll_interval_secs: 30,
212 rules: Arc::new(RwLock::new(Vec::new())),
213 }
214 }
215
216 pub fn from_service_account() -> Result<Self, String> {
227 Err(
228 "In-cluster TLS (https://kubernetes.default.svc) is not yet supported. \
229 Use `kubectl proxy` and set RWS_K8S_API_SERVER=http://localhost:8001 \
230 along with RWS_K8S_TOKEN and RWS_K8S_NAMESPACE, then call \
231 KubernetesIngressWatcher::from_env()."
232 .to_string(),
233 )
234 }
235
236 pub fn from_env() -> Result<Self, String> {
241 let api_server = std::env::var("RWS_K8S_API_SERVER").map_err(|_| {
242 "RWS_K8S_API_SERVER environment variable is not set".to_string()
243 })?;
244 let token = std::env::var("RWS_K8S_TOKEN").unwrap_or_default();
245 let namespace = std::env::var("RWS_K8S_NAMESPACE").unwrap_or_else(|_| "default".to_string());
246 let mut watcher = Self::new(api_server, token);
247 watcher.namespace = namespace;
248 Ok(watcher)
249 }
250
251 pub fn namespace(mut self, ns: impl Into<String>) -> Self {
253 self.namespace = ns.into();
254 self
255 }
256
257 pub fn poll_interval_secs(mut self, secs: u64) -> Self {
259 self.poll_interval_secs = secs;
260 self
261 }
262
263 pub fn start(&self) {
266 self.clone_inner().poll_loop();
267 }
268
269 fn clone_inner(&self) -> WatcherHandle {
270 WatcherHandle {
271 api_server: self.api_server.clone(),
272 token: self.token.clone(),
273 namespace: self.namespace.clone(),
274 poll_interval_secs: self.poll_interval_secs,
275 rules: Arc::clone(&self.rules),
276 }
277 }
278
279 pub fn rules(&self) -> Vec<IngressRule> {
281 self.rules.read().unwrap().clone()
282 }
283
284 pub fn poll(&self) -> Result<(), String> {
288 let new_rules = self.do_poll()?;
289 *self.rules.write().unwrap() = new_rules;
290 Ok(())
291 }
292
293 fn do_poll(&self) -> Result<Vec<IngressRule>, String> {
294 let path = if self.namespace.is_empty() || self.namespace == "all" {
295 "/apis/networking.k8s.io/v1/ingresses".to_string()
296 } else {
297 format!(
298 "/apis/networking.k8s.io/v1/namespaces/{}/ingresses",
299 self.namespace
300 )
301 };
302
303 let body = http_get_plain(&self.api_server, &path, &self.token)?;
304 Ok(parse_ingress_list(&body))
305 }
306}
307
308struct WatcherHandle {
311 api_server: String,
312 token: String,
313 namespace: String,
314 poll_interval_secs: u64,
315 rules: Arc<RwLock<Vec<IngressRule>>>,
316}
317
318impl WatcherHandle {
319 fn poll_loop(self) {
320 self.poll_once();
322 let interval = Duration::from_secs(self.poll_interval_secs);
323 std::thread::spawn(move || loop {
324 std::thread::sleep(interval);
325 self.poll_once();
326 });
327 }
328
329 fn poll_once(&self) {
330 let path = if self.namespace.is_empty() || self.namespace == "all" {
331 "/apis/networking.k8s.io/v1/ingresses".to_string()
332 } else {
333 format!(
334 "/apis/networking.k8s.io/v1/namespaces/{}/ingresses",
335 self.namespace
336 )
337 };
338 match http_get_plain(&self.api_server, &path, &self.token) {
339 Ok(body) => {
340 let new_rules = parse_ingress_list(&body);
341 *self.rules.write().unwrap() = new_rules;
342 }
343 Err(e) => {
344 eprintln!("ingress watcher: poll failed: {}", e);
345 }
346 }
347 }
348}
349
350fn http_get_plain(api_server: &str, path: &str, token: &str) -> Result<String, String> {
355 let rest = api_server
357 .strip_prefix("http://")
358 .ok_or_else(|| format!("ingress watcher: api_server must start with http://, got: {}", api_server))?;
359 let host_port = rest.split('/').next().unwrap_or(rest);
360 let (host, port) = if let Some(colon) = host_port.rfind(':') {
361 let port_str = &host_port[colon + 1..];
362 if let Ok(p) = port_str.parse::<u16>() {
363 (&host_port[..colon], p)
364 } else {
365 (host_port, 80u16)
366 }
367 } else {
368 (host_port, 80u16)
369 };
370
371 let addr = format!("{}:{}", host, port);
372 let mut stream = TcpStream::connect(&addr)
373 .map_err(|e| format!("ingress watcher: connect to {} failed: {}", addr, e))?;
374 stream.set_read_timeout(Some(Duration::from_secs(10))).map_err(|e| e.to_string())?;
375 stream.set_write_timeout(Some(Duration::from_secs(5))).map_err(|e| e.to_string())?;
376
377 let auth_header = if token.is_empty() {
378 String::new()
379 } else {
380 format!("Authorization: Bearer {}\r\n", token)
381 };
382
383 let request = format!(
384 "GET {} HTTP/1.1\r\nHost: {}\r\n{}Accept: application/json\r\nConnection: close\r\n\r\n",
385 path, host, auth_header
386 );
387
388 stream.write_all(request.as_bytes()).map_err(|e| e.to_string())?;
389
390 let mut buf = Vec::with_capacity(8192);
391 let mut tmp = [0u8; 4096];
392 loop {
393 match stream.read(&mut tmp) {
394 Ok(0) => break,
395 Ok(n) => buf.extend_from_slice(&tmp[..n]),
396 Err(e) => return Err(format!("ingress watcher: read failed: {}", e)),
397 }
398 }
399
400 let header_end = buf
402 .windows(4)
403 .position(|w| w == b"\r\n\r\n")
404 .ok_or_else(|| "ingress watcher: incomplete HTTP response (no header end)".to_string())?;
405
406 let header_str = std::str::from_utf8(&buf[..header_end]).unwrap_or("");
407 let status_line = header_str.lines().next().unwrap_or("");
408 let parts: Vec<&str> = status_line.splitn(3, ' ').collect();
409 if parts.len() < 2 {
410 return Err(format!("ingress watcher: malformed status line: {}", status_line));
411 }
412 let status: u16 = parts[1].parse().unwrap_or(0);
413 if status < 200 || status >= 300 {
414 return Err(format!("ingress watcher: API returned status {}", status));
415 }
416
417 let body_bytes = &buf[header_end + 4..];
418 std::str::from_utf8(body_bytes)
419 .map(|s| s.to_string())
420 .map_err(|e| format!("ingress watcher: non-UTF-8 response body: {}", e))
421}
422
423pub struct IngressRouter {
431 watcher: KubernetesIngressWatcher,
432 connect_timeout: Duration,
433 read_timeout: Duration,
434}
435
436impl IngressRouter {
437 pub fn new(watcher: KubernetesIngressWatcher) -> Self {
439 Self {
440 watcher,
441 connect_timeout: Duration::from_secs(5),
442 read_timeout: Duration::from_secs(30),
443 }
444 }
445
446 pub fn connect_timeout_ms(mut self, ms: u64) -> Self {
448 self.connect_timeout = Duration::from_millis(ms);
449 self
450 }
451
452 pub fn read_timeout_ms(mut self, ms: u64) -> Self {
454 self.read_timeout = Duration::from_millis(ms);
455 self
456 }
457}
458
459impl Application for IngressRouter {
460 fn execute(&self, request: &Request, connection: &ConnectionInfo) -> Result<Response, String> {
461 let host = request
462 .get_header("host".to_string())
463 .map(|h| h.value.as_str())
464 .unwrap_or("");
465
466 let rules = self.watcher.rules();
467 let matched = rules.iter().find(|r| r.matches(host, &request.request_uri));
468
469 match matched {
470 Some(rule) => {
471 let upstream_host = format!(
472 "{}.{}.svc.cluster.local",
473 rule.service_name, rule.namespace
474 );
475 crate::proxy::proxy_http1(
476 request,
477 &connection.client.ip,
478 &upstream_host,
479 rule.service_port,
480 self.connect_timeout,
481 self.read_timeout,
482 )
483 .or_else(|_| Ok(bad_gateway()))
484 }
485 None => Ok(not_found()),
486 }
487 }
488}
489
490fn bad_gateway() -> Response {
491 let cr = Range::get_content_range(
492 b"502 Bad Gateway".to_vec(),
493 MimeType::TEXT_PLAIN.to_string(),
494 );
495 let mut r = Response::new();
496 r.status_code = *STATUS_CODE_REASON_PHRASE.n502_bad_gateway.status_code;
497 r.reason_phrase = STATUS_CODE_REASON_PHRASE.n502_bad_gateway.reason_phrase.to_string();
498 r.content_range_list = vec![cr];
499 r
500}
501
502fn not_found() -> Response {
503 let cr = Range::get_content_range(
504 b"404 No matching ingress rule".to_vec(),
505 MimeType::TEXT_PLAIN.to_string(),
506 );
507 let mut r = Response::new();
508 r.status_code = *STATUS_CODE_REASON_PHRASE.n404_not_found.status_code;
509 r.reason_phrase = STATUS_CODE_REASON_PHRASE.n404_not_found.reason_phrase.to_string();
510 r.content_range_list = vec![cr];
511 r
512}