1pub mod bbox;
2#[cfg(feature = "channel")]
3pub mod channel;
4pub mod shutdown;
5pub mod stream;
6
7#[cfg(feature = "channel")]
8use crate::channel::{
9 utils::{generate_jwt, random_string},
10 websocket::{axum_on_connected, launch_channel_redis_listen_task, State},
11 ChannelControl,
12};
13#[cfg(feature = "channel")]
14use axum::{
15 extract::{Query, State as AxumState, WebSocketUpgrade},
16 response::IntoResponse,
17 routing::get,
18 Json, Router,
19};
20#[cfg(feature = "pyo3")]
21use pyo3::{
22 exceptions::{PyConnectionError, PyOSError},
23 prelude::*,
24};
25
26#[cfg(feature = "stubgen")]
27use pyo3_stub_gen::derive::*;
28
29#[cfg(feature = "channel")]
30use serde::{Deserialize, Serialize};
31#[cfg(feature = "channel")]
32use std::sync::Arc;
33#[cfg(feature = "channel")]
34use thiserror::Error;
35#[cfg(feature = "channel")]
36use tokio::sync::Mutex;
37#[cfg(feature = "channel")]
38use tower_http::{
39 cors::{Any, CorsLayer},
40 trace::TraceLayer,
41};
42
43#[cfg(feature = "pyo3")]
44use tracing_subscriber::{fmt, prelude::*, EnvFilter};
45
46#[cfg(feature = "pyo3")]
47#[cfg_attr(feature = "stubgen", gen_stub_pyfunction)]
48#[pyfunction]
49fn init_tracing_stderr(filter_str: String) -> PyResult<()> {
50 tracing_subscriber::registry()
51 .with(EnvFilter::new(filter_str))
52 .with(fmt::layer().with_writer(std::io::stderr))
53 .try_init()
54 .map_err(|e| PyOSError::new_err(e.to_string()))
55}
56
57#[cfg(feature = "channel")]
58#[cfg_attr(feature = "stubgen", gen_stub_pyclass)]
59#[cfg_attr(feature = "pyo3", pyclass(get_all, set_all))]
60#[derive(Debug, Clone)]
61pub struct ChannelConfig {
62 pub host: String,
63 pub port: u16,
64 pub redis_url: String,
65 pub jwt_secret: String,
66 pub jwt_expiration_secs: i64,
67 pub id_length: u8,
68}
69
70#[cfg(feature = "channel")]
71#[cfg(feature = "pyo3")]
72#[cfg_attr(feature = "stubgen", gen_stub_pymethods)]
73#[pymethods]
74impl ChannelConfig {
75 #[new]
76 fn new(
77 host: String,
78 port: u16,
79 redis_url: String,
80 jwt_secret: String,
81 jwt_expiration_secs: i64,
82 id_length: u8,
83 ) -> Self {
84 Self {
85 host,
86 port,
87 redis_url,
88 jwt_secret,
89 jwt_expiration_secs,
90 id_length,
91 }
92 }
93}
94
95#[cfg(feature = "channel")]
96#[cfg(feature = "pyo3")]
97#[cfg_attr(feature = "stubgen", gen_stub_pyfunction)]
98#[pyfunction]
99fn run(py: Python<'_>, config: ChannelConfig) -> PyResult<Bound<'_, PyAny>> {
100 pyo3_async_runtimes::tokio::future_into_py(py, async move {
101 run_server(config).await.map_err(|e| e.into())
102 })
103}
104
105#[cfg(feature = "channel")]
109async fn setup_persistent_channel(name: &str, state: &Arc<State>, redis_client: &redis::Client) {
110 state.ctl.lock().await.channel_add(name.into(), None).await;
111 launch_channel_redis_listen_task(
112 state.clone(),
113 &state.ctl,
114 name.to_string(),
115 redis_client.clone(),
116 )
117 .await;
118}
119
120#[cfg(feature = "channel")]
121async fn shutdown_signal() {
122 tokio::signal::ctrl_c()
123 .await
124 .expect("failed to install CTRL+C signal handler");
125}
126
127#[cfg(feature = "channel")]
128pub async fn run_server(config: ChannelConfig) -> Result<(), ChannelError> {
129 let redis_client =
130 redis::Client::open(config.redis_url.clone()).map_err(ChannelError::Redis)?;
131
132 let channel_control = ChannelControl::new(Arc::new(redis_client.clone()));
133 let state = Arc::new(State {
134 ctl: Mutex::new(channel_control),
135 redis_client: redis_client.clone(),
136 id_length: config.id_length,
137 jwt_secret: config.jwt_secret,
138 jwt_expiration_secs: config.jwt_expiration_secs,
139 });
140
141 state
142 .ctl
143 .lock()
144 .await
145 .channel_add("phoenix".into(), None)
146 .await;
147 setup_persistent_channel("system", &state, &redis_client).await;
148 setup_persistent_channel("admin", &state, &redis_client).await;
149
150 let cors = CorsLayer::new()
152 .allow_methods(Any)
153 .allow_headers(Any)
154 .allow_origin(Any);
155
156 let app = Router::new()
157 .route("/token", axum::routing::post(generate_token_handler))
158 .route("/websocket", get(websocket_handler))
159 .with_state(state)
160 .layer(TraceLayer::new_for_http())
161 .layer(cors);
162
163 let addr = format!("{}:{}", config.host, config.port);
164 let listener = tokio::net::TcpListener::bind(&addr).await?;
165
166 tracing::info!("channel service listening on {}", addr);
167 axum::serve(listener, app.into_make_service())
168 .with_graceful_shutdown(shutdown_signal())
169 .await?;
170
171 Ok(())
172}
173
174#[cfg(feature = "channel")]
175#[derive(Debug, Clone, Deserialize)]
176struct TokenRequest {
177 channel: String,
178 id: Option<String>,
179}
180
181#[cfg(feature = "channel")]
182#[derive(Debug, Clone, Serialize)]
183struct TokenResponse {
184 channel: String,
185 id: String,
186 token: String,
187}
188
189#[cfg(feature = "channel")]
190async fn generate_token_handler(
191 AxumState(state): AxumState<Arc<State>>,
192 Json(payload): Json<TokenRequest>,
193) -> impl IntoResponse {
194 let client_id = payload.id.unwrap_or_else(|| random_string(8));
195 match generate_jwt(
196 client_id.clone(),
197 payload.channel.clone(),
198 state.jwt_secret.clone(),
199 state.jwt_expiration_secs,
200 )
201 .await
202 {
203 Ok(token) => Json(TokenResponse {
204 channel: payload.channel,
205 id: client_id,
206 token,
207 })
208 .into_response(),
209 Err(e) => (
210 axum::http::StatusCode::INTERNAL_SERVER_ERROR,
211 format!("failed to generate token: {e}"),
212 )
213 .into_response(),
214 }
215}
216
217#[cfg(feature = "channel")]
218#[derive(Debug, Clone, Deserialize)]
219struct WebSocketParams {
220 #[serde(rename = "userToken")]
221 user_token: Option<String>,
222}
223
224#[cfg(feature = "channel")]
225async fn websocket_handler(
226 ws: WebSocketUpgrade,
227 AxumState(state): AxumState<Arc<State>>,
228 Query(params): Query<WebSocketParams>,
229) -> impl IntoResponse {
230 let user_token = params.user_token;
231 ws.on_upgrade(move |socket| axum_on_connected(socket, state, user_token))
232}
233
234#[cfg(feature = "channel")]
235#[derive(Error, Debug)]
236pub enum ChannelError {
237 #[error("redis error: {0}")]
238 Redis(#[from] redis::RedisError),
239 #[error("server bind error: {0}")]
240 ServerBind(#[from] std::io::Error),
241}
242
243#[cfg(feature = "channel")]
244#[cfg(feature = "pyo3")]
245impl From<ChannelError> for PyErr {
246 fn from(err: ChannelError) -> Self {
247 match err {
248 ChannelError::Redis(e) => PyConnectionError::new_err(e.to_string()),
249 ChannelError::ServerBind(e) => PyOSError::new_err(e.to_string()),
250 }
251 }
252}
253
254#[cfg(feature = "pyo3")]
255#[pymodule]
256fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
257 m.add_function(wrap_pyfunction!(init_tracing_stderr, m)?)?;
258
259 #[cfg(feature = "channel")]
260 {
261 m.add_function(wrap_pyfunction!(run, m)?)?;
262 m.add_class::<ChannelConfig>()?;
263 }
264 Ok(())
265}
266
267#[cfg(feature = "stubgen")]
268pub fn stub_info() -> pyo3_stub_gen::Result<pyo3_stub_gen::StubInfo> {
271 let manifest_dir: &::std::path::Path = env!("CARGO_MANIFEST_DIR").as_ref();
272 let pyproject_path = manifest_dir.parent().unwrap().join("pyproject.toml");
273 pyo3_stub_gen::StubInfo::from_pyproject_toml(pyproject_path)
274}