1use std::{
21 borrow::Cow,
22 convert::TryFrom,
23 future::Future,
24 str::FromStr,
25 sync::{
26 atomic::{AtomicUsize, Ordering},
27 Arc,
28 },
29 time::Duration,
30};
31
32use base64::Engine;
33use futures::StreamExt;
34use http_types::Method;
35use serde::{Deserialize, Serialize};
36use tide::{http::Mime, sse::Sender, Request, Response, Server, StatusCode};
37use tokio::{task::JoinHandle, time::timeout};
38use zenoh::{
39 bytes::{Encoding, ZBytes},
40 internal::{
41 bail,
42 plugins::{RunningPluginTrait, ZenohPlugin},
43 runtime::Runtime,
44 zerror,
45 },
46 key_expr::{keyexpr, KeyExpr},
47 query::{Parameters, QueryConsolidation, Reply, Selector, ZenohParameters},
48 sample::{Sample, SampleKind},
49 session::Session,
50 Result as ZResult,
51};
52use zenoh_plugin_trait::{plugin_long_version, plugin_version, Plugin, PluginControl};
53
54mod config;
55pub use config::Config;
56use zenoh::query::ReplyError;
57
58const GIT_VERSION: &str = git_version::git_version!(prefix = "v", cargo_prefix = "v");
59lazy_static::lazy_static! {
60 static ref LONG_VERSION: String = format!("{} built with {}", GIT_VERSION, env!("RUSTC_VERSION"));
61}
62const RAW_KEY: &str = "_raw";
63
64lazy_static::lazy_static! {
65 static ref WORKER_THREAD_NUM: AtomicUsize = AtomicUsize::new(config::DEFAULT_WORK_THREAD_NUM);
66 static ref MAX_BLOCK_THREAD_NUM: AtomicUsize = AtomicUsize::new(config::DEFAULT_MAX_BLOCK_THREAD_NUM);
67 static ref TOKIO_RUNTIME: tokio::runtime::Runtime = tokio::runtime::Builder::new_multi_thread()
69 .worker_threads(WORKER_THREAD_NUM.load(Ordering::SeqCst))
70 .max_blocking_threads(MAX_BLOCK_THREAD_NUM.load(Ordering::SeqCst))
71 .enable_all()
72 .build()
73 .expect("Unable to create runtime");
74}
75
76#[inline(always)]
77pub(crate) fn blockon_runtime<F: Future>(task: F) -> F::Output {
78 match tokio::runtime::Handle::try_current() {
80 Ok(rt) => {
81 tokio::task::block_in_place(|| rt.block_on(task))
83 }
84 Err(_) => {
85 tokio::task::block_in_place(|| TOKIO_RUNTIME.block_on(task))
87 }
88 }
89}
90
91pub(crate) fn spawn_runtime<F>(task: F) -> JoinHandle<F::Output>
92where
93 F: Future + Send + 'static,
94 F::Output: Send + 'static,
95{
96 match tokio::runtime::Handle::try_current() {
98 Ok(rt) => {
99 rt.spawn(task)
101 }
102 Err(_) => {
103 TOKIO_RUNTIME.spawn(task)
105 }
106 }
107}
108
109#[derive(Serialize, Deserialize)]
110struct JSONSample {
111 key: String,
112 value: serde_json::Value,
113 encoding: String,
114 timestamp: Option<String>,
115}
116
117pub fn base64_encode(data: &[u8]) -> String {
118 use base64::engine::general_purpose;
119 general_purpose::STANDARD.encode(data)
120}
121
122fn payload_to_json(payload: &ZBytes, encoding: &Encoding) -> serde_json::Value {
123 if payload.is_empty() {
124 return serde_json::Value::Null;
125 }
126 match encoding {
127 &Encoding::APPLICATION_JSON | &Encoding::TEXT_JSON | &Encoding::TEXT_JSON5 => {
129 let bytes = payload.to_bytes();
130 serde_json::from_slice(&bytes).unwrap_or_else(|e| {
131 tracing::warn!(
132 "Encoding is JSON but data is not JSON, converting to base64, Error: {e:?}"
133 );
134 serde_json::Value::String(base64_encode(&bytes))
135 })
136 }
137 &Encoding::TEXT_PLAIN | &Encoding::ZENOH_STRING => serde_json::Value::String(
138 String::from_utf8(payload.to_bytes().into_owned()).unwrap_or_else(|e| {
139 tracing::warn!(
140 "Encoding is String but data is not String, converting to base64, Error: {e:?}"
141 );
142 base64_encode(e.as_bytes())
143 }),
144 ),
145 _ => serde_json::Value::String(base64_encode(&payload.to_bytes())),
147 }
148}
149
150fn sample_to_json(sample: &Sample) -> JSONSample {
151 JSONSample {
152 key: sample.key_expr().as_str().to_string(),
153 value: payload_to_json(sample.payload(), sample.encoding()),
154 encoding: sample.encoding().to_string(),
155 timestamp: sample.timestamp().map(|ts| ts.to_string()),
156 }
157}
158
159fn result_to_json(sample: Result<&Sample, &ReplyError>) -> JSONSample {
160 match sample {
161 Ok(sample) => sample_to_json(sample),
162 Err(err) => JSONSample {
163 key: "ERROR".into(),
164 value: payload_to_json(err.payload(), err.encoding()),
165 encoding: err.encoding().to_string(),
166 timestamp: None,
167 },
168 }
169}
170
171async fn to_json(results: flume::Receiver<Reply>) -> String {
172 let values = results
173 .stream()
174 .filter_map(move |reply| async move { Some(result_to_json(reply.result())) })
175 .collect::<Vec<JSONSample>>()
176 .await;
177
178 serde_json::to_string(&values).unwrap_or("[]".into())
179}
180
181async fn to_json_response(results: flume::Receiver<Reply>) -> Response {
182 response(StatusCode::Ok, "application/json", &to_json(results).await)
183}
184
185fn sample_to_html(sample: &Sample) -> String {
186 format!(
187 "<dt>{}</dt>\n<dd>{}</dd>\n",
188 sample.key_expr().as_str(),
189 sample.payload().try_to_string().unwrap_or_default()
190 )
191}
192
193fn result_to_html(sample: Result<&Sample, &ReplyError>) -> String {
194 match sample {
195 Ok(sample) => sample_to_html(sample),
196 Err(err) => {
197 format!(
198 "<dt>ERROR</dt>\n<dd>{}</dd>\n",
199 err.payload().try_to_string().unwrap_or_default()
200 )
201 }
202 }
203}
204
205async fn to_html(results: flume::Receiver<Reply>) -> String {
206 let values = results
207 .stream()
208 .filter_map(move |reply| async move { Some(result_to_html(reply.result())) })
209 .collect::<Vec<String>>()
210 .await
211 .join("\n");
212 format!("<dl>\n{values}\n</dl>\n")
213}
214
215async fn to_html_response(results: flume::Receiver<Reply>) -> Response {
216 response(StatusCode::Ok, "text/html", &to_html(results).await)
217}
218
219async fn to_raw_response(results: flume::Receiver<Reply>) -> Response {
220 match results.recv_async().await {
221 Ok(reply) => match reply.result() {
222 Ok(sample) => response(
223 StatusCode::Ok,
224 Cow::from(sample.encoding()).as_ref(),
225 &sample.payload().try_to_string().unwrap_or_default(),
226 ),
227 Err(value) => response(
228 StatusCode::Ok,
229 Cow::from(value.encoding()).as_ref(),
230 &value.payload().try_to_string().unwrap_or_default(),
231 ),
232 },
233 Err(_) => response(StatusCode::Ok, "", ""),
234 }
235}
236
237fn method_to_kind(method: Method) -> SampleKind {
238 match method {
239 Method::Put => SampleKind::Put,
240 Method::Delete => SampleKind::Delete,
241 _ => SampleKind::default(),
242 }
243}
244
245fn response<'a, S: Into<&'a str> + std::fmt::Debug>(
246 status: StatusCode,
247 content_type: S,
248 body: &str,
249) -> Response {
250 tracing::trace!("Outgoing Response: {status} - {content_type:?} - body: {body}");
251 let mut builder = Response::builder(status)
252 .header("content-length", body.len().to_string())
253 .header("Access-Control-Allow-Origin", "*")
254 .body(body);
255 if let Ok(mime) = Mime::from_str(content_type.into()) {
256 builder = builder.content_type(mime);
257 }
258 builder.build()
259}
260
261#[cfg(feature = "dynamic_plugin")]
262zenoh_plugin_trait::declare_plugin!(RestPlugin);
263
264pub struct RestPlugin {}
265
266impl ZenohPlugin for RestPlugin {}
267
268impl Plugin for RestPlugin {
269 type StartArgs = Runtime;
270 type Instance = zenoh::internal::plugins::RunningPlugin;
271 const DEFAULT_NAME: &'static str = "rest";
272 const PLUGIN_VERSION: &'static str = plugin_version!();
273 const PLUGIN_LONG_VERSION: &'static str = plugin_long_version!();
274
275 fn start(
276 name: &str,
277 runtime: &Self::StartArgs,
278 ) -> ZResult<zenoh::internal::plugins::RunningPlugin> {
279 zenoh::init_log_from_env_or("error");
283 tracing::debug!("REST plugin {}", LONG_VERSION.as_str());
284
285 let runtime_conf = runtime.config().lock();
286 let plugin_conf = runtime_conf
287 .plugin(name)
288 .ok_or_else(|| zerror!("Plugin `{}`: missing config", name))?;
289
290 let conf: Config = serde_json::from_value(plugin_conf.clone())
291 .map_err(|e| zerror!("Plugin `{}` configuration error: {}", name, e))?;
292 WORKER_THREAD_NUM.store(conf.work_thread_num, Ordering::SeqCst);
293 MAX_BLOCK_THREAD_NUM.store(conf.max_block_thread_num, Ordering::SeqCst);
294
295 let task = run(runtime.clone(), conf.clone());
296 let task =
297 blockon_runtime(async { timeout(Duration::from_millis(1), spawn_runtime(task)).await });
298
299 if let Ok(Ok(Err(e))) = task {
302 bail!("REST server failed within 1ms: {e}")
303 }
304
305 Ok(Box::new(RunningPlugin(conf)))
306 }
307}
308
309struct RunningPlugin(Config);
310
311impl PluginControl for RunningPlugin {}
312
313impl RunningPluginTrait for RunningPlugin {
314 fn adminspace_getter<'a>(
315 &'a self,
316 key_expr: &'a KeyExpr<'a>,
317 plugin_status_key: &str,
318 ) -> ZResult<Vec<zenoh::internal::plugins::Response>> {
319 let mut responses = Vec::new();
320 let mut key = String::from(plugin_status_key);
321 with_extended_string(&mut key, &["/version"], |key| {
322 if keyexpr::new(key.as_str()).unwrap().intersects(key_expr) {
323 responses.push(zenoh::internal::plugins::Response::new(
324 key.clone(),
325 GIT_VERSION.into(),
326 ))
327 }
328 });
329 with_extended_string(&mut key, &["/port"], |port_key| {
330 if keyexpr::new(port_key.as_str())
331 .unwrap()
332 .intersects(key_expr)
333 {
334 responses.push(zenoh::internal::plugins::Response::new(
335 port_key.clone(),
336 (&self.0).into(),
337 ))
338 }
339 });
340 Ok(responses)
341 }
342}
343
344fn with_extended_string<R, F: FnMut(&mut String) -> R>(
345 prefix: &mut String,
346 suffixes: &[&str],
347 mut closure: F,
348) -> R {
349 let prefix_len = prefix.len();
350 for suffix in suffixes {
351 prefix.push_str(suffix);
352 }
353 let result = closure(prefix);
354 prefix.truncate(prefix_len);
355 result
356}
357
358async fn query(mut req: Request<(Arc<Session>, String)>) -> tide::Result<Response> {
359 tracing::trace!("Incoming GET request: {:?}", req);
360
361 let first_accept = match req.header("accept") {
362 Some(accept) => accept[0]
363 .to_string()
364 .split(';')
365 .next()
366 .unwrap()
367 .split(',')
368 .next()
369 .unwrap()
370 .to_string(),
371 None => "application/json".to_string(),
372 };
373 if first_accept == "text/event-stream" {
374 Ok(tide::sse::upgrade(
375 req,
376 move |req: Request<(Arc<Session>, String)>, sender: Sender| async move {
377 let key_expr = match path_to_key_expr(req.url().path(), &req.state().1) {
378 Ok(ke) => ke.into_owned(),
379 Err(e) => {
380 return Err(tide::Error::new(
381 tide::StatusCode::BadRequest,
382 anyhow::anyhow!("{}", e),
383 ))
384 }
385 };
386 spawn_runtime(async move {
387 tracing::debug!("Subscribe to {} for SSE stream", key_expr);
388 let sender = &sender;
389 let sub = req.state().0.declare_subscriber(&key_expr).await.unwrap();
390 loop {
391 let sample = sub.recv_async().await.unwrap();
392 let json_sample =
393 serde_json::to_string(&sample_to_json(&sample)).unwrap_or("{}".into());
394
395 match timeout(
396 std::time::Duration::new(10, 0),
397 sender.send(&sample.kind().to_string(), json_sample, None),
398 )
399 .await
400 {
401 Ok(Ok(_)) => {}
402 Ok(Err(e)) => {
403 tracing::debug!("SSE error ({})! Unsubscribe and terminate", e);
404 if let Err(e) = sub.undeclare().await {
405 tracing::error!("Error undeclaring subscriber: {}", e);
406 }
407 break;
408 }
409 Err(_) => {
410 tracing::debug!("SSE timeout! Unsubscribe and terminate",);
411 if let Err(e) = sub.undeclare().await {
412 tracing::error!("Error undeclaring subscriber: {}", e);
413 }
414 break;
415 }
416 }
417 }
418 });
419 Ok(())
420 },
421 ))
422 } else {
423 let body = req.body_bytes().await.unwrap_or_default();
424 let url = req.url();
425 let key_expr = match path_to_key_expr(url.path(), &req.state().1) {
426 Ok(ke) => ke,
427 Err(e) => {
428 return Ok(response(
429 StatusCode::BadRequest,
430 "text/plain",
431 &e.to_string(),
432 ))
433 }
434 };
435 let query_part = url.query();
436 let parameters = Parameters::from(query_part.unwrap_or_default());
437 let consolidation = if parameters.time_range().is_some() {
438 QueryConsolidation::from(zenoh::query::ConsolidationMode::None)
439 } else {
440 QueryConsolidation::from(zenoh::query::ConsolidationMode::Latest)
441 };
442 let raw = parameters.contains_key(RAW_KEY);
443 let mut query = req
444 .state()
445 .0
446 .get(Selector::borrowed(&key_expr, ¶meters))
447 .consolidation(consolidation)
448 .with(flume::unbounded());
449 if !body.is_empty() {
450 let encoding: Encoding = req
451 .content_type()
452 .map(|m| Encoding::from(m.to_string()))
453 .unwrap_or_default();
454 query = query.payload(body).encoding(encoding);
455 }
456 match query.await {
457 Ok(receiver) => {
458 if raw {
459 Ok(to_raw_response(receiver).await)
460 } else if first_accept == "text/html" {
461 Ok(to_html_response(receiver).await)
462 } else {
463 Ok(to_json_response(receiver).await)
464 }
465 }
466 Err(e) => Ok(response(
467 StatusCode::InternalServerError,
468 "text/plain",
469 &e.to_string(),
470 )),
471 }
472 }
473}
474
475async fn write(mut req: Request<(Arc<Session>, String)>) -> tide::Result<Response> {
476 tracing::trace!("Incoming PUT request: {:?}", req);
477 match req.body_bytes().await {
478 Ok(bytes) => {
479 let key_expr = match path_to_key_expr(req.url().path(), &req.state().1) {
480 Ok(ke) => ke,
481 Err(e) => {
482 return Ok(response(
483 StatusCode::BadRequest,
484 "text/plain",
485 &e.to_string(),
486 ))
487 }
488 };
489
490 let encoding: Encoding = req
491 .content_type()
492 .map(|m| Encoding::from(m.to_string()))
493 .unwrap_or_default();
494
495 let session = &req.state().0;
497 let res = match method_to_kind(req.method()) {
498 SampleKind::Put => session.put(&key_expr, bytes).encoding(encoding).await,
499 SampleKind::Delete => session.delete(&key_expr).await,
500 };
501 match res {
502 Ok(_) => Ok(Response::new(StatusCode::Ok)),
503 Err(e) => Ok(response(
504 StatusCode::InternalServerError,
505 "text/plain",
506 &e.to_string(),
507 )),
508 }
509 }
510 Err(e) => Ok(response(
511 StatusCode::NoContent,
512 "text/plain",
513 &e.to_string(),
514 )),
515 }
516}
517
518pub async fn run(runtime: Runtime, conf: Config) -> ZResult<()> {
519 zenoh::init_log_from_env_or("error");
523
524 let zid = runtime.zid().to_string();
525 let session = zenoh::session::init(runtime).await.unwrap();
526
527 let mut app = Server::with_state((Arc::new(session), zid));
528 app.with(
529 tide::security::CorsMiddleware::new()
530 .allow_methods(
531 "GET, POST, PUT, PATCH, DELETE"
532 .parse::<http_types::headers::HeaderValue>()
533 .unwrap(),
534 )
535 .allow_origin(tide::security::Origin::from("*"))
536 .allow_credentials(false),
537 );
538
539 app.at("/")
540 .get(query)
541 .post(query)
542 .put(write)
543 .patch(write)
544 .delete(write);
545 app.at("*")
546 .get(query)
547 .post(query)
548 .put(write)
549 .patch(write)
550 .delete(write);
551
552 if let Err(e) = app.listen(conf.http_port).await {
553 tracing::error!("Unable to start http server for REST: {:?}", e);
554 return Err(e.into());
555 }
556 Ok(())
557}
558
559fn path_to_key_expr<'a>(path: &'a str, zid: &str) -> ZResult<KeyExpr<'a>> {
560 let path = path.strip_prefix('/').unwrap_or(path);
561 if path == "@/local" {
562 KeyExpr::try_from(format!("@/{zid}"))
563 } else if let Some(suffix) = path.strip_prefix("@/local/") {
564 KeyExpr::try_from(format!("@/{zid}/{suffix}"))
565 } else {
566 KeyExpr::try_from(path)
567 }
568}