1use std::collections::{BTreeMap, HashMap};
7use std::convert::Infallible;
8use std::future::Future;
9use std::panic::{catch_unwind, AssertUnwindSafe};
10use std::path::PathBuf;
11use std::pin::Pin;
12use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
13use std::sync::Arc;
14use std::time::{Duration, SystemTime, UNIX_EPOCH};
15
16use axum::body::{Body, Bytes};
17use axum::extract::multipart::MultipartRejection;
18use axum::extract::ws::{CloseFrame, Message, WebSocket as AxumWebSocket, WebSocketUpgrade};
19use axum::extract::DefaultBodyLimit;
20use axum::extract::{Multipart, Path, Query, State};
21use axum::http::{header, HeaderMap, HeaderName, HeaderValue, Method, StatusCode};
22use axum::response::{IntoResponse, Response};
23use axum::routing::{get, post};
24use axum::{Json, Router};
25use futures::FutureExt;
26use serde_json::{json, Map, Value};
27use tokio::sync::mpsc;
28use zynk_runtime::{
29 inventory, EndpointKind, EndpointMeta, Handler, HandlerKey, JsonErrorEnvelope,
30 JsonResultEnvelope, ParamMeta, SseFrame, WsMessage, ZynkError, COMMAND_NOT_FOUND,
31 EXECUTION_ERROR, INTERNAL_ERROR, STATIC_HANDLER_NOT_FOUND, UPLOAD_HANDLER_NOT_FOUND,
32 UPLOAD_VALIDATION_ERROR, VALIDATION_ERROR, WEBSOCKET_ERROR,
33};
34
35#[derive(Clone)]
37pub struct ZynkBridge {
38 state: Arc<BridgeState>,
39}
40
41struct BridgeState {
42 title: String,
43 endpoints: BTreeMap<String, &'static EndpointMeta>,
44 handlers: HashMap<HandlerKey, Arc<dyn Handler>>,
45 channel_handlers: HashMap<HandlerKey, Arc<dyn ChannelHandler>>,
46 upload_handlers: HashMap<HandlerKey, Arc<dyn UploadHandler>>,
47 static_handlers: HashMap<HandlerKey, Arc<dyn StaticHandler>>,
48 ws_handlers: HashMap<HandlerKey, Arc<dyn WsHandler>>,
49 models: BTreeMap<String, zynk_runtime::zynk_schema::ModelDef>,
50 enums: BTreeMap<String, zynk_runtime::zynk_schema::EnumDef>,
51 debug: bool,
52 keepalive_interval: Duration,
53}
54
55type BoxChannelFuture = Pin<Box<dyn Future<Output = Result<(), ZynkError>> + Send>>;
56type BoxUploadFuture = Pin<Box<dyn Future<Output = Result<Value, ZynkError>> + Send>>;
57type BoxStaticFuture = Pin<Box<dyn Future<Output = Result<StaticFile, ZynkError>> + Send>>;
58type BoxWsFuture = Pin<Box<dyn Future<Output = Result<(), ZynkError>> + Send>>;
59
60pub trait ChannelHandler: Send + Sync + 'static {
62 fn call(&self, payload: Value, channel: Channel) -> BoxChannelFuture;
64}
65
66impl<F, Fut> ChannelHandler for F
67where
68 F: Fn(Value, Channel) -> Fut + Send + Sync + 'static,
69 Fut: Future<Output = Result<(), ZynkError>> + Send + 'static,
70{
71 fn call(&self, payload: Value, channel: Channel) -> BoxChannelFuture {
72 Box::pin(self(payload, channel))
73 }
74}
75
76#[derive(Debug, Clone, PartialEq, Eq)]
78pub struct UploadFile {
79 filename: String,
80 content_type: String,
81 bytes: Bytes,
82}
83
84impl UploadFile {
85 fn new(filename: String, content_type: String, bytes: Bytes) -> Self {
86 Self {
87 filename,
88 content_type,
89 bytes,
90 }
91 }
92
93 pub fn filename(&self) -> &str {
95 &self.filename
96 }
97
98 pub fn content_type(&self) -> &str {
100 &self.content_type
101 }
102
103 pub fn size(&self) -> usize {
105 self.bytes.len()
106 }
107
108 pub fn bytes(&self) -> &[u8] {
110 &self.bytes
111 }
112}
113
114#[derive(Debug, Clone, PartialEq, Eq)]
116pub struct StaticFile {
117 path: PathBuf,
118 content_type: Option<String>,
119 headers: HeaderMap,
120}
121
122impl StaticFile {
123 pub fn new(path: impl Into<PathBuf>) -> Self {
125 Self {
126 path: path.into(),
127 content_type: None,
128 headers: HeaderMap::new(),
129 }
130 }
131
132 pub fn with_content_type(mut self, content_type: impl Into<String>) -> Self {
134 self.content_type = Some(content_type.into());
135 self
136 }
137
138 pub fn with_header(mut self, name: impl AsRef<str>, value: impl AsRef<str>) -> Self {
140 if let (Ok(name), Ok(value)) = (
141 HeaderName::from_bytes(name.as_ref().as_bytes()),
142 HeaderValue::from_str(value.as_ref()),
143 ) {
144 self.headers.insert(name, value);
145 }
146 self
147 }
148
149 pub fn path(&self) -> &std::path::Path {
151 &self.path
152 }
153
154 pub fn content_type(&self) -> String {
156 self.content_type
157 .clone()
158 .unwrap_or_else(|| guess_content_type(&self.path))
159 }
160
161 pub fn headers(&self) -> &HeaderMap {
163 &self.headers
164 }
165}
166
167pub trait UploadHandler: Send + Sync + 'static {
169 fn call(&self, payload: Value, files: Vec<UploadFile>) -> BoxUploadFuture;
171}
172
173impl<F, Fut> UploadHandler for F
174where
175 F: Fn(Value, Vec<UploadFile>) -> Fut + Send + Sync + 'static,
176 Fut: Future<Output = Result<Value, ZynkError>> + Send + 'static,
177{
178 fn call(&self, payload: Value, files: Vec<UploadFile>) -> BoxUploadFuture {
179 Box::pin(self(payload, files))
180 }
181}
182
183pub trait StaticHandler: Send + Sync + 'static {
185 fn call(&self, payload: Value) -> BoxStaticFuture;
187}
188
189impl<F, Fut> StaticHandler for F
190where
191 F: Fn(Value) -> Fut + Send + Sync + 'static,
192 Fut: Future<Output = Result<StaticFile, ZynkError>> + Send + 'static,
193{
194 fn call(&self, payload: Value) -> BoxStaticFuture {
195 Box::pin(self(payload))
196 }
197}
198
199pub trait WsHandler: Send + Sync + 'static {
201 fn call(&self, payload: Value, socket: WebSocket) -> BoxWsFuture;
203}
204
205impl<F, Fut> WsHandler for F
206where
207 F: Fn(Value, WebSocket) -> Fut + Send + Sync + 'static,
208 Fut: Future<Output = Result<(), ZynkError>> + Send + 'static,
209{
210 fn call(&self, payload: Value, socket: WebSocket) -> BoxWsFuture {
211 Box::pin(self(payload, socket))
212 }
213}
214
215#[derive(Debug, Clone)]
216enum ChannelEvent {
217 Data(Value),
218 Close,
219}
220
221#[derive(Clone, Debug)]
223pub struct Channel {
224 id: Arc<str>,
225 sender: mpsc::UnboundedSender<ChannelEvent>,
226 closed: Arc<AtomicBool>,
227}
228
229impl Channel {
230 fn new(id: String, sender: mpsc::UnboundedSender<ChannelEvent>) -> Self {
231 Self {
232 id: Arc::from(id),
233 sender,
234 closed: Arc::new(AtomicBool::new(false)),
235 }
236 }
237
238 pub fn id(&self) -> &str {
240 &self.id
241 }
242
243 pub fn send(&self, data: Value) -> Result<(), ZynkError> {
245 if self.closed.load(Ordering::SeqCst) {
246 return Err(ZynkError::new(
247 EXECUTION_ERROR,
248 format!("Cannot send on closed channel {}", self.id),
249 ));
250 }
251 self.sender.send(ChannelEvent::Data(data)).map_err(|_| {
252 ZynkError::new(
253 EXECUTION_ERROR,
254 format!("Cannot send on closed channel {}", self.id),
255 )
256 })
257 }
258
259 pub fn close(&self) -> Result<(), ZynkError> {
261 if !self.closed.swap(true, Ordering::SeqCst) {
262 let _ = self.sender.send(ChannelEvent::Close);
263 }
264 Ok(())
265 }
266}
267
268#[derive(Clone, Debug)]
270pub struct WebSocket {
271 sender: mpsc::UnboundedSender<WsMessage>,
272}
273
274impl WebSocket {
275 fn new(sender: mpsc::UnboundedSender<WsMessage>) -> Self {
276 Self { sender }
277 }
278
279 pub async fn send(&self, event: impl Into<String>, data: Value) -> Result<(), ZynkError> {
281 self.sender
282 .send(WsMessage::new(event, data))
283 .map_err(|_| ZynkError::new(WEBSOCKET_ERROR, "Cannot send on closed WebSocket"))
284 }
285}
286
287impl ZynkBridge {
288 pub fn new() -> Self {
290 Self::from_inventory()
291 }
292
293 pub fn with_debug(debug: bool) -> Self {
295 Self::from_parts(collect_inventory_endpoints(), HashMap::new(), debug)
296 }
297
298 pub fn title(mut self, title: impl Into<String>) -> Self {
300 Arc::make_mut(&mut self.state).title = title.into();
301 self
302 }
303
304 pub fn debug(mut self, debug: bool) -> Self {
306 Arc::make_mut(&mut self.state).debug = debug;
307 self
308 }
309
310 pub fn register_endpoint_meta(mut self, endpoint: &'static EndpointMeta) -> Self {
312 Arc::make_mut(&mut self.state)
313 .endpoints
314 .insert(endpoint.name.to_string(), endpoint);
315 self
316 }
317
318 pub fn register_handler<H>(mut self, key: HandlerKey, handler: H) -> Self
324 where
325 H: Handler,
326 {
327 Arc::make_mut(&mut self.state)
328 .handlers
329 .insert(key, Arc::new(handler));
330 self
331 }
332
333 pub fn register_command<H>(mut self, name: &str, handler: H) -> Self
335 where
336 H: Handler,
337 {
338 if let Some(endpoint) = self.state.endpoints.get(name) {
339 if let Some(key) = endpoint.handler_key {
340 Arc::make_mut(&mut self.state)
341 .handlers
342 .insert(key, Arc::new(handler));
343 }
344 }
345 self
346 }
347
348 pub fn register_channel<H>(mut self, key: HandlerKey, handler: H) -> Self
350 where
351 H: ChannelHandler,
352 {
353 Arc::make_mut(&mut self.state)
354 .channel_handlers
355 .insert(key, Arc::new(handler));
356 self
357 }
358
359 pub fn register_upload<H>(mut self, key: HandlerKey, handler: H) -> Self
361 where
362 H: UploadHandler,
363 {
364 Arc::make_mut(&mut self.state)
365 .upload_handlers
366 .insert(key, Arc::new(handler));
367 self
368 }
369
370 pub fn register_static<H>(mut self, key: HandlerKey, handler: H) -> Self
372 where
373 H: StaticHandler,
374 {
375 Arc::make_mut(&mut self.state)
376 .static_handlers
377 .insert(key, Arc::new(handler));
378 self
379 }
380
381 pub fn register_ws<H>(mut self, key: HandlerKey, handler: H) -> Self
383 where
384 H: WsHandler,
385 {
386 Arc::make_mut(&mut self.state)
387 .ws_handlers
388 .insert(key, Arc::new(handler));
389 self
390 }
391
392 pub fn register_model(mut self, model: zynk_runtime::zynk_schema::ModelDef) -> Self {
394 Arc::make_mut(&mut self.state)
395 .models
396 .insert(model.name.clone(), model);
397 self
398 }
399
400 pub fn register_enum(mut self, enum_def: zynk_runtime::zynk_schema::EnumDef) -> Self {
402 Arc::make_mut(&mut self.state)
403 .enums
404 .insert(enum_def.name.clone(), enum_def);
405 self
406 }
407
408 pub fn keepalive_interval(mut self, interval: Duration) -> Self {
410 Arc::make_mut(&mut self.state).keepalive_interval = interval;
411 self
412 }
413
414 pub fn configure(self, router: Router) -> Router {
416 router
417 .route("/", get(root_route).with_state(self.state.clone()))
418 .route(
419 "/commands",
420 get(commands_route).with_state(self.state.clone()),
421 )
422 .route(
423 "/command/{name}",
424 post(command_route).with_state(self.state.clone()),
425 )
426 .route(
427 "/channel/{name}",
428 post(channel_route).with_state(self.state.clone()),
429 )
430 .route(
431 "/upload/{name}",
432 post(upload_route)
433 .layer(DefaultBodyLimit::max(64 * 1024 * 1024))
434 .with_state(self.state.clone()),
435 )
436 .route(
437 "/static/{name}",
438 get(static_get_route)
439 .head(static_head_route)
440 .with_state(self.state.clone()),
441 )
442 .route("/ws/{name}", get(ws_route).with_state(self.state))
443 }
444
445 pub fn dump_schema_json(&self) -> String {
447 serde_json::to_string(&self.api_graph()).expect("ApiGraph serialization cannot fail")
448 }
449
450 pub fn api_graph(&self) -> zynk_runtime::zynk_schema::ApiGraph {
452 let mut graph = zynk_runtime::zynk_schema::ApiGraph::new();
453 for endpoint in self.state.endpoints.values() {
454 graph.insert_endpoint(endpoint_to_schema(endpoint));
455 }
456 for model in self.state.models.values() {
457 graph.insert_model(model.clone());
458 }
459 for enum_def in self.state.enums.values() {
460 graph.insert_enum(enum_def.clone());
461 }
462 graph
463 }
464
465 fn from_inventory() -> Self {
466 Self::from_parts(collect_inventory_endpoints(), HashMap::new(), false)
467 }
468
469 fn from_parts(
470 endpoints: BTreeMap<String, &'static EndpointMeta>,
471 handlers: HashMap<HandlerKey, Arc<dyn Handler>>,
472 debug: bool,
473 ) -> Self {
474 Self {
475 state: Arc::new(BridgeState {
476 title: "Zynk API".to_string(),
477 endpoints,
478 handlers,
479 channel_handlers: HashMap::new(),
480 upload_handlers: HashMap::new(),
481 static_handlers: HashMap::new(),
482 ws_handlers: HashMap::new(),
483 models: BTreeMap::new(),
484 enums: BTreeMap::new(),
485 debug,
486 keepalive_interval: Duration::from_secs(30),
487 }),
488 }
489 }
490}
491
492impl Default for ZynkBridge {
493 fn default() -> Self {
494 Self::new()
495 }
496}
497
498impl Clone for BridgeState {
499 fn clone(&self) -> Self {
500 Self {
501 title: self.title.clone(),
502 endpoints: self.endpoints.clone(),
503 handlers: self.handlers.clone(),
504 channel_handlers: self.channel_handlers.clone(),
505 upload_handlers: self.upload_handlers.clone(),
506 static_handlers: self.static_handlers.clone(),
507 ws_handlers: self.ws_handlers.clone(),
508 models: self.models.clone(),
509 enums: self.enums.clone(),
510 debug: self.debug,
511 keepalive_interval: self.keepalive_interval,
512 }
513 }
514}
515
516fn collect_inventory_endpoints() -> BTreeMap<String, &'static EndpointMeta> {
517 inventory::iter::<EndpointMeta>
518 .into_iter()
519 .map(|endpoint| (endpoint.name.to_string(), endpoint))
520 .collect()
521}
522
523async fn root_route(State(state): State<Arc<BridgeState>>) -> Response {
524 let commands: Vec<_> = state
525 .endpoints
526 .values()
527 .filter(|endpoint| matches!(endpoint.kind, EndpointKind::Rpc | EndpointKind::Channel))
528 .map(|endpoint| endpoint.name)
529 .collect();
530
531 Json(json!({
532 "status": "ok",
533 "bridge": state.title,
534 "commands": commands,
535 }))
536 .into_response()
537}
538
539async fn commands_route(State(state): State<Arc<BridgeState>>) -> Response {
540 let commands: Vec<_> = state
541 .endpoints
542 .values()
543 .filter(|endpoint| matches!(endpoint.kind, EndpointKind::Rpc | EndpointKind::Channel))
544 .map(|endpoint| {
545 json!({
546 "name": endpoint.name,
547 "module": endpoint.module.unwrap_or_default(),
548 "has_channel": endpoint.kind == EndpointKind::Channel,
549 "params": endpoint
550 .params
551 .iter()
552 .map(|param| param.source_name)
553 .collect::<Vec<_>>(),
554 })
555 })
556 .collect();
557
558 Json(json!({ "commands": commands })).into_response()
559}
560
561async fn channel_route(
562 State(state): State<Arc<BridgeState>>,
563 Path(name): Path<String>,
564 body: Bytes,
565) -> Response {
566 let Some(endpoint) = state.endpoints.get(&name).copied() else {
567 return command_not_found_response(&name);
568 };
569
570 if endpoint.kind != EndpointKind::Channel {
571 return command_not_found_response(&name);
572 }
573
574 let payload = match parse_json_body(&body) {
575 Ok(payload) => payload,
576 Err(error) => return error_response(StatusCode::BAD_REQUEST, error.into_envelope()),
577 };
578
579 let payload = match validate_params(endpoint.params, payload) {
580 Ok(payload) => payload,
581 Err(error) => return error_response(StatusCode::BAD_REQUEST, error.into_envelope()),
582 };
583
584 let Some(handler_key) = endpoint.handler_key else {
585 return error_response(
586 StatusCode::INTERNAL_SERVER_ERROR,
587 JsonErrorEnvelope::new(
588 INTERNAL_ERROR,
589 "Registered channel is missing a handler key",
590 ),
591 );
592 };
593
594 let Some(handler) = state.channel_handlers.get(&handler_key).cloned() else {
595 return error_response(
596 StatusCode::INTERNAL_SERVER_ERROR,
597 JsonErrorEnvelope::new(
598 INTERNAL_ERROR,
599 format!("No handler registered for channel '{name}'"),
600 ),
601 );
602 };
603
604 let channel_id = next_channel_id();
605 let (sender, receiver) = mpsc::unbounded_channel();
606 let channel = Channel::new(channel_id, sender);
607 let handler_channel = channel.clone();
608 let handler_task = tokio::spawn(async move {
609 match AssertUnwindSafe(handler.call(payload, handler_channel))
610 .catch_unwind()
611 .await
612 {
613 Ok(result) => result,
614 Err(panic) => Err(ZynkError::new(INTERNAL_ERROR, panic_message(panic))),
615 }
616 });
617
618 let stream = futures::stream::unfold(
619 ChannelStreamState {
620 channel,
621 receiver,
622 handler_task: Some(handler_task),
623 keepalive_interval: state.keepalive_interval,
624 pending_error: None,
625 emitted_terminal: false,
626 },
627 next_channel_chunk,
628 );
629
630 let mut response = Body::from_stream(stream).into_response();
631 let headers = response.headers_mut();
632 headers.insert(
633 header::CONTENT_TYPE,
634 HeaderValue::from_static("text/event-stream"),
635 );
636 headers.insert(header::CACHE_CONTROL, HeaderValue::from_static("no-cache"));
637 headers.insert(header::CONNECTION, HeaderValue::from_static("keep-alive"));
638 headers.insert("x-accel-buffering", HeaderValue::from_static("no"));
639 response
640}
641
642struct ChannelStreamState {
643 channel: Channel,
644 receiver: mpsc::UnboundedReceiver<ChannelEvent>,
645 handler_task: Option<tokio::task::JoinHandle<Result<(), ZynkError>>>,
646 keepalive_interval: Duration,
647 pending_error: Option<String>,
648 emitted_terminal: bool,
649}
650
651async fn next_channel_chunk(
652 mut state: ChannelStreamState,
653) -> Option<(Result<Bytes, Infallible>, ChannelStreamState)> {
654 if state.emitted_terminal {
655 return None;
656 }
657
658 loop {
659 if let Some(handler_task) = state.handler_task.as_mut() {
660 tokio::select! {
661 Some(event) = state.receiver.recv() => {
662 match event {
663 ChannelEvent::Data(data) => {
664 return Some((Ok(Bytes::from(SseFrame::new("message", data).encode())), state));
665 }
666 ChannelEvent::Close => {
667 state.emitted_terminal = true;
668 let frame = close_frame(state.channel.id());
669 return Some((Ok(Bytes::from(frame)), state));
670 }
671 }
672 }
673 result = handler_task => {
674 state.handler_task = None;
675 match result {
676 Ok(Ok(())) => {
677 let _ = state.channel.close();
678 continue;
679 }
680 Ok(Err(error)) => {
681 state.pending_error = Some(error.message);
682 continue;
683 }
684 Err(error) => {
685 state.pending_error = Some(error.to_string());
686 continue;
687 }
688 }
689 }
690 () = tokio::time::sleep(state.keepalive_interval) => {
691 return Some((Ok(Bytes::from(SseFrame::new("keepalive", json!({})).encode())), state));
692 }
693 }
694 } else if let Ok(event) = state.receiver.try_recv() {
695 match event {
696 ChannelEvent::Data(data) => {
697 return Some((
698 Ok(Bytes::from(SseFrame::new("message", data).encode())),
699 state,
700 ));
701 }
702 ChannelEvent::Close => {
703 state.emitted_terminal = true;
704 let frame = close_frame(state.channel.id());
705 return Some((Ok(Bytes::from(frame)), state));
706 }
707 }
708 } else if let Some(message) = state.pending_error.take() {
709 state.emitted_terminal = true;
710 let frame = SseFrame::new("error", json!({ "error": message })).encode();
711 return Some((Ok(Bytes::from(frame)), state));
712 } else {
713 match state.receiver.recv().await {
714 Some(ChannelEvent::Data(data)) => {
715 return Some((
716 Ok(Bytes::from(SseFrame::new("message", data).encode())),
717 state,
718 ));
719 }
720 Some(ChannelEvent::Close) => {
721 state.emitted_terminal = true;
722 let frame = close_frame(state.channel.id());
723 return Some((Ok(Bytes::from(frame)), state));
724 }
725 None => return None,
726 }
727 }
728 }
729}
730
731fn close_frame(channel_id: &str) -> String {
732 SseFrame::new("close", json!({ "channelId": channel_id })).encode()
733}
734
735async fn upload_route(
736 State(state): State<Arc<BridgeState>>,
737 Path(name): Path<String>,
738 multipart: Result<Multipart, MultipartRejection>,
739) -> Response {
740 let Some(endpoint) = state.endpoints.get(&name).copied() else {
741 return upload_not_found_response(&name);
742 };
743
744 if endpoint.kind != EndpointKind::Upload {
745 return upload_not_found_response(&name);
746 }
747
748 let mut multipart = match multipart {
749 Ok(multipart) => multipart,
750 Err(error) => {
751 return error_response(
752 StatusCode::BAD_REQUEST,
753 JsonErrorEnvelope::new(
754 VALIDATION_ERROR,
755 format!("Invalid multipart body: {error}"),
756 ),
757 )
758 }
759 };
760
761 let (args, files) = match parse_upload_multipart(endpoint, &mut multipart).await {
762 Ok(parsed) => parsed,
763 Err(error) => return error_response(status_for_error(error.code), error.into_envelope()),
764 };
765
766 let payload = match validate_params(endpoint.params, args) {
767 Ok(payload) => payload,
768 Err(error) => return error_response(StatusCode::BAD_REQUEST, error.into_envelope()),
769 };
770
771 if !endpoint.multi_file && files.is_empty() {
772 return error_response(
773 StatusCode::BAD_REQUEST,
774 JsonErrorEnvelope::new(VALIDATION_ERROR, "No file provided"),
775 );
776 }
777
778 let Some(handler_key) = endpoint.handler_key else {
779 return error_response(
780 StatusCode::INTERNAL_SERVER_ERROR,
781 JsonErrorEnvelope::new(INTERNAL_ERROR, "Registered upload is missing a handler key"),
782 );
783 };
784
785 let Some(handler) = state.upload_handlers.get(&handler_key).cloned() else {
786 return error_response(
787 StatusCode::INTERNAL_SERVER_ERROR,
788 JsonErrorEnvelope::new(
789 INTERNAL_ERROR,
790 format!("No handler registered for upload '{name}'"),
791 ),
792 );
793 };
794
795 match AssertUnwindSafe(handler.call(payload, files))
796 .catch_unwind()
797 .await
798 {
799 Ok(Ok(value)) => (StatusCode::OK, Json(JsonResultEnvelope::new(value))).into_response(),
800 Ok(Err(error)) => error_response(status_for_error(error.code), error.into_envelope()),
801 Err(panic) => {
802 let message = if state.debug {
803 panic_message(panic)
804 } else {
805 "An internal error occurred".to_string()
806 };
807 error_response(
808 StatusCode::INTERNAL_SERVER_ERROR,
809 JsonErrorEnvelope::new(INTERNAL_ERROR, message),
810 )
811 }
812 }
813}
814
815async fn parse_upload_multipart(
816 endpoint: &EndpointMeta,
817 multipart: &mut Multipart,
818) -> Result<(Value, Vec<UploadFile>), ZynkError> {
819 let mut args = json!({});
820 let mut files = Vec::new();
821
822 while let Some(field) = multipart.next_field().await.map_err(|error| {
823 ZynkError::new(VALIDATION_ERROR, format!("Invalid multipart body: {error}"))
824 })? {
825 let Some(field_name) = field.name().map(str::to_string) else {
826 continue;
827 };
828
829 if field_name == "_args" {
830 let text = field.text().await.map_err(|error| {
831 ZynkError::new(VALIDATION_ERROR, format!("Invalid _args field: {error}"))
832 })?;
833 args = serde_json::from_str(&text).map_err(|error| {
834 ZynkError::new(VALIDATION_ERROR, format!("Invalid _args JSON: {error}"))
835 })?;
836 } else if field_name == "files" {
837 let filename = field.file_name().unwrap_or("upload").to_string();
838 let content_type = field
839 .content_type()
840 .unwrap_or("application/octet-stream")
841 .to_string();
842 if !content_type_allowed(&content_type, endpoint.allowed_types) {
843 return Err(upload_validation_error(
844 format!("File '{filename}' has disallowed content type {content_type}"),
845 filename,
846 ));
847 }
848 let bytes = read_limited_field(field, endpoint.max_size, &filename).await?;
849 files.push(UploadFile::new(filename, content_type, bytes));
850 }
851 }
852
853 Ok((args, files))
854}
855
856async fn read_limited_field(
857 mut field: axum::extract::multipart::Field<'_>,
858 max_size: Option<u64>,
859 filename: &str,
860) -> Result<Bytes, ZynkError> {
861 let mut bytes = Vec::new();
862 let mut size: u64 = 0;
863
864 while let Some(chunk) = field.chunk().await.map_err(|error| {
865 ZynkError::new(VALIDATION_ERROR, format!("Invalid upload file: {error}"))
866 })? {
867 size += chunk.len() as u64;
868 if let Some(max_size) = max_size {
869 if size > max_size {
870 return Err(upload_validation_error(
871 format!("File '{filename}' exceeds maximum size of {max_size} bytes"),
872 filename.to_string(),
873 ));
874 }
875 }
876 bytes.extend_from_slice(&chunk);
877 }
878
879 Ok(Bytes::from(bytes))
880}
881
882fn upload_validation_error(message: String, filename: String) -> ZynkError {
883 ZynkError::with_details(
884 UPLOAD_VALIDATION_ERROR,
885 message,
886 json!({ "filename": filename }),
887 )
888}
889
890fn content_type_allowed(content_type: &str, allowed_types: &[&str]) -> bool {
891 allowed_types.is_empty()
892 || allowed_types.iter().any(|allowed| {
893 *allowed == content_type
894 || allowed
895 .strip_suffix("/*")
896 .is_some_and(|prefix| content_type.starts_with(&format!("{prefix}/")))
897 })
898}
899
900async fn static_get_route(
901 State(state): State<Arc<BridgeState>>,
902 Path(name): Path<String>,
903 Query(query): Query<HashMap<String, String>>,
904) -> Response {
905 static_route(state, name, query, Method::GET).await
906}
907
908async fn static_head_route(
909 State(state): State<Arc<BridgeState>>,
910 Path(name): Path<String>,
911 Query(query): Query<HashMap<String, String>>,
912) -> Response {
913 static_route(state, name, query, Method::HEAD).await
914}
915
916async fn static_route(
917 state: Arc<BridgeState>,
918 name: String,
919 query: HashMap<String, String>,
920 method: Method,
921) -> Response {
922 let Some(endpoint) = state.endpoints.get(&name).copied() else {
923 return static_not_found_response(&name);
924 };
925
926 if endpoint.kind != EndpointKind::Static {
927 return static_not_found_response(&name);
928 }
929
930 let payload = match coerce_query_params(endpoint.params, &query) {
931 Ok(payload) => payload,
932 Err(error) => return error_response(StatusCode::BAD_REQUEST, error.into_envelope()),
933 };
934
935 let Some(handler_key) = endpoint.handler_key else {
936 return error_response(
937 StatusCode::INTERNAL_SERVER_ERROR,
938 JsonErrorEnvelope::new(
939 INTERNAL_ERROR,
940 "Registered static handler is missing a handler key",
941 ),
942 );
943 };
944
945 let Some(handler) = state.static_handlers.get(&handler_key).cloned() else {
946 return error_response(
947 StatusCode::INTERNAL_SERVER_ERROR,
948 JsonErrorEnvelope::new(
949 INTERNAL_ERROR,
950 format!("No handler registered for static '{name}'"),
951 ),
952 );
953 };
954
955 match AssertUnwindSafe(handler.call(payload)).catch_unwind().await {
956 Ok(Ok(file)) => static_file_response(file, method).await,
957 Ok(Err(error)) => error_response(status_for_error(error.code), error.into_envelope()),
958 Err(panic) => {
959 let message = if state.debug {
960 panic_message(panic)
961 } else {
962 "An internal error occurred".to_string()
963 };
964 error_response(
965 StatusCode::INTERNAL_SERVER_ERROR,
966 JsonErrorEnvelope::new(INTERNAL_ERROR, message),
967 )
968 }
969 }
970}
971
972fn coerce_query_params(
973 params: &[ParamMeta],
974 query: &HashMap<String, String>,
975) -> Result<Value, ZynkError> {
976 let mut object = Map::new();
977 for param in params {
978 let Some(raw) = query
979 .get(param.source_name)
980 .or_else(|| query.get(param.wire_name))
981 else {
982 if param.required {
983 return Err(ZynkError::with_details(
984 VALIDATION_ERROR,
985 format!("Missing required parameter: {}", param.source_name),
986 json!({ "parameter": param.source_name }),
987 ));
988 }
989 continue;
990 };
991 object.insert(
992 param.source_name.to_string(),
993 coerce_query_value(raw, param)?,
994 );
995 }
996 Ok(Value::Object(object))
997}
998
999fn coerce_query_value(raw: &str, param: &ParamMeta) -> Result<Value, ZynkError> {
1000 match param.ty.name {
1001 Some("number") => {
1002 if let Ok(value) = raw.parse::<i64>() {
1003 Ok(json!(value))
1004 } else if let Ok(value) = raw.parse::<f64>() {
1005 Ok(json!(value))
1006 } else {
1007 Err(ZynkError::new(
1008 VALIDATION_ERROR,
1009 format!("Invalid value for parameter '{}': {raw}", param.source_name),
1010 ))
1011 }
1012 }
1013 Some("boolean") => Ok(Value::Bool(matches!(
1014 raw.to_ascii_lowercase().as_str(),
1015 "true" | "1" | "yes"
1016 ))),
1017 _ => Ok(Value::String(raw.to_string())),
1018 }
1019}
1020
1021async fn static_file_response(file: StaticFile, method: Method) -> Response {
1022 let metadata = match tokio::fs::metadata(file.path()).await {
1023 Ok(metadata) => metadata,
1024 Err(error) => {
1025 return error_response(
1026 StatusCode::BAD_REQUEST,
1027 JsonErrorEnvelope::new(VALIDATION_ERROR, error.to_string()),
1028 )
1029 }
1030 };
1031 let body = if method == Method::HEAD {
1032 Body::empty()
1033 } else {
1034 match tokio::fs::read(file.path()).await {
1035 Ok(bytes) => Body::from(bytes),
1036 Err(error) => {
1037 return error_response(
1038 StatusCode::BAD_REQUEST,
1039 JsonErrorEnvelope::new(VALIDATION_ERROR, error.to_string()),
1040 )
1041 }
1042 }
1043 };
1044
1045 let mut response = Response::new(body);
1046 *response.status_mut() = StatusCode::OK;
1047 let headers = response.headers_mut();
1048 headers.insert(
1049 header::CONTENT_TYPE,
1050 HeaderValue::from_str(&file.content_type())
1051 .unwrap_or_else(|_| HeaderValue::from_static("application/octet-stream")),
1052 );
1053 if let Ok(value) = HeaderValue::from_str(&metadata.len().to_string()) {
1054 headers.insert(header::CONTENT_LENGTH, value);
1055 }
1056 if let Ok(modified) = metadata.modified() {
1057 headers.insert(header::LAST_MODIFIED, http_date(modified));
1058 }
1059 headers.insert(
1060 "x-content-type-options",
1061 HeaderValue::from_static("nosniff"),
1062 );
1063 for (name, value) in file.headers() {
1064 headers.insert(name.clone(), value.clone());
1065 }
1066 response
1067}
1068
1069fn http_date(time: SystemTime) -> HeaderValue {
1070 let seconds = time
1071 .duration_since(UNIX_EPOCH)
1072 .unwrap_or_default()
1073 .as_secs() as i64;
1074 let days = seconds.div_euclid(86_400);
1075 let seconds_of_day = seconds.rem_euclid(86_400);
1076 let (year, month, day) = civil_from_days(days);
1077 let hour = seconds_of_day / 3_600;
1078 let minute = (seconds_of_day % 3_600) / 60;
1079 let second = seconds_of_day % 60;
1080 let weekday = ["Thu", "Fri", "Sat", "Sun", "Mon", "Tue", "Wed"][days.rem_euclid(7) as usize];
1081 let month_name = [
1082 "Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec",
1083 ][(month - 1) as usize];
1084 HeaderValue::from_str(&format!(
1085 "{weekday}, {day:02} {month_name} {year:04} {hour:02}:{minute:02}:{second:02} GMT"
1086 ))
1087 .unwrap_or_else(|_| HeaderValue::from_static("Thu, 01 Jan 1970 00:00:00 GMT"))
1088}
1089
1090fn civil_from_days(days: i64) -> (i64, u32, u32) {
1091 let z = days + 719_468;
1092 let era = if z >= 0 { z } else { z - 146_096 } / 146_097;
1093 let doe = z - era * 146_097;
1094 let yoe = (doe - doe / 1_460 + doe / 36_524 - doe / 146_096) / 365;
1095 let y = yoe + era * 400;
1096 let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
1097 let mp = (5 * doy + 2) / 153;
1098 let d = doy - (153 * mp + 2) / 5 + 1;
1099 let m = mp + if mp < 10 { 3 } else { -9 };
1100 (y + i64::from(m <= 2), m as u32, d as u32)
1101}
1102
1103fn guess_content_type(path: &std::path::Path) -> String {
1104 match path.extension().and_then(|extension| extension.to_str()) {
1105 Some("html") | Some("htm") => "text/html; charset=utf-8",
1106 Some("txt") => "text/plain; charset=utf-8",
1107 Some("json") => "application/json",
1108 Some("png") => "image/png",
1109 Some("jpg") | Some("jpeg") => "image/jpeg",
1110 Some("gif") => "image/gif",
1111 Some("svg") => "image/svg+xml",
1112 Some("css") => "text/css; charset=utf-8",
1113 Some("js") => "application/javascript",
1114 Some("pdf") => "application/pdf",
1115 _ => "application/octet-stream",
1116 }
1117 .to_string()
1118}
1119
1120static NEXT_CHANNEL_ID: AtomicU64 = AtomicU64::new(1);
1121
1122fn next_channel_id() -> String {
1123 format!("channel-{}", NEXT_CHANNEL_ID.fetch_add(1, Ordering::SeqCst))
1124}
1125
1126async fn ws_route(
1127 State(state): State<Arc<BridgeState>>,
1128 Path(name): Path<String>,
1129 ws: WebSocketUpgrade,
1130) -> Response {
1131 ws.on_upgrade(move |socket| handle_ws_socket(state, name, socket))
1132}
1133
1134async fn handle_ws_socket(state: Arc<BridgeState>, name: String, mut socket: AxumWebSocket) {
1135 let Some(endpoint) = state.endpoints.get(&name).copied() else {
1136 close_ws(&mut socket, 4004, format!("Handler '{name}' not found")).await;
1137 return;
1138 };
1139
1140 if endpoint.kind != EndpointKind::Ws {
1141 close_ws(&mut socket, 4004, format!("Handler '{name}' not found")).await;
1142 return;
1143 }
1144
1145 let Some(handler_key) = endpoint.handler_key else {
1146 close_ws(
1147 &mut socket,
1148 1011,
1149 "Registered WebSocket is missing a handler key",
1150 )
1151 .await;
1152 return;
1153 };
1154
1155 let Some(handler) = state.ws_handlers.get(&handler_key).cloned() else {
1156 close_ws(
1157 &mut socket,
1158 1011,
1159 format!("No handler registered for WebSocket '{name}'"),
1160 )
1161 .await;
1162 return;
1163 };
1164
1165 let (sender, mut receiver) = mpsc::unbounded_channel();
1166 let ws_sender = WebSocket::new(sender);
1167
1168 loop {
1169 tokio::select! {
1170 outbound = receiver.recv() => {
1171 match outbound {
1172 Some(message) => {
1173 let text = serde_json::to_string(&message)
1174 .expect("WebSocket message serialization cannot fail");
1175 if socket.send(Message::Text(text.into())).await.is_err() {
1176 return;
1177 }
1178 }
1179 None => return,
1180 }
1181 }
1182 inbound = socket.recv() => {
1183 let Some(inbound) = inbound else {
1184 return;
1185 };
1186 match inbound {
1187 Ok(Message::Text(text)) => {
1188 let message = match WsMessage::from_json(text.as_str()) {
1189 Ok(message) => message,
1190 Err(_) => continue,
1191 };
1192 if !client_event_known(endpoint, &message.event) {
1193 continue;
1194 }
1195 let payload = json!({ "event": message.event, "data": message.data });
1196 let future = match catch_unwind(AssertUnwindSafe(|| {
1197 handler.call(payload, ws_sender.clone())
1198 })) {
1199 Ok(future) => future,
1200 Err(panic) => {
1201 close_ws(&mut socket, 1011, panic_message(panic)).await;
1202 return;
1203 }
1204 };
1205 match AssertUnwindSafe(future).catch_unwind().await {
1206 Ok(Ok(())) => {}
1207 Ok(Err(error)) => {
1208 close_ws(&mut socket, 1011, error.message).await;
1209 return;
1210 }
1211 Err(panic) => {
1212 close_ws(&mut socket, 1011, panic_message(panic)).await;
1213 return;
1214 }
1215 }
1216 }
1217 Ok(Message::Close(_)) => return,
1218 Ok(Message::Ping(payload)) => {
1219 let _ = socket.send(Message::Pong(payload)).await;
1220 }
1221 Ok(Message::Pong(_)) | Ok(Message::Binary(_)) => {}
1222 Err(error) => {
1223 close_ws(&mut socket, 1011, error.to_string()).await;
1224 return;
1225 }
1226 }
1227 }
1228 }
1229 }
1230}
1231
1232fn client_event_known(endpoint: &EndpointMeta, event: &str) -> bool {
1233 endpoint.client_events.is_empty()
1234 || endpoint
1235 .client_events
1236 .iter()
1237 .any(|param| param.source_name == event)
1238}
1239
1240async fn close_ws(socket: &mut AxumWebSocket, code: u16, reason: impl Into<String>) {
1241 let _ = socket
1242 .send(Message::Close(Some(CloseFrame {
1243 code,
1244 reason: reason.into().into(),
1245 })))
1246 .await;
1247}
1248
1249async fn command_route(
1250 State(state): State<Arc<BridgeState>>,
1251 Path(name): Path<String>,
1252 body: Bytes,
1253) -> Response {
1254 let Some(endpoint) = state.endpoints.get(&name).copied() else {
1255 return command_not_found_response(&name);
1256 };
1257
1258 if endpoint.kind != EndpointKind::Rpc {
1259 return command_not_found_response(&name);
1260 }
1261
1262 let payload = match parse_json_body(&body) {
1263 Ok(payload) => payload,
1264 Err(error) => return error_response(StatusCode::BAD_REQUEST, error.into_envelope()),
1265 };
1266
1267 let payload = match validate_params(endpoint.params, payload) {
1268 Ok(payload) => payload,
1269 Err(error) => return error_response(StatusCode::BAD_REQUEST, error.into_envelope()),
1270 };
1271
1272 let Some(handler_key) = endpoint.handler_key else {
1273 return error_response(
1274 StatusCode::INTERNAL_SERVER_ERROR,
1275 JsonErrorEnvelope::new(
1276 INTERNAL_ERROR,
1277 "Registered command is missing a handler key",
1278 ),
1279 );
1280 };
1281
1282 let Some(handler) = state.handlers.get(&handler_key).cloned() else {
1283 return error_response(
1284 StatusCode::INTERNAL_SERVER_ERROR,
1285 JsonErrorEnvelope::new(
1286 INTERNAL_ERROR,
1287 format!("No handler registered for command '{name}'"),
1288 ),
1289 );
1290 };
1291
1292 let result = catch_unwind(AssertUnwindSafe(|| handler.call(payload)));
1293 match result {
1294 Ok(Ok(value)) => (StatusCode::OK, Json(JsonResultEnvelope::new(value))).into_response(),
1295 Ok(Err(error)) => {
1296 let status = status_for_error(error.code);
1297 if error.code == INTERNAL_ERROR && !state.debug {
1298 error_response(
1299 status,
1300 JsonErrorEnvelope::new(INTERNAL_ERROR, "An internal error occurred"),
1301 )
1302 } else {
1303 error_response(status, error.into_envelope())
1304 }
1305 }
1306 Err(panic) => {
1307 let message = if state.debug {
1308 panic_message(panic)
1309 } else {
1310 "An internal error occurred".to_string()
1311 };
1312 error_response(
1313 StatusCode::INTERNAL_SERVER_ERROR,
1314 JsonErrorEnvelope::new(INTERNAL_ERROR, message),
1315 )
1316 }
1317 }
1318}
1319
1320fn parse_json_body(body: &Bytes) -> Result<Value, ZynkError> {
1321 if body.is_empty() {
1322 return Ok(Value::Object(Default::default()));
1323 }
1324
1325 serde_json::from_slice(body)
1326 .map_err(|error| ZynkError::new(VALIDATION_ERROR, format!("Invalid JSON body: {error}")))
1327}
1328
1329fn validate_params(params: &[ParamMeta], payload: Value) -> Result<Value, ZynkError> {
1330 let object = payload
1331 .as_object()
1332 .ok_or_else(|| ZynkError::new(VALIDATION_ERROR, "Request body must be a JSON object"))?;
1333
1334 for param in params.iter().filter(|param| param.required) {
1335 if !object.contains_key(param.source_name) && !object.contains_key(param.wire_name) {
1336 return Err(ZynkError::with_details(
1337 VALIDATION_ERROR,
1338 format!("Missing required parameter '{}'", param.source_name),
1339 json!({ "parameter": param.source_name }),
1340 ));
1341 }
1342 }
1343
1344 Ok(Value::Object(object.clone()))
1345}
1346
1347fn status_for_error(code: &'static str) -> StatusCode {
1348 match code {
1349 VALIDATION_ERROR | UPLOAD_VALIDATION_ERROR => StatusCode::BAD_REQUEST,
1350 COMMAND_NOT_FOUND | UPLOAD_HANDLER_NOT_FOUND | STATIC_HANDLER_NOT_FOUND => {
1351 StatusCode::NOT_FOUND
1352 }
1353 EXECUTION_ERROR | INTERNAL_ERROR => StatusCode::INTERNAL_SERVER_ERROR,
1354 _ => StatusCode::INTERNAL_SERVER_ERROR,
1355 }
1356}
1357
1358fn command_not_found_response(name: &str) -> Response {
1359 error_response(
1360 StatusCode::NOT_FOUND,
1361 JsonErrorEnvelope::new(COMMAND_NOT_FOUND, format!("Command '{name}' not found"))
1362 .with_details(json!({ "command": name })),
1363 )
1364}
1365
1366fn upload_not_found_response(name: &str) -> Response {
1367 error_response(
1368 StatusCode::NOT_FOUND,
1369 JsonErrorEnvelope::new(
1370 UPLOAD_HANDLER_NOT_FOUND,
1371 format!("Upload handler '{name}' not found"),
1372 )
1373 .with_details(json!({ "handler": name })),
1374 )
1375}
1376
1377fn static_not_found_response(name: &str) -> Response {
1378 error_response(
1379 StatusCode::NOT_FOUND,
1380 JsonErrorEnvelope::new(
1381 STATIC_HANDLER_NOT_FOUND,
1382 format!("Static handler '{name}' not found"),
1383 )
1384 .with_details(json!({ "handler": name })),
1385 )
1386}
1387
1388fn error_response(status: StatusCode, error: JsonErrorEnvelope) -> Response {
1389 (status, Json(error)).into_response()
1390}
1391
1392fn panic_message(panic: Box<dyn std::any::Any + Send>) -> String {
1393 if let Some(message) = panic.downcast_ref::<&str>() {
1394 (*message).to_string()
1395 } else if let Some(message) = panic.downcast_ref::<String>() {
1396 message.clone()
1397 } else {
1398 "handler panicked".to_string()
1399 }
1400}
1401
1402fn endpoint_to_schema(meta: &EndpointMeta) -> zynk_runtime::zynk_schema::Endpoint {
1403 let mut endpoint = zynk_runtime::zynk_schema::Endpoint::new(
1404 meta.name,
1405 meta.kind,
1406 meta.returns.to_schema_type_ref(),
1407 );
1408 endpoint.module = meta.module.map(str::to_string);
1409 endpoint.doc = meta.doc.map(str::to_string);
1410 endpoint.params = params_to_schema(meta.params);
1411 endpoint.channel_item = meta
1412 .channel_item
1413 .as_ref()
1414 .map(zynk_runtime::TypeRefStatic::to_schema_type_ref);
1415 endpoint.file_param = meta.file_param.map(str::to_string);
1416 endpoint.multi_file = meta.multi_file;
1417 endpoint.max_size = meta.max_size;
1418 endpoint.allowed_types = meta
1419 .allowed_types
1420 .iter()
1421 .map(|value| (*value).to_string())
1422 .collect();
1423 endpoint.server_events = params_to_schema(meta.server_events);
1424 endpoint.client_events = params_to_schema(meta.client_events);
1425 endpoint
1426}
1427
1428fn params_to_schema(params: &[ParamMeta]) -> Vec<zynk_runtime::zynk_schema::Param> {
1429 params
1430 .iter()
1431 .map(|param| zynk_runtime::zynk_schema::Param {
1432 source_name: param.source_name.to_string(),
1433 wire_name: param.wire_name.to_string(),
1434 ty: param.ty.to_schema_type_ref(),
1435 required: param.required,
1436 default: param
1437 .default
1438 .as_ref()
1439 .map(zynk_runtime::StaticValue::to_json),
1440 })
1441 .collect()
1442}
1443
1444#[cfg(test)]
1445mod tests {
1446 use super::*;
1447
1448 #[test]
1449 fn empty_body_defaults_to_json_object() {
1450 assert_eq!(parse_json_body(&Bytes::new()).expect("valid"), json!({}));
1451 }
1452
1453 #[test]
1454 fn validates_required_params_against_snake_or_camel_keys() {
1455 let params = [ParamMeta {
1456 source_name: "user_id",
1457 wire_name: "userId",
1458 ty: zynk_runtime::TypeRefStatic::primitive("number"),
1459 required: true,
1460 default: None,
1461 }];
1462
1463 validate_params(¶ms, json!({"user_id": 7})).expect("snake case accepted");
1464 validate_params(¶ms, json!({"userId": 7})).expect("camel case accepted");
1465 let error = validate_params(¶ms, json!({})).expect_err("missing param rejected");
1466 assert_eq!(error.code, VALIDATION_ERROR);
1467 }
1468
1469 #[test]
1470 fn last_modified_header_uses_http_date_format() {
1471 assert_eq!(
1472 http_date(UNIX_EPOCH),
1473 HeaderValue::from_static("Thu, 01 Jan 1970 00:00:00 GMT")
1474 );
1475 }
1476}