1pub mod auth;
2pub mod inference;
3pub mod routes;
4pub mod types;
5
6use std::net::SocketAddr;
7use std::path::PathBuf;
8use std::sync::Arc;
9use std::time::Instant;
10
11use axum::middleware;
12use axum::Router;
13use axum::routing::{get, post};
14use tokio::net::TcpListener;
15
16use crate::auth::TokenAuthority;
17use crate::inference::{InferenceHandle, ModelInfo};
18use yule_core::error::Result;
19
20pub struct AppState {
21 pub inference_tx: std::sync::mpsc::Sender<inference::InferenceRequest>,
22 pub auth: Arc<TokenAuthority>,
23 pub start_time: Instant,
24 pub model_info: ModelInfo,
25 pub sandbox_active: bool,
26 pub device_pubkey: [u8; 32],
27 pub merkle_root_bytes: [u8; 32],
28 pub signing_key: ed25519_dalek::SigningKey,
29}
30
31pub struct ApiServer {
32 bind: String,
33 model_path: PathBuf,
34 token: Option<String>,
35 sandbox_active: bool,
36}
37
38impl ApiServer {
39 pub fn new(bind: String, model_path: PathBuf, token: Option<String>, sandbox_active: bool) -> Self {
40 Self { bind, model_path, token, sandbox_active }
41 }
42
43 pub async fn run(self) -> Result<()> {
44 eprintln!("loading model: {}", self.model_path.display());
45
46 let handle = InferenceHandle::spawn(self.model_path)?;
47
48 eprintln!("model loaded: {:?} ({} tensors, merkle: {})",
49 handle.model_info.metadata.architecture,
50 handle.model_info.tensor_count,
51 &handle.model_info.merkle_root[..16],
52 );
53
54 let key_store = yule_verify::keys::KeyStore::open()
56 .map_err(|e| yule_core::error::YuleError::Api(format!("key store: {e}")))?;
57 let signing_key = key_store.device_key()
58 .map_err(|e| yule_core::error::YuleError::Api(format!("device key: {e}")))?;
59 let device_pubkey = signing_key.verifying_key().to_bytes();
60
61 let merkle_root_bytes = parse_merkle_hex(&handle.model_info.merkle_root);
63
64 eprintln!("attestation: device key loaded, pubkey {}",
65 hex::short(&device_pubkey));
66
67 let mut auth = match &self.token {
68 Some(t) => TokenAuthority::from_existing(t),
69 None => TokenAuthority::new(),
70 };
71
72 let token = match &self.token {
73 Some(t) => t.clone(),
74 None => auth.generate_token(),
75 };
76
77 eprintln!();
78 eprintln!(" token: {token}");
79 eprintln!();
80
81 let auth = Arc::new(auth);
82
83 let state = Arc::new(AppState {
84 inference_tx: handle.tx.clone(),
85 auth: auth.clone(),
86 start_time: Instant::now(),
87 model_info: handle.model_info.clone(),
88 sandbox_active: self.sandbox_active,
89 device_pubkey,
90 merkle_root_bytes,
91 signing_key,
92 });
93
94 let app = build_router(state, auth);
95
96 let addr: SocketAddr = self.bind.parse()
97 .map_err(|e| yule_core::error::YuleError::Api(format!("invalid bind addr: {e}")))?;
98
99 let listener = TcpListener::bind(addr).await
100 .map_err(|e| yule_core::error::YuleError::Api(format!("bind failed: {e}")))?;
101
102 eprintln!("listening on {addr}");
103 eprintln!(" yule api: http://{addr}/yule/health");
104 eprintln!(" openai: http://{addr}/v1/chat/completions");
105
106 axum::serve(listener, app).await
107 .map_err(|e| yule_core::error::YuleError::Api(format!("server error: {e}")))?;
108
109 handle.shutdown();
110 Ok(())
111 }
112}
113
114fn parse_merkle_hex(hex_str: &str) -> [u8; 32] {
115 let mut bytes = [0u8; 32];
116 for (i, byte) in bytes.iter_mut().enumerate() {
117 if i * 2 + 2 <= hex_str.len() {
118 *byte = u8::from_str_radix(&hex_str[i * 2..i * 2 + 2], 16).unwrap_or(0);
119 }
120 }
121 bytes
122}
123
124mod hex {
125 pub fn short(bytes: &[u8; 32]) -> String {
126 bytes[..8].iter().map(|b| format!("{b:02x}")).collect::<String>() + "..."
127 }
128}
129
130fn build_router(state: Arc<AppState>, auth: Arc<TokenAuthority>) -> Router {
131 use routes::{native, openai_compat};
132
133 Router::new()
134 .route("/yule/health", get(native::health))
135 .route("/yule/model", get(native::model_info))
136 .route("/yule/chat", post(native::chat))
137 .route("/yule/tokenize", post(native::tokenize))
138 .route("/v1/chat/completions", post(openai_compat::chat_completions))
139 .route("/v1/models", get(openai_compat::models))
140 .layer(middleware::from_fn(auth::require_auth))
141 .layer(axum::Extension(auth))
142 .with_state(state)
143}