1use std::sync::Arc;
4
5use axum::Router;
6use axum::body::Body;
7use axum::extract::{DefaultBodyLimit, State};
8use axum::http::{Method, Request, StatusCode};
9use axum::middleware::{self, Next};
10use axum::response::{IntoResponse, Response};
11use axum::routing::{get, post};
12
13use haystack_core::auth::{AuthHeader, parse_auth_header};
14use haystack_core::graph::SharedGraph;
15use haystack_core::ontology::DefNamespace;
16
17use crate::actions::ActionRegistry;
18use crate::auth::AuthManager;
19use crate::his_store::HisStore;
20use crate::ops;
21use crate::state::{AppState, SharedState};
22use crate::ws;
23use crate::ws::WatchManager;
24
25pub struct HaystackServer {
27 graph: SharedGraph,
28 namespace: DefNamespace,
29 auth_manager: AuthManager,
30 actions: ActionRegistry,
31 custom_router: Option<Router<SharedState>>,
32 authenticated_router: Option<Router<SharedState>>,
33 history_provider: Option<Box<dyn crate::his_provider::HistoryProvider>>,
34 port: u16,
35 host: String,
36}
37
38impl HaystackServer {
39 pub fn new(graph: SharedGraph) -> Self {
41 Self {
42 graph,
43 namespace: DefNamespace::new(),
44 auth_manager: AuthManager::empty(),
45 actions: ActionRegistry::new(),
46 custom_router: None,
47 authenticated_router: None,
48 history_provider: None,
49 port: 8080,
50 host: "127.0.0.1".to_string(),
51 }
52 }
53
54 pub fn with_namespace(mut self, ns: DefNamespace) -> Self {
56 self.namespace = ns;
57 self
58 }
59
60 pub fn with_auth(mut self, auth: AuthManager) -> Self {
62 self.auth_manager = auth;
63 self
64 }
65
66 pub fn port(mut self, port: u16) -> Self {
68 self.port = port;
69 self
70 }
71
72 pub fn host(mut self, host: &str) -> Self {
74 self.host = host.to_string();
75 self
76 }
77
78 pub fn with_actions(mut self, actions: ActionRegistry) -> Self {
80 self.actions = actions;
81 self
82 }
83
84 pub fn with_router(mut self, router: Router<SharedState>) -> Self {
93 self.custom_router = Some(router);
94 self
95 }
96
97 pub fn with_authenticated_router(mut self, router: Router<SharedState>) -> Self {
102 self.authenticated_router = Some(router);
103 self
104 }
105
106 pub fn with_history_provider(
108 mut self,
109 provider: Box<dyn crate::his_provider::HistoryProvider>,
110 ) -> Self {
111 self.history_provider = Some(provider);
112 self
113 }
114
115 pub async fn run(self) -> std::io::Result<()> {
117 let his: Box<dyn crate::his_provider::HistoryProvider> = self
118 .history_provider
119 .unwrap_or_else(|| Box::new(HisStore::new()));
120
121 let state: SharedState = Arc::new(AppState {
122 graph: self.graph,
123 namespace: parking_lot::RwLock::new(self.namespace),
124 auth: self.auth_manager,
125 watches: WatchManager::new(),
126 actions: self.actions,
127 his,
128 started_at: std::time::Instant::now(),
129 });
130
131 let mut core_router = Router::new()
132 .route("/api/about", get(ops::about::handle))
134 .route("/api/ops", get(ops::ops_handler::handle))
135 .route("/api/formats", get(ops::formats::handle))
136 .route("/api/ws", get(ws::ws_handler))
137 .route("/api/read", post(ops::read::handle))
139 .route("/api/nav", post(ops::nav::handle))
140 .route("/api/defs", post(ops::defs::handle))
141 .route("/api/libs", post(ops::defs::handle_libs))
142 .route("/api/hisRead", post(ops::his::handle_read))
143 .route("/api/hisWrite", post(ops::his::handle_write))
144 .route("/api/watchSub", post(ops::watch::handle_sub))
145 .route("/api/watchPoll", post(ops::watch::handle_poll))
146 .route("/api/watchUnsub", post(ops::watch::handle_unsub))
147 .route("/api/pointWrite", post(ops::point_write::handle))
148 .route("/api/invokeAction", post(ops::invoke::handle))
149 .route("/api/close", post(ops::about::handle_close))
150 .route("/api/import", post(ops::data::handle_import))
151 .route("/api/export", post(ops::data::handle_export))
152 .route("/api/validate", post(ops::libs::handle_validate))
153 .route("/api/specs", post(ops::libs::handle_specs))
154 .route("/api/spec", post(ops::libs::handle_spec))
155 .route("/api/loadLib", post(ops::libs::handle_load_lib))
156 .route("/api/unloadLib", post(ops::libs::handle_unload_lib))
157 .route("/api/exportLib", post(ops::libs::handle_export_lib))
158 .route("/api/changes", post(ops::changes::handle));
159
160 if let Some(auth_router) = self.authenticated_router {
163 core_router = core_router.merge(auth_router);
164 }
165
166 let mut app = core_router
167 .route_layer(middleware::from_fn_with_state(
168 state.clone(),
169 auth_middleware,
170 ))
171 .layer(DefaultBodyLimit::max(2 * 1024 * 1024))
172 .with_state(state.clone());
173
174 if let Some(custom) = self.custom_router {
175 app = app.merge(custom.with_state(state));
176 }
177
178 log::info!("Starting haystack-server on {}:{}", self.host, self.port);
179
180 let listener =
181 tokio::net::TcpListener::bind(format!("{}:{}", self.host, self.port)).await?;
182 axum::serve(listener, app).await
183 }
184}
185
186fn required_permission(path: &str) -> Option<&'static str> {
191 match path {
193 "/api/pointWrite" | "/api/hisWrite" | "/api/invokeAction" | "/api/loadLib"
194 | "/api/unloadLib" | "/api/import" => return Some("write"),
195 _ => {}
196 }
197
198 Some("read")
203}
204
205async fn auth_middleware(
212 State(state): State<SharedState>,
213 mut req: Request<Body>,
214 next: Next,
215) -> Response {
216 let path = req.uri().path().to_string();
217 let method = req.method().clone();
218
219 if path == "/api/about" {
221 return next.run(req).await;
222 }
223
224 if (path == "/api/ops" || path == "/api/formats") && method == Method::GET {
226 return next.run(req).await;
227 }
228
229 if !state.auth.is_enabled() {
231 return next.run(req).await;
232 }
233
234 let auth_header = req
236 .headers()
237 .get("Authorization")
238 .and_then(|v| v.to_str().ok())
239 .map(|s| s.to_string());
240
241 match auth_header {
242 Some(header) => match parse_auth_header(&header) {
243 Ok(AuthHeader::Bearer { auth_token }) => {
244 match state.auth.validate_token(&auth_token) {
245 Some(auth_user) => {
246 if let Some(required) = required_permission(&path)
248 && !AuthManager::check_permission(&auth_user, required)
249 {
250 return crate::error::HaystackError::forbidden(format!(
251 "insufficient '{}' permission",
252 required
253 ))
254 .into_response();
255 }
256
257 req.extensions_mut().insert(auth_user);
259 next.run(req).await
260 }
261 None => crate::error::HaystackError::new(
262 "invalid or expired auth token",
263 StatusCode::UNAUTHORIZED,
264 )
265 .into_response(),
266 }
267 }
268 _ => {
269 crate::error::HaystackError::new("BEARER token required", StatusCode::UNAUTHORIZED)
270 .into_response()
271 }
272 },
273 None => crate::error::HaystackError::new(
274 "Authorization header required",
275 StatusCode::UNAUTHORIZED,
276 )
277 .into_response(),
278 }
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284
285 #[test]
286 fn required_permission_read_ops() {
287 assert_eq!(required_permission("/api/read"), Some("read"));
288 assert_eq!(required_permission("/api/nav"), Some("read"));
289 assert_eq!(required_permission("/api/defs"), Some("read"));
290 assert_eq!(required_permission("/api/libs"), Some("read"));
291 assert_eq!(required_permission("/api/hisRead"), Some("read"));
292 assert_eq!(required_permission("/api/watchSub"), Some("read"));
293 assert_eq!(required_permission("/api/watchPoll"), Some("read"));
294 assert_eq!(required_permission("/api/watchUnsub"), Some("read"));
295 assert_eq!(required_permission("/api/close"), Some("read"));
296 assert_eq!(required_permission("/api/about"), Some("read"));
297 assert_eq!(required_permission("/api/ops"), Some("read"));
298 assert_eq!(required_permission("/api/formats"), Some("read"));
299 }
300
301 #[test]
302 fn required_permission_write_ops() {
303 assert_eq!(required_permission("/api/pointWrite"), Some("write"));
304 assert_eq!(required_permission("/api/hisWrite"), Some("write"));
305 assert_eq!(required_permission("/api/invokeAction"), Some("write"));
306 assert_eq!(required_permission("/api/import"), Some("write"));
307 }
308}