Skip to main content

st/proxy/
server.rs

1//! OpenAI-Compatible Proxy Server
2//!
3//! HTTP server implementing the OpenAI Chat Completions API plus a small admin
4//! surface for status and provider listing.
5//!
6//! Endpoints:
7//!   POST /v1/chat/completions  - OpenAI-compatible chat
8//!   GET  /v1/models            - List available models (across providers)
9//!   GET  /admin/status         - Proxy status: providers + auth state
10//!
11//! Optional bearer auth: set `ST_PROXY_API_KEY` in the env. When set, every
12//! request must carry `Authorization: Bearer <key>`. When unset, the proxy is
13//! open (loopback only by default).
14
15use crate::proxy::memory::MemoryProxy;
16use crate::proxy::openai_compat::{
17    OpenAiChoice, OpenAiError, OpenAiErrorResponse, OpenAiRequest, OpenAiResponse,
18    OpenAiResponseMessage, OpenAiUsage,
19};
20use crate::proxy::LlmRequest;
21use anyhow::Result;
22use axum::{
23    extract::State,
24    http::{Request, StatusCode},
25    middleware::{self, Next},
26    response::{IntoResponse, Response},
27    routing::{get, post},
28    Json, Router,
29};
30use serde::Serialize;
31use std::net::SocketAddr;
32use std::sync::Arc;
33use tokio::sync::RwLock;
34
35type SharedProxy = Arc<RwLock<MemoryProxy>>;
36
37/// Start the OpenAI-compatible proxy server.
38pub async fn start_proxy_server(port: u16) -> Result<()> {
39    let proxy: SharedProxy = Arc::new(RwLock::new(MemoryProxy::new()?));
40
41    let app = Router::new()
42        .route("/v1/chat/completions", post(chat_completions))
43        .route("/v1/models", get(list_models))
44        .route("/admin/status", get(admin_status))
45        .layer(middleware::from_fn(bearer_auth))
46        .with_state(proxy);
47
48    let addr = SocketAddr::from(([127, 0, 0, 1], port));
49    println!("Smart Tree LLM Proxy on http://{}", addr);
50    println!("  POST /v1/chat/completions");
51    println!("  GET  /v1/models");
52    println!("  GET  /admin/status");
53    if std::env::var("ST_PROXY_API_KEY").is_ok() {
54        println!("  Auth: bearer (ST_PROXY_API_KEY required)");
55    } else {
56        println!("  Auth: open (set ST_PROXY_API_KEY to require Bearer token)");
57    }
58
59    let listener = tokio::net::TcpListener::bind(addr).await?;
60    axum::serve(listener, app).await?;
61    Ok(())
62}
63
64/// Bearer-token middleware. No-op when ST_PROXY_API_KEY is unset.
65async fn bearer_auth(req: Request<axum::body::Body>, next: Next) -> Response {
66    let Ok(expected) = std::env::var("ST_PROXY_API_KEY") else {
67        return next.run(req).await;
68    };
69
70    let provided = req
71        .headers()
72        .get("authorization")
73        .and_then(|v| v.to_str().ok())
74        .and_then(|v| v.strip_prefix("Bearer "))
75        .map(|s| s.trim().to_string());
76
77    match provided {
78        Some(token) if token == expected => next.run(req).await,
79        _ => (
80            StatusCode::UNAUTHORIZED,
81            Json(OpenAiErrorResponse {
82                error: OpenAiError {
83                    message: "missing or invalid bearer token".into(),
84                    error_type: "authentication_error".into(),
85                    code: Some("invalid_api_key".into()),
86                },
87            }),
88        )
89            .into_response(),
90    }
91}
92
93async fn chat_completions(
94    State(proxy): State<SharedProxy>,
95    Json(req): Json<OpenAiRequest>,
96) -> Response {
97    let (provider_name, model_name) = match req.model.split_once('/') {
98        Some((p, m)) => (p.to_string(), m.to_string()),
99        None => ("openai".to_string(), req.model.clone()),
100    };
101
102    let internal_req = LlmRequest {
103        model: model_name,
104        messages: req.messages.into_iter().map(Into::into).collect(),
105        temperature: req.temperature,
106        max_tokens: req.max_tokens,
107        stream: req.stream.unwrap_or(false),
108    };
109
110    let scope_id = req.user.unwrap_or_else(|| "global".to_string());
111
112    let mut proxy_lock = proxy.write().await;
113    match proxy_lock
114        .complete_with_memory(&provider_name, &scope_id, internal_req)
115        .await
116    {
117        Ok(resp) => (
118            StatusCode::OK,
119            Json(OpenAiResponse {
120                id: format!("st-{}", uuid::Uuid::new_v4()),
121                object: "chat.completion".to_string(),
122                created: chrono::Utc::now().timestamp() as u64,
123                model: req.model,
124                choices: vec![OpenAiChoice {
125                    index: 0,
126                    message: OpenAiResponseMessage {
127                        role: "assistant".to_string(),
128                        content: resp.content,
129                    },
130                    finish_reason: "stop".to_string(),
131                }],
132                usage: resp.usage.map(|u| OpenAiUsage {
133                    prompt_tokens: u.prompt_tokens,
134                    completion_tokens: u.completion_tokens,
135                    total_tokens: u.total_tokens,
136                }),
137            }),
138        )
139            .into_response(),
140        Err(e) => {
141            let msg = e.to_string();
142            let status = if msg.contains("not found") || msg.contains("invalid") {
143                StatusCode::BAD_REQUEST
144            } else if msg.contains("unauthorized") || msg.contains("authentication") {
145                StatusCode::UNAUTHORIZED
146            } else {
147                StatusCode::INTERNAL_SERVER_ERROR
148            };
149            (
150                status,
151                Json(OpenAiErrorResponse {
152                    error: OpenAiError {
153                        message: msg,
154                        error_type: "api_error".into(),
155                        code: None,
156                    },
157                }),
158            )
159                .into_response()
160        }
161    }
162}
163
164#[derive(Serialize)]
165struct ModelEntry {
166    id: String,
167    object: &'static str,
168    owned_by: String,
169}
170
171#[derive(Serialize)]
172struct ModelList {
173    object: &'static str,
174    data: Vec<ModelEntry>,
175}
176
177/// GET /v1/models — returns an OpenAI-style list.
178/// Each provider contributes a single placeholder entry `<provider>/default`;
179/// callers can request specific models with the `<provider>/<model>` syntax.
180async fn list_models(State(proxy): State<SharedProxy>) -> Response {
181    let lock = proxy.read().await;
182    let data: Vec<ModelEntry> = lock
183        .inner
184        .list_providers()
185        .into_iter()
186        .map(|p| ModelEntry {
187            id: format!("{}/default", p.to_lowercase()),
188            object: "model",
189            owned_by: p.to_string(),
190        })
191        .collect();
192
193    Json(ModelList {
194        object: "list",
195        data,
196    })
197    .into_response()
198}
199
200#[derive(Serialize)]
201struct AdminStatus {
202    running: bool,
203    auth_required: bool,
204    providers: Vec<&'static str>,
205    version: &'static str,
206}
207
208async fn admin_status(State(proxy): State<SharedProxy>) -> Response {
209    let lock = proxy.read().await;
210    Json(AdminStatus {
211        running: true,
212        auth_required: std::env::var("ST_PROXY_API_KEY").is_ok(),
213        providers: lock.inner.list_providers(),
214        version: env!("CARGO_PKG_VERSION"),
215    })
216    .into_response()
217}