1use crate::proxy::memory::MemoryProxy;
16use crate::proxy::openai_compat::{
17 OpenAiChoice, OpenAiError, OpenAiErrorResponse, OpenAiRequest, OpenAiResponse,
18 OpenAiResponseMessage, OpenAiUsage,
19};
20use crate::proxy::LlmRequest;
21use anyhow::Result;
22use axum::{
23 extract::State,
24 http::{Request, StatusCode},
25 middleware::{self, Next},
26 response::{IntoResponse, Response},
27 routing::{get, post},
28 Json, Router,
29};
30use serde::Serialize;
31use std::net::SocketAddr;
32use std::sync::Arc;
33use tokio::sync::RwLock;
34
35type SharedProxy = Arc<RwLock<MemoryProxy>>;
36
37pub async fn start_proxy_server(port: u16) -> Result<()> {
39 let proxy: SharedProxy = Arc::new(RwLock::new(MemoryProxy::new()?));
40
41 let app = Router::new()
42 .route("/v1/chat/completions", post(chat_completions))
43 .route("/v1/models", get(list_models))
44 .route("/admin/status", get(admin_status))
45 .layer(middleware::from_fn(bearer_auth))
46 .with_state(proxy);
47
48 let addr = SocketAddr::from(([127, 0, 0, 1], port));
49 println!("Smart Tree LLM Proxy on http://{}", addr);
50 println!(" POST /v1/chat/completions");
51 println!(" GET /v1/models");
52 println!(" GET /admin/status");
53 if std::env::var("ST_PROXY_API_KEY").is_ok() {
54 println!(" Auth: bearer (ST_PROXY_API_KEY required)");
55 } else {
56 println!(" Auth: open (set ST_PROXY_API_KEY to require Bearer token)");
57 }
58
59 let listener = tokio::net::TcpListener::bind(addr).await?;
60 axum::serve(listener, app).await?;
61 Ok(())
62}
63
64async fn bearer_auth(req: Request<axum::body::Body>, next: Next) -> Response {
66 let Ok(expected) = std::env::var("ST_PROXY_API_KEY") else {
67 return next.run(req).await;
68 };
69
70 let provided = req
71 .headers()
72 .get("authorization")
73 .and_then(|v| v.to_str().ok())
74 .and_then(|v| v.strip_prefix("Bearer "))
75 .map(|s| s.trim().to_string());
76
77 match provided {
78 Some(token) if token == expected => next.run(req).await,
79 _ => (
80 StatusCode::UNAUTHORIZED,
81 Json(OpenAiErrorResponse {
82 error: OpenAiError {
83 message: "missing or invalid bearer token".into(),
84 error_type: "authentication_error".into(),
85 code: Some("invalid_api_key".into()),
86 },
87 }),
88 )
89 .into_response(),
90 }
91}
92
93async fn chat_completions(
94 State(proxy): State<SharedProxy>,
95 Json(req): Json<OpenAiRequest>,
96) -> Response {
97 let (provider_name, model_name) = match req.model.split_once('/') {
98 Some((p, m)) => (p.to_string(), m.to_string()),
99 None => ("openai".to_string(), req.model.clone()),
100 };
101
102 let internal_req = LlmRequest {
103 model: model_name,
104 messages: req.messages.into_iter().map(Into::into).collect(),
105 temperature: req.temperature,
106 max_tokens: req.max_tokens,
107 stream: req.stream.unwrap_or(false),
108 };
109
110 let scope_id = req.user.unwrap_or_else(|| "global".to_string());
111
112 let mut proxy_lock = proxy.write().await;
113 match proxy_lock
114 .complete_with_memory(&provider_name, &scope_id, internal_req)
115 .await
116 {
117 Ok(resp) => (
118 StatusCode::OK,
119 Json(OpenAiResponse {
120 id: format!("st-{}", uuid::Uuid::new_v4()),
121 object: "chat.completion".to_string(),
122 created: chrono::Utc::now().timestamp() as u64,
123 model: req.model,
124 choices: vec![OpenAiChoice {
125 index: 0,
126 message: OpenAiResponseMessage {
127 role: "assistant".to_string(),
128 content: resp.content,
129 },
130 finish_reason: "stop".to_string(),
131 }],
132 usage: resp.usage.map(|u| OpenAiUsage {
133 prompt_tokens: u.prompt_tokens,
134 completion_tokens: u.completion_tokens,
135 total_tokens: u.total_tokens,
136 }),
137 }),
138 )
139 .into_response(),
140 Err(e) => {
141 let msg = e.to_string();
142 let status = if msg.contains("not found") || msg.contains("invalid") {
143 StatusCode::BAD_REQUEST
144 } else if msg.contains("unauthorized") || msg.contains("authentication") {
145 StatusCode::UNAUTHORIZED
146 } else {
147 StatusCode::INTERNAL_SERVER_ERROR
148 };
149 (
150 status,
151 Json(OpenAiErrorResponse {
152 error: OpenAiError {
153 message: msg,
154 error_type: "api_error".into(),
155 code: None,
156 },
157 }),
158 )
159 .into_response()
160 }
161 }
162}
163
164#[derive(Serialize)]
165struct ModelEntry {
166 id: String,
167 object: &'static str,
168 owned_by: String,
169}
170
171#[derive(Serialize)]
172struct ModelList {
173 object: &'static str,
174 data: Vec<ModelEntry>,
175}
176
177async fn list_models(State(proxy): State<SharedProxy>) -> Response {
181 let lock = proxy.read().await;
182 let data: Vec<ModelEntry> = lock
183 .inner
184 .list_providers()
185 .into_iter()
186 .map(|p| ModelEntry {
187 id: format!("{}/default", p.to_lowercase()),
188 object: "model",
189 owned_by: p.to_string(),
190 })
191 .collect();
192
193 Json(ModelList {
194 object: "list",
195 data,
196 })
197 .into_response()
198}
199
200#[derive(Serialize)]
201struct AdminStatus {
202 running: bool,
203 auth_required: bool,
204 providers: Vec<&'static str>,
205 version: &'static str,
206}
207
208async fn admin_status(State(proxy): State<SharedProxy>) -> Response {
209 let lock = proxy.read().await;
210 Json(AdminStatus {
211 running: true,
212 auth_required: std::env::var("ST_PROXY_API_KEY").is_ok(),
213 providers: lock.inner.list_providers(),
214 version: env!("CARGO_PKG_VERSION"),
215 })
216 .into_response()
217}