Skip to main content

yule_api/
lib.rs

1pub mod auth;
2pub mod inference;
3pub mod routes;
4pub mod types;
5
6use std::net::SocketAddr;
7use std::path::PathBuf;
8use std::sync::Arc;
9use std::time::Instant;
10
11use axum::middleware;
12use axum::Router;
13use axum::routing::{get, post};
14use tokio::net::TcpListener;
15
16use crate::auth::TokenAuthority;
17use crate::inference::{InferenceHandle, ModelInfo};
18use yule_core::error::Result;
19
20pub struct AppState {
21    pub inference_tx: std::sync::mpsc::Sender<inference::InferenceRequest>,
22    pub auth: Arc<TokenAuthority>,
23    pub start_time: Instant,
24    pub model_info: ModelInfo,
25    pub sandbox_active: bool,
26    pub device_pubkey: [u8; 32],
27    pub merkle_root_bytes: [u8; 32],
28    pub signing_key: ed25519_dalek::SigningKey,
29}
30
31pub struct ApiServer {
32    bind: String,
33    model_path: PathBuf,
34    token: Option<String>,
35    sandbox_active: bool,
36}
37
38impl ApiServer {
39    pub fn new(bind: String, model_path: PathBuf, token: Option<String>, sandbox_active: bool) -> Self {
40        Self { bind, model_path, token, sandbox_active }
41    }
42
43    pub async fn run(self) -> Result<()> {
44        eprintln!("loading model: {}", self.model_path.display());
45
46        let handle = InferenceHandle::spawn(self.model_path)?;
47
48        eprintln!("model loaded: {:?} ({} tensors, merkle: {})",
49            handle.model_info.metadata.architecture,
50            handle.model_info.tensor_count,
51            &handle.model_info.merkle_root[..16],
52        );
53
54        // load device signing key for attestation
55        let key_store = yule_verify::keys::KeyStore::open()
56            .map_err(|e| yule_core::error::YuleError::Api(format!("key store: {e}")))?;
57        let signing_key = key_store.device_key()
58            .map_err(|e| yule_core::error::YuleError::Api(format!("device key: {e}")))?;
59        let device_pubkey = signing_key.verifying_key().to_bytes();
60
61        // parse merkle root hex → bytes
62        let merkle_root_bytes = parse_merkle_hex(&handle.model_info.merkle_root);
63
64        eprintln!("attestation: device key loaded, pubkey {}",
65            hex::short(&device_pubkey));
66
67        let mut auth = match &self.token {
68            Some(t) => TokenAuthority::from_existing(t),
69            None => TokenAuthority::new(),
70        };
71
72        let token = match &self.token {
73            Some(t) => t.clone(),
74            None => auth.generate_token(),
75        };
76
77        eprintln!();
78        eprintln!("  token: {token}");
79        eprintln!();
80
81        let auth = Arc::new(auth);
82
83        let state = Arc::new(AppState {
84            inference_tx: handle.tx.clone(),
85            auth: auth.clone(),
86            start_time: Instant::now(),
87            model_info: handle.model_info.clone(),
88            sandbox_active: self.sandbox_active,
89            device_pubkey,
90            merkle_root_bytes,
91            signing_key,
92        });
93
94        let app = build_router(state, auth);
95
96        let addr: SocketAddr = self.bind.parse()
97            .map_err(|e| yule_core::error::YuleError::Api(format!("invalid bind addr: {e}")))?;
98
99        let listener = TcpListener::bind(addr).await
100            .map_err(|e| yule_core::error::YuleError::Api(format!("bind failed: {e}")))?;
101
102        eprintln!("listening on {addr}");
103        eprintln!("  yule api:  http://{addr}/yule/health");
104        eprintln!("  openai:    http://{addr}/v1/chat/completions");
105
106        axum::serve(listener, app).await
107            .map_err(|e| yule_core::error::YuleError::Api(format!("server error: {e}")))?;
108
109        handle.shutdown();
110        Ok(())
111    }
112}
113
114fn parse_merkle_hex(hex_str: &str) -> [u8; 32] {
115    let mut bytes = [0u8; 32];
116    for (i, byte) in bytes.iter_mut().enumerate() {
117        if i * 2 + 2 <= hex_str.len() {
118            *byte = u8::from_str_radix(&hex_str[i * 2..i * 2 + 2], 16).unwrap_or(0);
119        }
120    }
121    bytes
122}
123
124mod hex {
125    pub fn short(bytes: &[u8; 32]) -> String {
126        bytes[..8].iter().map(|b| format!("{b:02x}")).collect::<String>() + "..."
127    }
128}
129
130fn build_router(state: Arc<AppState>, auth: Arc<TokenAuthority>) -> Router {
131    use routes::{native, openai_compat};
132
133    Router::new()
134        .route("/yule/health", get(native::health))
135        .route("/yule/model", get(native::model_info))
136        .route("/yule/chat", post(native::chat))
137        .route("/yule/tokenize", post(native::tokenize))
138        .route("/v1/chat/completions", post(openai_compat::chat_completions))
139        .route("/v1/models", get(openai_compat::models))
140        .layer(middleware::from_fn(auth::require_auth))
141        .layer(axum::Extension(auth))
142        .with_state(state)
143}