Skip to main content

tangram_core/
lib.rs

1pub mod bbox;
2#[cfg(feature = "channel")]
3pub mod channel;
4pub mod shutdown;
5pub mod stream;
6
7#[cfg(feature = "channel")]
8use crate::channel::{
9    utils::{generate_jwt, random_string},
10    websocket::{axum_on_connected, launch_channel_redis_listen_task, State},
11    ChannelControl,
12};
13#[cfg(feature = "channel")]
14use axum::{
15    extract::{Query, State as AxumState, WebSocketUpgrade},
16    response::IntoResponse,
17    routing::get,
18    Json, Router,
19};
20#[cfg(feature = "pyo3")]
21use pyo3::{
22    exceptions::{PyConnectionError, PyOSError},
23    prelude::*,
24};
25
26#[cfg(feature = "stubgen")]
27use pyo3_stub_gen::derive::*;
28
29#[cfg(feature = "channel")]
30use serde::{Deserialize, Serialize};
31#[cfg(feature = "channel")]
32use std::sync::Arc;
33#[cfg(feature = "channel")]
34use thiserror::Error;
35#[cfg(feature = "channel")]
36use tokio::sync::Mutex;
37#[cfg(feature = "channel")]
38use tower_http::{
39    cors::{Any, CorsLayer},
40    trace::TraceLayer,
41};
42
43#[cfg(feature = "pyo3")]
44use tracing_subscriber::{fmt, prelude::*, EnvFilter};
45
46#[cfg(feature = "pyo3")]
47#[cfg_attr(feature = "stubgen", gen_stub_pyfunction)]
48#[pyfunction]
49fn init_tracing_stderr(filter_str: String) -> PyResult<()> {
50    tracing_subscriber::registry()
51        .with(EnvFilter::new(filter_str))
52        .with(fmt::layer().with_writer(std::io::stderr))
53        .try_init()
54        .map_err(|e| PyOSError::new_err(e.to_string()))
55}
56
57#[cfg(feature = "channel")]
58#[cfg_attr(feature = "stubgen", gen_stub_pyclass)]
59#[cfg_attr(feature = "pyo3", pyclass(get_all, set_all))]
60#[derive(Debug, Clone)]
61pub struct ChannelConfig {
62    pub host: String,
63    pub port: u16,
64    pub redis_url: String,
65    pub jwt_secret: String,
66    pub jwt_expiration_secs: i64,
67    pub id_length: u8,
68}
69
70#[cfg(feature = "channel")]
71#[cfg(feature = "pyo3")]
72#[cfg_attr(feature = "stubgen", gen_stub_pymethods)]
73#[pymethods]
74impl ChannelConfig {
75    #[new]
76    fn new(
77        host: String,
78        port: u16,
79        redis_url: String,
80        jwt_secret: String,
81        jwt_expiration_secs: i64,
82        id_length: u8,
83    ) -> Self {
84        Self {
85            host,
86            port,
87            redis_url,
88            jwt_secret,
89            jwt_expiration_secs,
90            id_length,
91        }
92    }
93}
94
95#[cfg(feature = "channel")]
96#[cfg(feature = "pyo3")]
97#[cfg_attr(feature = "stubgen", gen_stub_pyfunction)]
98#[pyfunction]
99fn run(py: Python<'_>, config: ChannelConfig) -> PyResult<Bound<'_, PyAny>> {
100    pyo3_async_runtimes::tokio::future_into_py(py, async move {
101        run_server(config).await.map_err(|e| e.into())
102    })
103}
104
105/// Sets up a persistent channel that must listen for Redis events from startup.
106/// Unlike dynamic channels, these are needed to relay backend-initiated
107/// messages (e.g., system time) regardless of client connections.
108#[cfg(feature = "channel")]
109async fn setup_persistent_channel(name: &str, state: &Arc<State>, redis_client: &redis::Client) {
110    state.ctl.lock().await.channel_add(name.into(), None).await;
111    launch_channel_redis_listen_task(
112        state.clone(),
113        &state.ctl,
114        name.to_string(),
115        redis_client.clone(),
116    )
117    .await;
118}
119
120#[cfg(feature = "channel")]
121async fn shutdown_signal() {
122    tokio::signal::ctrl_c()
123        .await
124        .expect("failed to install CTRL+C signal handler");
125}
126
127#[cfg(feature = "channel")]
128pub async fn run_server(config: ChannelConfig) -> Result<(), ChannelError> {
129    let redis_client =
130        redis::Client::open(config.redis_url.clone()).map_err(ChannelError::Redis)?;
131
132    let channel_control = ChannelControl::new(Arc::new(redis_client.clone()));
133    let state = Arc::new(State {
134        ctl: Mutex::new(channel_control),
135        redis_client: redis_client.clone(),
136        id_length: config.id_length,
137        jwt_secret: config.jwt_secret,
138        jwt_expiration_secs: config.jwt_expiration_secs,
139    });
140
141    state
142        .ctl
143        .lock()
144        .await
145        .channel_add("phoenix".into(), None)
146        .await;
147    setup_persistent_channel("system", &state, &redis_client).await;
148    setup_persistent_channel("admin", &state, &redis_client).await;
149
150    // TODO: allow this to be configurable
151    let cors = CorsLayer::new()
152        .allow_methods(Any)
153        .allow_headers(Any)
154        .allow_origin(Any);
155
156    let app = Router::new()
157        .route("/token", axum::routing::post(generate_token_handler))
158        .route("/websocket", get(websocket_handler))
159        .with_state(state)
160        .layer(TraceLayer::new_for_http())
161        .layer(cors);
162
163    let addr = format!("{}:{}", config.host, config.port);
164    let listener = tokio::net::TcpListener::bind(&addr).await?;
165
166    tracing::info!("channel service listening on {}", addr);
167    axum::serve(listener, app.into_make_service())
168        .with_graceful_shutdown(shutdown_signal())
169        .await?;
170
171    Ok(())
172}
173
174#[cfg(feature = "channel")]
175#[derive(Debug, Clone, Deserialize)]
176struct TokenRequest {
177    channel: String,
178    id: Option<String>,
179}
180
181#[cfg(feature = "channel")]
182#[derive(Debug, Clone, Serialize)]
183struct TokenResponse {
184    channel: String,
185    id: String,
186    token: String,
187}
188
189#[cfg(feature = "channel")]
190async fn generate_token_handler(
191    AxumState(state): AxumState<Arc<State>>,
192    Json(payload): Json<TokenRequest>,
193) -> impl IntoResponse {
194    let client_id = payload.id.unwrap_or_else(|| random_string(8));
195    match generate_jwt(
196        client_id.clone(),
197        payload.channel.clone(),
198        state.jwt_secret.clone(),
199        state.jwt_expiration_secs,
200    )
201    .await
202    {
203        Ok(token) => Json(TokenResponse {
204            channel: payload.channel,
205            id: client_id,
206            token,
207        })
208        .into_response(),
209        Err(e) => (
210            axum::http::StatusCode::INTERNAL_SERVER_ERROR,
211            format!("failed to generate token: {e}"),
212        )
213            .into_response(),
214    }
215}
216
217#[cfg(feature = "channel")]
218#[derive(Debug, Clone, Deserialize)]
219struct WebSocketParams {
220    #[serde(rename = "userToken")]
221    user_token: Option<String>,
222}
223
224#[cfg(feature = "channel")]
225async fn websocket_handler(
226    ws: WebSocketUpgrade,
227    AxumState(state): AxumState<Arc<State>>,
228    Query(params): Query<WebSocketParams>,
229) -> impl IntoResponse {
230    let user_token = params.user_token;
231    ws.on_upgrade(move |socket| axum_on_connected(socket, state, user_token))
232}
233
234#[cfg(feature = "channel")]
235#[derive(Error, Debug)]
236pub enum ChannelError {
237    #[error("redis error: {0}")]
238    Redis(#[from] redis::RedisError),
239    #[error("server bind error: {0}")]
240    ServerBind(#[from] std::io::Error),
241}
242
243#[cfg(feature = "channel")]
244#[cfg(feature = "pyo3")]
245impl From<ChannelError> for PyErr {
246    fn from(err: ChannelError) -> Self {
247        match err {
248            ChannelError::Redis(e) => PyConnectionError::new_err(e.to_string()),
249            ChannelError::ServerBind(e) => PyOSError::new_err(e.to_string()),
250        }
251    }
252}
253
254#[cfg(feature = "pyo3")]
255#[pymodule]
256fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
257    m.add_function(wrap_pyfunction!(init_tracing_stderr, m)?)?;
258
259    #[cfg(feature = "channel")]
260    {
261        m.add_function(wrap_pyfunction!(run, m)?)?;
262        m.add_class::<ChannelConfig>()?;
263    }
264    Ok(())
265}
266
267#[cfg(feature = "stubgen")]
268// not using define_stub_info_gatherer! macro, we need to
269// go up one level from `packages/tangram_core/rust` to `packages/tangram_core`
270pub fn stub_info() -> pyo3_stub_gen::Result<pyo3_stub_gen::StubInfo> {
271    let manifest_dir: &::std::path::Path = env!("CARGO_MANIFEST_DIR").as_ref();
272    let pyproject_path = manifest_dir.parent().unwrap().join("pyproject.toml");
273    pyo3_stub_gen::StubInfo::from_pyproject_toml(pyproject_path)
274}