1use std::sync::Arc;
4use std::sync::atomic::Ordering;
5
6use axum::{
7 Router,
8 extract::{
9 State, WebSocketUpgrade,
10 ws::{Message, WebSocket},
11 },
12 response::{IntoResponse, Json},
13 routing::get,
14};
15
16#[cfg(feature = "embedded-frontend")]
17use axum::extract::Path as AxumPath;
18
19#[cfg(not(feature = "embedded-frontend"))]
20use axum::response::Html;
21use futures::{SinkExt, StreamExt};
22use tokio::sync::{Mutex as TokioMutex, RwLock};
23use tower_http::cors::CorsLayer;
24use venus_core::execute::ExecutorKillHandle;
25use venus_core::graph::CellId;
26
27use crate::lsp;
28use crate::protocol::{CellState, ClientMessage, ServerMessage};
29use crate::session::{InterruptFlag, NotebookSession};
30
31#[cfg(feature = "embedded-frontend")]
32use crate::embedded_frontend;
33
34pub struct AppState {
36 pub session: Arc<RwLock<NotebookSession>>,
38 pub kill_handle: Arc<TokioMutex<Option<ExecutorKillHandle>>>,
42 pub interrupted: InterruptFlag,
45}
46
47pub fn create_router(state: Arc<AppState>) -> Router {
49 let router = Router::new()
50 .route("/health", get(health_handler))
51 .route("/ws", get(ws_handler))
52 .route("/lsp", get(lsp_handler))
53 .route("/api/state", get(state_handler))
54 .route("/api/graph", get(graph_handler));
55
56 #[cfg(feature = "embedded-frontend")]
58 let router = router
59 .route("/", get(frontend_index_handler))
60 .route("/static/{*path}", get(static_handler));
61
62 #[cfg(not(feature = "embedded-frontend"))]
63 let router = router.route("/", get(index_handler));
64
65 router
66 .layer(CorsLayer::permissive())
67 .with_state(state)
68}
69
70#[cfg(not(feature = "embedded-frontend"))]
72async fn index_handler() -> Html<&'static str> {
73 Html(
74 r#"<!DOCTYPE html>
75<html>
76<head>
77 <title>Venus Notebook</title>
78 <style>
79 body { font-family: system-ui, sans-serif; margin: 2rem; }
80 h1 { color: #7c3aed; }
81 pre { background: #f3f4f6; padding: 1rem; border-radius: 0.5rem; }
82 </style>
83</head>
84<body>
85 <h1>Venus Notebook Server</h1>
86 <p>WebSocket endpoint: <code>/ws</code></p>
87 <p>API endpoints:</p>
88 <ul>
89 <li><code>GET /health</code> - Health check</li>
90 <li><code>GET /api/state</code> - Current notebook state</li>
91 <li><code>GET /api/graph</code> - Dependency graph</li>
92 </ul>
93 <p><em>Note: The full UI is available with the <code>embedded-frontend</code> feature.</em></p>
94 <script>
95 const ws = new WebSocket(`ws://${location.host}/ws`);
96 ws.onmessage = (e) => console.log('Server:', JSON.parse(e.data));
97 ws.onopen = () => ws.send(JSON.stringify({ type: 'get_state' }));
98 </script>
99</body>
100</html>"#,
101 )
102}
103
104#[cfg(feature = "embedded-frontend")]
106async fn frontend_index_handler() -> impl IntoResponse {
107 embedded_frontend::serve_index()
108}
109
110#[cfg(feature = "embedded-frontend")]
112async fn static_handler(AxumPath(path): AxumPath<String>) -> impl IntoResponse {
113 embedded_frontend::serve_static(path)
114}
115
116async fn health_handler() -> Json<serde_json::Value> {
118 Json(serde_json::json!({
119 "status": "ok",
120 "version": env!("CARGO_PKG_VERSION")
121 }))
122}
123
124async fn state_handler(State(state): State<Arc<AppState>>) -> impl IntoResponse {
126 let session = state.session.read().await;
127 let notebook_state = session.get_state();
128 Json(notebook_state)
129}
130
131async fn graph_handler(State(state): State<Arc<AppState>>) -> impl IntoResponse {
133 let session = state.session.read().await;
134
135 let state_msg = session.get_state();
137 match state_msg {
138 ServerMessage::NotebookState {
139 execution_order, ..
140 } => Json(serde_json::json!({
141 "execution_order": execution_order
142 })),
143 _ => Json(serde_json::json!({})),
144 }
145}
146
147async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<AppState>>) -> impl IntoResponse {
149 ws.on_upgrade(|socket| handle_websocket(socket, state))
150}
151
152async fn lsp_handler(ws: WebSocketUpgrade, State(state): State<Arc<AppState>>) -> impl IntoResponse {
154 let notebook_path = {
155 let session = state.session.read().await;
156 session.path().to_path_buf()
157 };
158 ws.on_upgrade(move |socket| lsp::handle_lsp_websocket(socket, notebook_path))
159}
160
161async fn handle_websocket(socket: WebSocket, state: Arc<AppState>) {
163 let (mut sender, mut receiver) = socket.split();
164
165 let mut rx = {
167 let session = state.session.read().await;
168 session.subscribe()
169 };
170
171 {
173 let session = state.session.read().await;
174 let initial_state = session.get_state();
175 if let Ok(json) = serde_json::to_string(&initial_state) {
176 let _ = sender.send(Message::Text(json.into())).await;
177 }
178 }
179
180 let sender = Arc::new(tokio::sync::Mutex::new(sender));
182 let sender_clone = sender.clone();
183
184 let forward_task = tokio::spawn(async move {
185 while let Ok(msg) = rx.recv().await {
186 if let Ok(json) = serde_json::to_string(&msg) {
187 let mut sender = sender_clone.lock().await;
188 if sender.send(Message::Text(json.into())).await.is_err() {
189 break;
190 }
191 }
192 }
193 });
194
195 while let Some(result) = receiver.next().await {
197 tracing::debug!("Received WebSocket message");
198 match result {
199 Ok(Message::Text(text)) => {
200 tracing::debug!("Parsing message: {}", &text[..text.len().min(100)]);
201 match serde_json::from_str::<ClientMessage>(&text) {
202 Ok(msg) => {
203 tracing::debug!("Dispatching message: {:?}", std::mem::discriminant(&msg));
204 handle_client_message(msg, &state, &sender).await;
205 }
206 Err(e) => {
207 tracing::warn!("Failed to parse client message: {} (input: {})", e, text);
208 send_message(
209 &sender,
210 &ServerMessage::Error {
211 message: format!("Invalid message format: {}", e),
212 },
213 )
214 .await;
215 }
216 }
217 }
218 Ok(Message::Close(_)) => break,
219 Err(e) => {
220 tracing::warn!("WebSocket error: {}", e);
221 break;
222 }
223 _ => {}
224 }
225 }
226
227 forward_task.abort();
229 let _ = forward_task.await;
230}
231
232async fn send_message(
234 sender: &Arc<tokio::sync::Mutex<futures::stream::SplitSink<WebSocket, Message>>>,
235 msg: &ServerMessage,
236) {
237 if let Ok(json) = serde_json::to_string(msg) {
238 let mut sender = sender.lock().await;
239 let _ = sender.send(Message::Text(json.into())).await;
240 }
241}
242
243async fn handle_cell_operation<T, F, R>(
257 session: &mut NotebookSession,
258 operation: F,
259 response_constructor: R,
260 sender: &Arc<tokio::sync::Mutex<futures::stream::SplitSink<WebSocket, Message>>>,
261) where
262 F: FnOnce(&mut NotebookSession) -> crate::error::ServerResult<T>,
263 R: FnOnce(Result<T, String>) -> ServerMessage,
264{
265 let result = operation(session);
266
267 match result {
269 Ok(value) => {
270 let msg = response_constructor(Ok(value));
271 send_message(sender, &msg).await;
272
273 let state_msg = session.get_state();
275 session.broadcast(state_msg);
276 let undo_state = session.get_undo_redo_state();
277 session.broadcast(undo_state);
278 }
279 Err(e) => {
280 let msg = response_constructor(Err(e.to_string()));
281 send_message(sender, &msg).await;
282 }
283 };
284}
285
286async fn handle_client_message(
288 msg: ClientMessage,
289 state: &Arc<AppState>,
290 sender: &Arc<tokio::sync::Mutex<futures::stream::SplitSink<WebSocket, Message>>>,
291) {
292 match msg {
293 ClientMessage::GetState => {
294 let session = state.session.read().await;
295 let state_msg = session.get_state();
296 send_message(sender, &state_msg).await;
297 }
298
299 ClientMessage::ExecuteCell { cell_id } => {
300 let state_clone = state.clone();
302
303 tokio::spawn(async move {
304 let state_for_blocking = state_clone.clone();
306 let exec_result = tokio::task::spawn_blocking(move || {
307 let rt = tokio::runtime::Handle::current();
308 rt.block_on(async {
309 let mut session = state_for_blocking.session.write().await;
310 session.execute_cell(cell_id).await
311 })
312 }).await;
313
314 match exec_result {
315 Ok(Ok(())) => {}
316 Ok(Err(e)) => {
317 tracing::debug!("Execution error: {}", e);
318 }
319 Err(e) => {
320 tracing::error!("Task join error: {}", e);
321 }
322 }
323 });
324 }
325
326 ClientMessage::ExecuteAll => {
327 let state_clone = state.clone();
329
330 tokio::spawn(async move {
331 let state_for_blocking = state_clone.clone();
333 let exec_result = tokio::task::spawn_blocking(move || {
334 let rt = tokio::runtime::Handle::current();
335 rt.block_on(async {
336 let mut session = state_for_blocking.session.write().await;
337 session.execute_all().await
338 })
339 }).await;
340
341 match exec_result {
342 Ok(Ok(())) => {}
343 Ok(Err(e)) => {
344 tracing::debug!("Execution error: {}", e);
345 }
346 Err(e) => {
347 tracing::error!("Task join error: {}", e);
348 }
349 }
350 });
351 }
352
353 ClientMessage::ExecuteDirty => {
354 let state_clone = state.clone();
356
357 tokio::spawn(async move {
358 let dirty_cells = {
359 let session = state_clone.session.read().await;
360 session.get_dirty_cell_ids()
361 };
362
363 for cell_id in dirty_cells {
365 let state_for_blocking = state_clone.clone();
366 let exec_result = tokio::task::spawn_blocking(move || {
367 let rt = tokio::runtime::Handle::current();
368 rt.block_on(async {
369 let mut session = state_for_blocking.session.write().await;
370 session.execute_cell(cell_id).await
371 })
372 }).await;
373
374 match exec_result {
375 Ok(Ok(())) => {}
376 Ok(Err(e)) => {
377 tracing::debug!("Execution error for {:?}: {}", cell_id, e);
378 }
379 Err(e) => {
380 tracing::error!("Task join error for {:?}: {}", cell_id, e);
381 }
382 }
383 }
384 });
385 }
386
387 ClientMessage::CellEdit { cell_id, source } => {
388 let mut session = state.session.write().await;
391 session.store_pending_edit(cell_id, source);
392 }
393
394 ClientMessage::Interrupt => {
395 tracing::debug!("Received interrupt request from client");
396 let kill_handle = state.kill_handle.lock().await;
399 if let Some(ref handle) = *kill_handle {
400 tracing::debug!("Killing worker process via interrupt request");
401 state.interrupted.store(true, Ordering::SeqCst);
403 handle.kill();
404 tracing::debug!("Kill signal sent to worker");
405 } else {
406 tracing::warn!("Interrupt received but no kill handle available");
407 send_message(
408 sender,
409 &ServerMessage::Error {
410 message: "No execution in progress to abort".to_string(),
411 },
412 )
413 .await;
414 }
415 }
416
417 ClientMessage::Sync => {
418 let session = state.session.read().await;
419 let rs_path = session.path();
420 let ipynb_path = rs_path.with_extension("ipynb");
421
422 match venus_sync::sync_to_ipynb(rs_path, &ipynb_path, None) {
423 Ok(()) => {
424 send_message(
425 sender,
426 &ServerMessage::SyncCompleted {
427 ipynb_path: ipynb_path.display().to_string(),
428 },
429 )
430 .await;
431 }
432 Err(e) => {
433 tracing::error!("Sync error: {}", e);
434 send_message(
435 sender,
436 &ServerMessage::Error {
437 message: e.to_string(),
438 },
439 )
440 .await;
441 }
442 }
443 }
444
445 ClientMessage::GetGraph => {
446 let session = state.session.read().await;
447 let state_msg = session.get_state();
448 send_message(sender, &state_msg).await;
449 }
450
451 ClientMessage::WidgetUpdate {
452 cell_id,
453 widget_id,
454 value,
455 } => {
456 let mut session = state.session.write().await;
458 session.update_widget_value(cell_id, widget_id, value);
459 }
461
462 ClientMessage::SelectHistory { cell_id, index } => {
463 let mut session = state.session.write().await;
464
465 let output = session.select_history_entry(cell_id, index);
466
467 if let Some(output) = output {
468 let dirty_cells: Vec<CellId> = session.cell_states()
470 .iter()
471 .filter(|(_, s)| s.is_dirty())
472 .map(|(id, _)| *id)
473 .collect();
474
475 let count = session.get_history_count(cell_id);
476
477 session.broadcast(ServerMessage::HistorySelected {
478 cell_id,
479 index,
480 count,
481 output: Some(output),
482 dirty_cells,
483 });
484 }
485 }
486
487 ClientMessage::InsertCell { after_cell_id } => {
488 let mut session = state.session.write().await;
489
490 match session.insert_cell(after_cell_id) {
491 Ok(new_name) => {
492 let new_cell_id = session.cell_states()
494 .iter()
495 .find(|(_, s)| s.name().unwrap_or("") == new_name)
496 .map(|(id, _)| *id)
497 .unwrap_or(CellId::new(0));
498
499 send_message(sender, &ServerMessage::CellInserted {
501 cell_id: new_cell_id,
502 error: None,
503 }).await;
504
505 let state_msg = session.get_state();
507 session.broadcast(state_msg);
508 let undo_state = session.get_undo_redo_state();
509 session.broadcast(undo_state);
510 }
511 Err(e) => {
512 send_message(sender, &ServerMessage::CellInserted {
513 cell_id: CellId::new(0),
514 error: Some(e.to_string()),
515 }).await;
516 }
517 }
518 }
519
520 ClientMessage::DeleteCell { cell_id } => {
521 let mut session = state.session.write().await;
522
523 match session.delete_cell(cell_id) {
524 Ok(()) => {
525 send_message(sender, &ServerMessage::CellDeleted {
527 cell_id,
528 error: None,
529 }).await;
530
531 let state_msg = session.get_state();
533 session.broadcast(state_msg);
534 let undo_state = session.get_undo_redo_state();
535 session.broadcast(undo_state);
536 }
537 Err(e) => {
538 send_message(sender, &ServerMessage::CellDeleted {
539 cell_id,
540 error: Some(e.to_string()),
541 }).await;
542 }
543 }
544 }
545
546 ClientMessage::DuplicateCell { cell_id } => {
547 let mut session = state.session.write().await;
548
549 match session.duplicate_cell(cell_id) {
550 Ok(new_name) => {
551 let new_cell_id = session.cell_states()
553 .iter()
554 .find(|(_, s)| s.name().unwrap_or("") == new_name)
555 .map(|(id, _)| *id)
556 .unwrap_or(CellId::new(0));
557
558 send_message(sender, &ServerMessage::CellDuplicated {
560 original_cell_id: cell_id,
561 new_cell_id,
562 error: None,
563 }).await;
564
565 let state_msg = session.get_state();
567 session.broadcast(state_msg);
568 let undo_state = session.get_undo_redo_state();
569 session.broadcast(undo_state);
570 }
571 Err(e) => {
572 send_message(sender, &ServerMessage::CellDuplicated {
573 original_cell_id: cell_id,
574 new_cell_id: CellId::new(0),
575 error: Some(e.to_string()),
576 }).await;
577 }
578 }
579 }
580
581 ClientMessage::MoveCell { cell_id, direction } => {
582 let mut session = state.session.write().await;
583
584 match session.move_cell(cell_id, direction) {
585 Ok(()) => {
586 send_message(sender, &ServerMessage::CellMoved {
588 cell_id,
589 error: None,
590 }).await;
591
592 let state_msg = session.get_state();
594 session.broadcast(state_msg);
595 let undo_state = session.get_undo_redo_state();
596 session.broadcast(undo_state);
597 }
598 Err(e) => {
599 send_message(sender, &ServerMessage::CellMoved {
600 cell_id,
601 error: Some(e.to_string()),
602 }).await;
603 }
604 }
605 }
606
607 ClientMessage::Undo => {
608 let mut session = state.session.write().await;
609
610 match session.undo() {
611 Ok(description) => {
612 send_message(sender, &ServerMessage::UndoResult {
614 success: true,
615 error: None,
616 description: Some(description),
617 }).await;
618
619 let state_msg = session.get_state();
621 session.broadcast(state_msg);
622 let undo_state = session.get_undo_redo_state();
623 session.broadcast(undo_state);
624 }
625 Err(e) => {
626 send_message(sender, &ServerMessage::UndoResult {
627 success: false,
628 error: Some(e.to_string()),
629 description: None,
630 }).await;
631 }
632 }
633 }
634
635 ClientMessage::Redo => {
636 let mut session = state.session.write().await;
637
638 match session.redo() {
639 Ok(description) => {
640 send_message(sender, &ServerMessage::RedoResult {
642 success: true,
643 error: None,
644 description: Some(description),
645 }).await;
646
647 let state_msg = session.get_state();
649 session.broadcast(state_msg);
650 let undo_state = session.get_undo_redo_state();
651 session.broadcast(undo_state);
652 }
653 Err(e) => {
654 send_message(sender, &ServerMessage::RedoResult {
655 success: false,
656 error: Some(e.to_string()),
657 description: None,
658 }).await;
659 }
660 }
661 }
662
663 ClientMessage::RestartKernel => {
664 let mut session = state.session.write().await;
665
666 match session.restart_kernel() {
667 Ok(()) => {
668 tracing::info!("Kernel restarted successfully");
669 }
671 Err(e) => {
672 tracing::error!("Kernel restart failed: {}", e);
673 send_message(sender, &ServerMessage::KernelRestarted {
674 error: Some(e.to_string()),
675 }).await;
676 }
677 }
678 }
679
680 ClientMessage::ClearOutputs => {
681 let mut session = state.session.write().await;
682 session.clear_outputs();
683 tracing::info!("All cell outputs cleared");
684 }
686
687 ClientMessage::RenameCell { cell_id, new_display_name } => {
688 let mut session = state.session.write().await;
689
690 match session.rename_cell(cell_id, new_display_name.clone()) {
691 Ok(()) => {
692 send_message(sender, &ServerMessage::CellRenamed {
694 cell_id,
695 new_display_name,
696 error: None,
697 }).await;
698
699 let state_msg = session.get_state();
701 session.broadcast(state_msg);
702 let undo_state = session.get_undo_redo_state();
703 session.broadcast(undo_state);
704 }
705 Err(e) => {
706 send_message(sender, &ServerMessage::CellRenamed {
707 cell_id,
708 new_display_name,
709 error: Some(e.to_string()),
710 }).await;
711 }
712 }
713 }
714
715 ClientMessage::InsertMarkdownCell { content, after_cell_id } => {
716 let mut session = state.session.write().await;
717
718 handle_cell_operation(
719 &mut session,
720 |s| {
721 s.insert_markdown_cell(content, after_cell_id)?;
722 let new_cell_id = s.cell_states()
724 .iter()
725 .filter_map(|(id, state)| {
726 if matches!(state, CellState::Markdown { .. }) {
727 Some(*id)
728 } else {
729 None
730 }
731 })
732 .last()
733 .unwrap_or(CellId::new(0));
734 Ok(new_cell_id)
735 },
736 |result| match result {
737 Ok(cell_id) => ServerMessage::MarkdownCellInserted {
738 cell_id,
739 error: None,
740 },
741 Err(e) => ServerMessage::MarkdownCellInserted {
742 cell_id: CellId::new(0),
743 error: Some(e),
744 },
745 },
746 sender,
747 ).await;
748 }
749
750 ClientMessage::EditMarkdownCell { cell_id, new_content } => {
751 let mut session = state.session.write().await;
752
753 handle_cell_operation(
754 &mut session,
755 |s| s.edit_markdown_cell(cell_id, new_content),
756 |result| ServerMessage::MarkdownCellEdited {
757 cell_id,
758 error: result.err(),
759 },
760 sender,
761 ).await;
762 }
763
764 ClientMessage::DeleteMarkdownCell { cell_id } => {
765 let mut session = state.session.write().await;
766
767 handle_cell_operation(
768 &mut session,
769 |s| s.delete_markdown_cell(cell_id),
770 |result| ServerMessage::MarkdownCellDeleted {
771 cell_id,
772 error: result.err(),
773 },
774 sender,
775 ).await;
776 }
777
778 ClientMessage::MoveMarkdownCell { cell_id, direction } => {
779 let mut session = state.session.write().await;
780
781 handle_cell_operation(
782 &mut session,
783 |s| s.move_markdown_cell(cell_id, direction),
784 |result| ServerMessage::MarkdownCellMoved {
785 cell_id,
786 error: result.err(),
787 },
788 sender,
789 ).await;
790 }
791
792 ClientMessage::InsertDefinitionCell { content, definition_type, after_cell_id } => {
793 let mut session = state.session.write().await;
794
795 handle_cell_operation(
796 &mut session,
797 |s| s.insert_definition_cell(content, definition_type, after_cell_id),
798 |result| match result {
799 Ok(cell_id) => ServerMessage::DefinitionCellInserted {
800 cell_id,
801 error: None,
802 },
803 Err(e) => ServerMessage::DefinitionCellInserted {
804 cell_id: CellId::new(0),
805 error: Some(e),
806 },
807 },
808 sender,
809 ).await;
810 }
811
812 ClientMessage::EditDefinitionCell { cell_id, new_content } => {
813 let mut session = state.session.write().await;
814
815 handle_cell_operation(
816 &mut session,
817 |s| s.edit_definition_cell(cell_id, new_content),
818 |result| match result {
819 Ok(dirty_cells) => ServerMessage::DefinitionCellEdited {
820 cell_id,
821 error: None,
822 dirty_cells,
823 },
824 Err(e) => ServerMessage::DefinitionCellEdited {
825 cell_id,
826 error: Some(e),
827 dirty_cells: vec![],
828 },
829 },
830 sender,
831 ).await;
832 }
833
834 ClientMessage::DeleteDefinitionCell { cell_id } => {
835 let mut session = state.session.write().await;
836
837 handle_cell_operation(
838 &mut session,
839 |s| s.delete_definition_cell(cell_id),
840 |result| ServerMessage::DefinitionCellDeleted {
841 cell_id,
842 error: result.err(),
843 },
844 sender,
845 ).await;
846 }
847
848 ClientMessage::MoveDefinitionCell { cell_id, direction } => {
849 let mut session = state.session.write().await;
850
851 handle_cell_operation(
852 &mut session,
853 |s| s.move_definition_cell(cell_id, direction),
854 |result| ServerMessage::DefinitionCellMoved {
855 cell_id,
856 error: result.err(),
857 },
858 sender,
859 ).await;
860 }
861 }
862}
863
864#[cfg(test)]
865mod tests {
866 #[test]
867 fn test_health_json() {
868 let health = serde_json::json!({
869 "status": "ok",
870 "version": env!("CARGO_PKG_VERSION")
871 });
872 assert_eq!(health["status"], "ok");
873 }
874}