1use super::{api, assets, collab, state_sync, voice, websocket, DashboardState, SharedState};
4use crate::in_memory_logger::InMemoryLogStore;
5use anyhow::Result;
6use axum::{
7 body::Body,
8 extract::ConnectInfo,
9 http::{header, Request, StatusCode},
10 middleware::Next,
11 response::{Html, IntoResponse, Response},
12 routing::{get, post},
13 Router,
14};
15use tower_http::cors::CorsLayer;
16use ipnet::IpNet;
17use std::net::{IpAddr, SocketAddr};
18use std::path::PathBuf;
19use std::sync::Arc;
20use tokio::sync::RwLock;
21
22#[derive(Clone)]
24struct AllowedNetworks {
25 networks: Vec<IpNet>,
26 allow_all: bool,
27}
28
29impl AllowedNetworks {
30 fn new(cidrs: Vec<String>) -> Self {
31 if cidrs.is_empty() {
32 return Self {
34 networks: vec!["127.0.0.0/8".parse().unwrap(), "::1/128".parse().unwrap()],
35 allow_all: false,
36 };
37 }
38
39 let mut networks = Vec::new();
40 let mut allow_all = false;
41
42 for cidr in cidrs {
43 if cidr == "0.0.0.0/0" || cidr == "::/0" || cidr == "any" {
44 allow_all = true;
45 break;
46 }
47 if let Ok(net) = cidr.parse::<IpNet>() {
48 networks.push(net);
49 } else {
50 eprintln!("Warning: Invalid CIDR '{}', ignoring", cidr);
51 }
52 }
53
54 Self {
55 networks,
56 allow_all,
57 }
58 }
59
60 fn is_allowed(&self, ip: IpAddr) -> bool {
61 if self.allow_all {
62 return true;
63 }
64 if ip.is_loopback() {
66 return true;
67 }
68 self.networks.iter().any(|net| net.contains(&ip))
69 }
70}
71
72async fn check_allowed_network(
74 ConnectInfo(addr): ConnectInfo<SocketAddr>,
75 req: Request<Body>,
76 next: Next,
77) -> Result<Response, StatusCode> {
78 let allowed = req
79 .extensions()
80 .get::<AllowedNetworks>()
81 .map(|nets| nets.is_allowed(addr.ip()))
82 .unwrap_or(true);
83
84 if allowed {
85 Ok(next.run(req).await)
86 } else {
87 eprintln!("Rejected connection from {}", addr.ip());
88 Err(StatusCode::FORBIDDEN)
89 }
90}
91
92pub async fn start_server(
94 port: u16,
95 open_browser: bool,
96 allow_networks: Vec<String>,
97 log_store: InMemoryLogStore,
98) -> Result<()> {
99 let cwd = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
100 let state: SharedState = Arc::new(RwLock::new(DashboardState::new(cwd, log_store)));
101
102 let has_explicit_networks = !allow_networks.is_empty();
103 let allowed = AllowedNetworks::new(allow_networks.clone());
104 let bind_all = has_explicit_networks || allowed.allow_all;
105
106 let app = Router::new()
107 .route("/", get(serve_index))
109 .route("/style.css", get(serve_css))
110 .route("/app.js", get(serve_js))
111 .route("/xterm.min.js", get(serve_xterm_js))
112 .route("/xterm.css", get(serve_xterm_css))
113 .route("/xterm-addon-fit.min.js", get(serve_xterm_fit_js))
114 .route("/marked.min.js", get(serve_marked_js))
115 .route("/api/health", get(api::health))
117 .route("/api/files", get(api::list_files))
118 .route("/api/file", get(api::read_file))
119 .route("/api/file", post(api::write_file))
120 .route("/api/tree", get(api::get_tree))
121 .route("/api/markdown", get(api::render_markdown))
122 .route("/api/logs", get(api::get_logs))
123 .route(
125 "/api/config/layout",
126 get(api::get_layout_config).post(api::save_layout_config),
127 )
128 .route(
129 "/api/config/theme",
130 get(api::get_theme_config).post(api::save_theme_config),
131 )
132 .route("/api/prompt", get(api::get_active_prompts).post(api::ask_prompt))
134 .route("/api/prompt/:prompt_id/answer", post(api::answer_prompt))
135 .route("/ws/terminal", get(websocket::terminal_handler))
137 .route("/ws/state", get(state_sync::state_handler))
138 .route("/ws/collab", get(collab::collab_handler))
139 .route("/api/voice/transcribe", post(voice::transcribe))
141 .route("/api/voice/register", post(voice::register_speaker))
142 .route("/api/voice/speak", post(voice::speak))
143 .layer(axum::Extension(allowed.clone()))
144 .layer(CorsLayer::permissive())
145 .with_state(state);
146
147 let bind_addr: IpAddr = if bind_all {
149 [0, 0, 0, 0].into()
150 } else {
151 [127, 0, 0, 1].into()
152 };
153 let addr = SocketAddr::from((bind_addr, port));
154
155 println!("\x1b[32m");
156 println!(" ╔══════════════════════════════════════════════════════╗");
157 println!(" ║ Smart Tree Web Dashboard ║");
158 println!(" ╠══════════════════════════════════════════════════════╣");
159 if bind_all {
160 println!(
161 " ║ http://0.0.0.0:{} ║",
162 port
163 );
164 println!(" ║ ║");
165 println!(" ║ Allowed networks: ║");
166 if allowed.allow_all {
167 println!(" ║ ANY (0.0.0.0/0) ║");
168 } else {
169 for net in &allowed.networks {
170 if !net.addr().is_loopback() {
171 println!(" ║ {} ║", net);
172 }
173 }
174 }
175 } else {
176 println!(
177 " ║ http://127.0.0.1:{} ║",
178 port
179 );
180 println!(" ║ ║");
181 println!(" ║ Localhost only (use --allow for network access) ║");
182 }
183 println!(" ║ ║");
184 println!(" ║ Terminal: Real PTY with bash/zsh ║");
185 println!(" ║ Files: Browse and edit ║");
186 println!(" ║ Preview: Markdown rendering ║");
187 println!(" ╚══════════════════════════════════════════════════════╝");
188 println!("\x1b[0m");
189
190 if open_browser {
191 let url = format!("http://127.0.0.1:{}", port);
192 if let Err(e) = open::that(&url) {
193 eprintln!("Failed to open browser: {}", e);
194 }
195 }
196
197 let listener = tokio::net::TcpListener::bind(addr).await?;
198 axum::serve(
199 listener,
200 app.into_make_service_with_connect_info::<SocketAddr>(),
201 )
202 .await?;
203
204 Ok(())
205}
206
207async fn serve_index() -> Html<&'static str> {
209 Html(assets::INDEX_HTML)
210}
211
212async fn serve_css() -> impl IntoResponse {
213 ([(header::CONTENT_TYPE, "text/css")], assets::STYLE_CSS)
214}
215
216async fn serve_js() -> impl IntoResponse {
217 (
218 [(header::CONTENT_TYPE, "application/javascript")],
219 assets::APP_JS,
220 )
221}
222
223async fn serve_xterm_js() -> impl IntoResponse {
224 (
225 [(header::CONTENT_TYPE, "application/javascript")],
226 assets::XTERM_JS,
227 )
228}
229
230async fn serve_xterm_css() -> impl IntoResponse {
231 ([(header::CONTENT_TYPE, "text/css")], assets::XTERM_CSS)
232}
233
234async fn serve_xterm_fit_js() -> impl IntoResponse {
235 (
236 [(header::CONTENT_TYPE, "application/javascript")],
237 assets::XTERM_FIT_JS,
238 )
239}
240
241async fn serve_marked_js() -> impl IntoResponse {
242 (
243 [(header::CONTENT_TYPE, "application/javascript")],
244 assets::MARKED_JS,
245 )
246}