surrealdb_server/rpc/
mod.rs1pub mod format;
2pub mod http;
3pub mod response;
4pub mod websocket;
5
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::Duration;
9
10use futures::stream::FuturesUnordered;
11use opentelemetry::Context as TelemetryContext;
12use surrealdb_core::kvs::Datastore;
13use surrealdb_core::rpc::{DbResponse, DbResult};
14use tokio::sync::RwLock;
15use tokio_stream::StreamExt;
16use tokio_util::sync::CancellationToken;
17use uuid::Uuid;
18
19use crate::rpc::websocket::Websocket;
20use crate::telemetry::metrics::ws::NotificationContext;
21
22static CONN_CLOSED_ERR: &str = "Connection closed normally";
23type WebSocket = Arc<Websocket>;
25type WebSockets = RwLock<HashMap<Uuid, WebSocket>>;
27type LiveQueries = RwLock<HashMap<Uuid, (Uuid, Option<Uuid>)>>;
29
30pub struct RpcState {
31 pub web_sockets: WebSockets,
33 pub live_queries: LiveQueries,
35 pub http: Arc<crate::rpc::http::Http>,
37}
38
39impl RpcState {
40 pub fn new(
41 datastore: Arc<surrealdb_core::kvs::Datastore>,
42 session: surrealdb_core::dbs::Session,
43 ) -> Self {
44 Self {
45 web_sockets: RwLock::new(HashMap::new()),
46 live_queries: RwLock::new(HashMap::new()),
47 http: Arc::new(crate::rpc::http::Http::new(datastore, session)),
48 }
49 }
50}
51
52pub(crate) async fn notifications(
54 ds: Arc<Datastore>,
55 state: Arc<RpcState>,
56 canceller: CancellationToken,
57) {
58 let mut futures = FuturesUnordered::new();
60 if let Some(channel) = ds.notifications() {
62 loop {
64 tokio::select! {
65 biased;
67 _ = canceller.cancelled() => break,
69 Some(_) = futures.next() => continue,
71 Ok(notification) = channel.recv() => {
73 let id = notification.id.as_ref();
75 let websocket = {
77 state.live_queries.read().await.get(id).copied()
78 };
79 if let Some((id, session_id)) = websocket.as_ref() {
81 let websocket = {
83 state.web_sockets.read().await.get(id).cloned()
84 };
85 if let Some(rpc) = websocket {
87 let message = DbResponse::success(None, session_id.map(Into::into), DbResult::Live(notification));
89 let cx = TelemetryContext::new();
91 let not_ctx = NotificationContext::default()
92 .with_live_id(id.to_string());
93 let cx = Arc::new(cx.with_value(not_ctx));
94 let format = rpc.format;
96 let sender = rpc.channel.clone();
98 let future = crate::rpc::response::send(message, cx, format, sender);
101 futures.push(future);
103 }
104 }
105 },
106 }
107 }
108 }
109}
110
111pub(crate) async fn graceful_shutdown(state: Arc<RpcState>) {
113 for (_, rpc) in state.web_sockets.read().await.iter() {
115 rpc.shutdown.cancel();
116 }
117 while !state.web_sockets.read().await.is_empty() {
119 tokio::time::sleep(Duration::from_millis(250)).await;
120 }
121}
122
123pub(crate) fn shutdown(state: Arc<RpcState>) {
125 if let Ok(mut writer) = state.web_sockets.try_write() {
127 writer.drain();
128 }
129}