zenoh_plugin_rest/
lib.rs

1//
2// Copyright (c) 2023 ZettaScale Technology
3//
4// This program and the accompanying materials are made available under the
5// terms of the Eclipse Public License 2.0 which is available at
6// http://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0
7// which is available at https://www.apache.org/licenses/LICENSE-2.0.
8//
9// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
10//
11// Contributors:
12//   ZettaScale Zenoh Team, <zenoh@zettascale.tech>
13//
14
15//! ⚠️ WARNING ⚠️
16//!
17//! This crate is intended for Zenoh's internal use.
18//!
19//! [Click here for Zenoh's documentation](https://docs.rs/zenoh/latest/zenoh)
20use std::{
21    borrow::Cow,
22    convert::TryFrom,
23    future::Future,
24    str::FromStr,
25    sync::{
26        atomic::{AtomicUsize, Ordering},
27        Arc,
28    },
29    time::Duration,
30};
31
32use base64::Engine;
33use futures::StreamExt;
34use http_types::Method;
35use serde::{Deserialize, Serialize};
36use tide::{http::Mime, sse::Sender, Request, Response, Server, StatusCode};
37use tokio::{task::JoinHandle, time::timeout};
38use zenoh::{
39    bytes::{Encoding, ZBytes},
40    internal::{
41        bail,
42        plugins::{RunningPluginTrait, ZenohPlugin},
43        runtime::Runtime,
44        zerror,
45    },
46    key_expr::{keyexpr, KeyExpr},
47    query::{Parameters, QueryConsolidation, Reply, Selector, ZenohParameters},
48    sample::{Sample, SampleKind},
49    session::Session,
50    Result as ZResult,
51};
52use zenoh_plugin_trait::{plugin_long_version, plugin_version, Plugin, PluginControl};
53
54mod config;
55pub use config::Config;
56use zenoh::query::ReplyError;
57
58const GIT_VERSION: &str = git_version::git_version!(prefix = "v", cargo_prefix = "v");
59lazy_static::lazy_static! {
60    static ref LONG_VERSION: String = format!("{} built with {}", GIT_VERSION, env!("RUSTC_VERSION"));
61}
62const RAW_KEY: &str = "_raw";
63
64lazy_static::lazy_static! {
65    static ref WORKER_THREAD_NUM: AtomicUsize = AtomicUsize::new(config::DEFAULT_WORK_THREAD_NUM);
66    static ref MAX_BLOCK_THREAD_NUM: AtomicUsize = AtomicUsize::new(config::DEFAULT_MAX_BLOCK_THREAD_NUM);
67    // The global runtime is used in the dynamic plugins, which we can't get the current runtime
68    static ref TOKIO_RUNTIME: tokio::runtime::Runtime = tokio::runtime::Builder::new_multi_thread()
69               .worker_threads(WORKER_THREAD_NUM.load(Ordering::SeqCst))
70               .max_blocking_threads(MAX_BLOCK_THREAD_NUM.load(Ordering::SeqCst))
71               .enable_all()
72               .build()
73               .expect("Unable to create runtime");
74}
75
76#[inline(always)]
77pub(crate) fn blockon_runtime<F: Future>(task: F) -> F::Output {
78    // Check whether able to get the current runtime
79    match tokio::runtime::Handle::try_current() {
80        Ok(rt) => {
81            // Able to get the current runtime (standalone binary), use the current runtime
82            tokio::task::block_in_place(|| rt.block_on(task))
83        }
84        Err(_) => {
85            // Unable to get the current runtime (dynamic plugins), reuse the global runtime
86            tokio::task::block_in_place(|| TOKIO_RUNTIME.block_on(task))
87        }
88    }
89}
90
91pub(crate) fn spawn_runtime<F>(task: F) -> JoinHandle<F::Output>
92where
93    F: Future + Send + 'static,
94    F::Output: Send + 'static,
95{
96    // Check whether able to get the current runtime
97    match tokio::runtime::Handle::try_current() {
98        Ok(rt) => {
99            // Able to get the current runtime (standalone binary), spawn on the current runtime
100            rt.spawn(task)
101        }
102        Err(_) => {
103            // Unable to get the current runtime (dynamic plugins), spawn on the global runtime
104            TOKIO_RUNTIME.spawn(task)
105        }
106    }
107}
108
109#[derive(Serialize, Deserialize)]
110struct JSONSample {
111    key: String,
112    value: serde_json::Value,
113    encoding: String,
114    timestamp: Option<String>,
115}
116
117pub fn base64_encode(data: &[u8]) -> String {
118    use base64::engine::general_purpose;
119    general_purpose::STANDARD.encode(data)
120}
121
122fn payload_to_json(payload: &ZBytes, encoding: &Encoding) -> serde_json::Value {
123    if payload.is_empty() {
124        return serde_json::Value::Null;
125    }
126    match encoding {
127        // If it is a JSON try to deserialize as json, if it fails fallback to base64
128        &Encoding::APPLICATION_JSON | &Encoding::TEXT_JSON | &Encoding::TEXT_JSON5 => {
129            let bytes = payload.to_bytes();
130            serde_json::from_slice(&bytes).unwrap_or_else(|e| {
131                tracing::warn!(
132                    "Encoding is JSON but data is not JSON, converting to base64, Error: {e:?}"
133                );
134                serde_json::Value::String(base64_encode(&bytes))
135            })
136        }
137        &Encoding::TEXT_PLAIN | &Encoding::ZENOH_STRING => serde_json::Value::String(
138            String::from_utf8(payload.to_bytes().into_owned()).unwrap_or_else(|e| {
139                tracing::warn!(
140                    "Encoding is String but data is not String, converting to base64, Error: {e:?}"
141                );
142                base64_encode(e.as_bytes())
143            }),
144        ),
145        // otherwise convert to JSON string
146        _ => serde_json::Value::String(base64_encode(&payload.to_bytes())),
147    }
148}
149
150fn sample_to_json(sample: &Sample) -> JSONSample {
151    JSONSample {
152        key: sample.key_expr().as_str().to_string(),
153        value: payload_to_json(sample.payload(), sample.encoding()),
154        encoding: sample.encoding().to_string(),
155        timestamp: sample.timestamp().map(|ts| ts.to_string()),
156    }
157}
158
159fn result_to_json(sample: Result<&Sample, &ReplyError>) -> JSONSample {
160    match sample {
161        Ok(sample) => sample_to_json(sample),
162        Err(err) => JSONSample {
163            key: "ERROR".into(),
164            value: payload_to_json(err.payload(), err.encoding()),
165            encoding: err.encoding().to_string(),
166            timestamp: None,
167        },
168    }
169}
170
171async fn to_json(results: flume::Receiver<Reply>) -> String {
172    let values = results
173        .stream()
174        .filter_map(move |reply| async move { Some(result_to_json(reply.result())) })
175        .collect::<Vec<JSONSample>>()
176        .await;
177
178    serde_json::to_string(&values).unwrap_or("[]".into())
179}
180
181async fn to_json_response(results: flume::Receiver<Reply>) -> Response {
182    response(StatusCode::Ok, "application/json", &to_json(results).await)
183}
184
185fn sample_to_html(sample: &Sample) -> String {
186    format!(
187        "<dt>{}</dt>\n<dd>{}</dd>\n",
188        sample.key_expr().as_str(),
189        sample.payload().try_to_string().unwrap_or_default()
190    )
191}
192
193fn result_to_html(sample: Result<&Sample, &ReplyError>) -> String {
194    match sample {
195        Ok(sample) => sample_to_html(sample),
196        Err(err) => {
197            format!(
198                "<dt>ERROR</dt>\n<dd>{}</dd>\n",
199                err.payload().try_to_string().unwrap_or_default()
200            )
201        }
202    }
203}
204
205async fn to_html(results: flume::Receiver<Reply>) -> String {
206    let values = results
207        .stream()
208        .filter_map(move |reply| async move { Some(result_to_html(reply.result())) })
209        .collect::<Vec<String>>()
210        .await
211        .join("\n");
212    format!("<dl>\n{values}\n</dl>\n")
213}
214
215async fn to_html_response(results: flume::Receiver<Reply>) -> Response {
216    response(StatusCode::Ok, "text/html", &to_html(results).await)
217}
218
219async fn to_raw_response(results: flume::Receiver<Reply>) -> Response {
220    match results.recv_async().await {
221        Ok(reply) => match reply.result() {
222            Ok(sample) => response(
223                StatusCode::Ok,
224                Cow::from(sample.encoding()).as_ref(),
225                &sample.payload().try_to_string().unwrap_or_default(),
226            ),
227            Err(value) => response(
228                StatusCode::Ok,
229                Cow::from(value.encoding()).as_ref(),
230                &value.payload().try_to_string().unwrap_or_default(),
231            ),
232        },
233        Err(_) => response(StatusCode::Ok, "", ""),
234    }
235}
236
237fn method_to_kind(method: Method) -> SampleKind {
238    match method {
239        Method::Put => SampleKind::Put,
240        Method::Delete => SampleKind::Delete,
241        _ => SampleKind::default(),
242    }
243}
244
245fn response<'a, S: Into<&'a str> + std::fmt::Debug>(
246    status: StatusCode,
247    content_type: S,
248    body: &str,
249) -> Response {
250    tracing::trace!("Outgoing Response: {status} - {content_type:?} - body: {body}");
251    let mut builder = Response::builder(status)
252        .header("content-length", body.len().to_string())
253        .header("Access-Control-Allow-Origin", "*")
254        .body(body);
255    if let Ok(mime) = Mime::from_str(content_type.into()) {
256        builder = builder.content_type(mime);
257    }
258    builder.build()
259}
260
261#[cfg(feature = "dynamic_plugin")]
262zenoh_plugin_trait::declare_plugin!(RestPlugin);
263
264pub struct RestPlugin {}
265
266impl ZenohPlugin for RestPlugin {}
267
268impl Plugin for RestPlugin {
269    type StartArgs = Runtime;
270    type Instance = zenoh::internal::plugins::RunningPlugin;
271    const DEFAULT_NAME: &'static str = "rest";
272    const PLUGIN_VERSION: &'static str = plugin_version!();
273    const PLUGIN_LONG_VERSION: &'static str = plugin_long_version!();
274
275    fn start(
276        name: &str,
277        runtime: &Self::StartArgs,
278    ) -> ZResult<zenoh::internal::plugins::RunningPlugin> {
279        // Try to initiate login.
280        // Required in case of dynamic lib, otherwise no logs.
281        // But cannot be done twice in case of static link.
282        zenoh::init_log_from_env_or("error");
283        tracing::debug!("REST plugin {}", LONG_VERSION.as_str());
284
285        let runtime_conf = runtime.config().lock();
286        let plugin_conf = runtime_conf
287            .plugin(name)
288            .ok_or_else(|| zerror!("Plugin `{}`: missing config", name))?;
289
290        let conf: Config = serde_json::from_value(plugin_conf.clone())
291            .map_err(|e| zerror!("Plugin `{}` configuration error: {}", name, e))?;
292        WORKER_THREAD_NUM.store(conf.work_thread_num, Ordering::SeqCst);
293        MAX_BLOCK_THREAD_NUM.store(conf.max_block_thread_num, Ordering::SeqCst);
294
295        let task = run(runtime.clone(), conf.clone());
296        let task =
297            blockon_runtime(async { timeout(Duration::from_millis(1), spawn_runtime(task)).await });
298
299        // The spawn task (spawn_runtime(task)).await) should not return immediately. The server should block inside.
300        // If it returns immediately (for example, address already in use), we can get the error inside Ok
301        if let Ok(Ok(Err(e))) = task {
302            bail!("REST server failed within 1ms: {e}")
303        }
304
305        Ok(Box::new(RunningPlugin(conf)))
306    }
307}
308
309struct RunningPlugin(Config);
310
311impl PluginControl for RunningPlugin {}
312
313impl RunningPluginTrait for RunningPlugin {
314    fn adminspace_getter<'a>(
315        &'a self,
316        key_expr: &'a KeyExpr<'a>,
317        plugin_status_key: &str,
318    ) -> ZResult<Vec<zenoh::internal::plugins::Response>> {
319        let mut responses = Vec::new();
320        let mut key = String::from(plugin_status_key);
321        with_extended_string(&mut key, &["/version"], |key| {
322            if keyexpr::new(key.as_str()).unwrap().intersects(key_expr) {
323                responses.push(zenoh::internal::plugins::Response::new(
324                    key.clone(),
325                    GIT_VERSION.into(),
326                ))
327            }
328        });
329        with_extended_string(&mut key, &["/port"], |port_key| {
330            if keyexpr::new(port_key.as_str())
331                .unwrap()
332                .intersects(key_expr)
333            {
334                responses.push(zenoh::internal::plugins::Response::new(
335                    port_key.clone(),
336                    (&self.0).into(),
337                ))
338            }
339        });
340        Ok(responses)
341    }
342}
343
344fn with_extended_string<R, F: FnMut(&mut String) -> R>(
345    prefix: &mut String,
346    suffixes: &[&str],
347    mut closure: F,
348) -> R {
349    let prefix_len = prefix.len();
350    for suffix in suffixes {
351        prefix.push_str(suffix);
352    }
353    let result = closure(prefix);
354    prefix.truncate(prefix_len);
355    result
356}
357
358async fn query(mut req: Request<(Arc<Session>, String)>) -> tide::Result<Response> {
359    tracing::trace!("Incoming GET request: {:?}", req);
360
361    let first_accept = match req.header("accept") {
362        Some(accept) => accept[0]
363            .to_string()
364            .split(';')
365            .next()
366            .unwrap()
367            .split(',')
368            .next()
369            .unwrap()
370            .to_string(),
371        None => "application/json".to_string(),
372    };
373    if first_accept == "text/event-stream" {
374        Ok(tide::sse::upgrade(
375            req,
376            move |req: Request<(Arc<Session>, String)>, sender: Sender| async move {
377                let key_expr = match path_to_key_expr(req.url().path(), &req.state().1) {
378                    Ok(ke) => ke.into_owned(),
379                    Err(e) => {
380                        return Err(tide::Error::new(
381                            tide::StatusCode::BadRequest,
382                            anyhow::anyhow!("{}", e),
383                        ))
384                    }
385                };
386                spawn_runtime(async move {
387                    tracing::debug!("Subscribe to {} for SSE stream", key_expr);
388                    let sender = &sender;
389                    let sub = req.state().0.declare_subscriber(&key_expr).await.unwrap();
390                    loop {
391                        let sample = sub.recv_async().await.unwrap();
392                        let json_sample =
393                            serde_json::to_string(&sample_to_json(&sample)).unwrap_or("{}".into());
394
395                        match timeout(
396                            std::time::Duration::new(10, 0),
397                            sender.send(&sample.kind().to_string(), json_sample, None),
398                        )
399                        .await
400                        {
401                            Ok(Ok(_)) => {}
402                            Ok(Err(e)) => {
403                                tracing::debug!("SSE error ({})! Unsubscribe and terminate", e);
404                                if let Err(e) = sub.undeclare().await {
405                                    tracing::error!("Error undeclaring subscriber: {}", e);
406                                }
407                                break;
408                            }
409                            Err(_) => {
410                                tracing::debug!("SSE timeout! Unsubscribe and terminate",);
411                                if let Err(e) = sub.undeclare().await {
412                                    tracing::error!("Error undeclaring subscriber: {}", e);
413                                }
414                                break;
415                            }
416                        }
417                    }
418                });
419                Ok(())
420            },
421        ))
422    } else {
423        let body = req.body_bytes().await.unwrap_or_default();
424        let url = req.url();
425        let key_expr = match path_to_key_expr(url.path(), &req.state().1) {
426            Ok(ke) => ke,
427            Err(e) => {
428                return Ok(response(
429                    StatusCode::BadRequest,
430                    "text/plain",
431                    &e.to_string(),
432                ))
433            }
434        };
435        let query_part = url.query();
436        let parameters = Parameters::from(query_part.unwrap_or_default());
437        let consolidation = if parameters.time_range().is_some() {
438            QueryConsolidation::from(zenoh::query::ConsolidationMode::None)
439        } else {
440            QueryConsolidation::from(zenoh::query::ConsolidationMode::Latest)
441        };
442        let raw = parameters.contains_key(RAW_KEY);
443        let mut query = req
444            .state()
445            .0
446            .get(Selector::borrowed(&key_expr, &parameters))
447            .consolidation(consolidation)
448            .with(flume::unbounded());
449        if !body.is_empty() {
450            let encoding: Encoding = req
451                .content_type()
452                .map(|m| Encoding::from(m.to_string()))
453                .unwrap_or_default();
454            query = query.payload(body).encoding(encoding);
455        }
456        match query.await {
457            Ok(receiver) => {
458                if raw {
459                    Ok(to_raw_response(receiver).await)
460                } else if first_accept == "text/html" {
461                    Ok(to_html_response(receiver).await)
462                } else {
463                    Ok(to_json_response(receiver).await)
464                }
465            }
466            Err(e) => Ok(response(
467                StatusCode::InternalServerError,
468                "text/plain",
469                &e.to_string(),
470            )),
471        }
472    }
473}
474
475async fn write(mut req: Request<(Arc<Session>, String)>) -> tide::Result<Response> {
476    tracing::trace!("Incoming PUT request: {:?}", req);
477    match req.body_bytes().await {
478        Ok(bytes) => {
479            let key_expr = match path_to_key_expr(req.url().path(), &req.state().1) {
480                Ok(ke) => ke,
481                Err(e) => {
482                    return Ok(response(
483                        StatusCode::BadRequest,
484                        "text/plain",
485                        &e.to_string(),
486                    ))
487                }
488            };
489
490            let encoding: Encoding = req
491                .content_type()
492                .map(|m| Encoding::from(m.to_string()))
493                .unwrap_or_default();
494
495            // @TODO: Define the right congestion control value
496            let session = &req.state().0;
497            let res = match method_to_kind(req.method()) {
498                SampleKind::Put => session.put(&key_expr, bytes).encoding(encoding).await,
499                SampleKind::Delete => session.delete(&key_expr).await,
500            };
501            match res {
502                Ok(_) => Ok(Response::new(StatusCode::Ok)),
503                Err(e) => Ok(response(
504                    StatusCode::InternalServerError,
505                    "text/plain",
506                    &e.to_string(),
507                )),
508            }
509        }
510        Err(e) => Ok(response(
511            StatusCode::NoContent,
512            "text/plain",
513            &e.to_string(),
514        )),
515    }
516}
517
518pub async fn run(runtime: Runtime, conf: Config) -> ZResult<()> {
519    // Try to initiate login.
520    // Required in case of dynamic lib, otherwise no logs.
521    // But cannot be done twice in case of static link.
522    zenoh::init_log_from_env_or("error");
523
524    let zid = runtime.zid().to_string();
525    let session = zenoh::session::init(runtime).await.unwrap();
526
527    let mut app = Server::with_state((Arc::new(session), zid));
528    app.with(
529        tide::security::CorsMiddleware::new()
530            .allow_methods(
531                "GET, POST, PUT, PATCH, DELETE"
532                    .parse::<http_types::headers::HeaderValue>()
533                    .unwrap(),
534            )
535            .allow_origin(tide::security::Origin::from("*"))
536            .allow_credentials(false),
537    );
538
539    app.at("/")
540        .get(query)
541        .post(query)
542        .put(write)
543        .patch(write)
544        .delete(write);
545    app.at("*")
546        .get(query)
547        .post(query)
548        .put(write)
549        .patch(write)
550        .delete(write);
551
552    if let Err(e) = app.listen(conf.http_port).await {
553        tracing::error!("Unable to start http server for REST: {:?}", e);
554        return Err(e.into());
555    }
556    Ok(())
557}
558
559fn path_to_key_expr<'a>(path: &'a str, zid: &str) -> ZResult<KeyExpr<'a>> {
560    let path = path.strip_prefix('/').unwrap_or(path);
561    if path == "@/local" {
562        KeyExpr::try_from(format!("@/{zid}"))
563    } else if let Some(suffix) = path.strip_prefix("@/local/") {
564        KeyExpr::try_from(format!("@/{zid}/{suffix}"))
565    } else {
566        KeyExpr::try_from(path)
567    }
568}