Skip to main content

surrealdb_server/rpc/
mod.rs

1pub mod format;
2pub mod http;
3pub mod response;
4pub mod websocket;
5
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::Duration;
9
10use futures::stream::FuturesUnordered;
11use opentelemetry::Context as TelemetryContext;
12use surrealdb_core::kvs::Datastore;
13use surrealdb_core::rpc::{DbResponse, DbResult};
14use tokio::sync::RwLock;
15use tokio_stream::StreamExt;
16use tokio_util::sync::CancellationToken;
17use uuid::Uuid;
18
19use crate::rpc::websocket::Websocket;
20use crate::telemetry::metrics::ws::NotificationContext;
21
22static CONN_CLOSED_ERR: &str = "Connection closed normally";
23/// A type alias for an RPC Connection
24type WebSocket = Arc<Websocket>;
25/// Mapping of WebSocket ID to WebSocket
26type WebSockets = RwLock<HashMap<Uuid, WebSocket>>;
27/// Mapping of LIVE Query ID to WebSocket ID + Session ID
28type LiveQueries = RwLock<HashMap<Uuid, (Uuid, Option<Uuid>)>>;
29
30pub struct RpcState {
31	/// Stores the currently connected WebSockets
32	pub web_sockets: WebSockets,
33	/// Stores the currently initiated LIVE queries
34	pub live_queries: LiveQueries,
35	/// HTTP RPC handler with persistent sessions
36	pub http: Arc<crate::rpc::http::Http>,
37}
38
39impl RpcState {
40	pub fn new(
41		datastore: Arc<surrealdb_core::kvs::Datastore>,
42		session: surrealdb_core::dbs::Session,
43	) -> Self {
44		Self {
45			web_sockets: RwLock::new(HashMap::new()),
46			live_queries: RwLock::new(HashMap::new()),
47			http: Arc::new(crate::rpc::http::Http::new(datastore, session)),
48		}
49	}
50}
51
52/// Performs notification delivery to the WebSockets
53pub(crate) async fn notifications(
54	ds: Arc<Datastore>,
55	state: Arc<RpcState>,
56	canceller: CancellationToken,
57) {
58	// Store messages being delivered
59	let mut futures = FuturesUnordered::new();
60	// Listen to the notifications channel
61	if let Some(channel) = ds.notifications() {
62		// Loop continuously
63		loop {
64			tokio::select! {
65				//
66				biased;
67				// Check if this has shutdown
68				_ = canceller.cancelled() => break,
69				// Process any buffered messages
70				Some(_) = futures.next() => continue,
71				// Receive a notification on the channel
72				Ok(notification) = channel.recv() => {
73					// Get the id for this notification
74					let id = notification.id.as_ref();
75					// Get the WebSocket for this notification
76					let websocket = {
77						state.live_queries.read().await.get(id).copied()
78					};
79					// Ensure the specified WebSocket exists
80					if let Some((id, session_id)) = websocket.as_ref() {
81						// Get the WebSocket for this notification
82						let websocket = {
83							state.web_sockets.read().await.get(id).cloned()
84						};
85						// Ensure the specified WebSocket exists
86						if let Some(rpc) = websocket {
87							// Serialize the message to send
88							let message = DbResponse::success(None, session_id.map(Into::into), DbResult::Live(notification));
89							// Add telemetry metrics
90							let cx = TelemetryContext::new();
91							let not_ctx = NotificationContext::default()
92								  .with_live_id(id.to_string());
93							let cx = Arc::new(cx.with_value(not_ctx));
94							// Get the WebSocket output format
95							let format = rpc.format;
96							// Get the WebSocket sending channel
97							let sender = rpc.channel.clone();
98							// Send the notification to the client
99							// let future = message.send(cx, format, sender);
100							let future = crate::rpc::response::send(message, cx, format, sender);
101							// Pus the future to the pipeline
102							futures.push(future);
103						}
104					}
105				},
106			}
107		}
108	}
109}
110
111/// Closes all WebSocket connections, waiting for graceful shutdown
112pub(crate) async fn graceful_shutdown(state: Arc<RpcState>) {
113	// Close WebSocket connections, ensuring queued messages are processed
114	for (_, rpc) in state.web_sockets.read().await.iter() {
115		rpc.shutdown.cancel();
116	}
117	// Wait for all existing WebSocket connections to finish sending
118	while !state.web_sockets.read().await.is_empty() {
119		tokio::time::sleep(Duration::from_millis(250)).await;
120	}
121}
122
123/// Forces a fast shutdown of all WebSocket connections
124pub(crate) fn shutdown(state: Arc<RpcState>) {
125	// Close all WebSocket connections immediately
126	if let Ok(mut writer) = state.web_sockets.try_write() {
127		writer.drain();
128	}
129}