Skip to main content

venice_e2ee_proxy/
venice.rs

1//! Venice upstream API client and model mapping.
2//!
3//! Implements authenticated Venice model-list retrieval, E2EE model filtering,
4//! and mapping into the OpenAI-compatible `/v1/models` response.
5
6use std::{fmt, sync::Arc, time::Duration};
7
8use reqwest::{
9    Url,
10    header::{ACCEPT, CONTENT_TYPE},
11};
12use secrecy::ExposeSecret;
13use serde::Deserialize;
14use serde_json::Value;
15use thiserror::Error;
16
17use crate::{
18    config::{ConfigError, ProxyConfig},
19    openai::{
20        ModelCapabilities, ModelListResponse, ModelObject, VeniceModelMetadata,
21        chat::VeniceE2eeChatRequest,
22    },
23};
24
25pub const HEADER_VENICE_TEE_CLIENT_PUB_KEY: &str = "X-Venice-TEE-Client-Pub-Key";
26pub const HEADER_VENICE_TEE_MODEL_PUB_KEY: &str = "X-Venice-TEE-Model-Pub-Key";
27pub const HEADER_VENICE_TEE_SIGNING_ALGO: &str = "X-Venice-TEE-Signing-Algo";
28
29/// Authenticated HTTP client for Venice model, chat, and attestation endpoints.
30#[derive(Clone)]
31pub struct VeniceClient {
32    http: reqwest::Client,
33    base_url: Url,
34    api_key: Arc<str>,
35    request_timeout: Duration,
36}
37
38impl VeniceClient {
39    /// Builds a Venice client from proxy configuration and the configured API key.
40    pub fn from_config(config: &ProxyConfig) -> Result<Self, VeniceClientError> {
41        let api_key = config.venice_api_key()?;
42        Self::new(
43            &config.venice.base_url,
44            api_key.expose_secret(),
45            config.venice.request_timeout,
46        )
47    }
48
49    /// Builds a Venice client from a base API URL, bearer token, and request timeout.
50    pub fn new(
51        base_url: impl AsRef<str>,
52        api_key: impl Into<String>,
53        timeout: Duration,
54    ) -> Result<Self, VeniceClientError> {
55        let base_url = parse_base_url(base_url.as_ref())?;
56        let http = reqwest::Client::builder()
57            .connect_timeout(timeout)
58            .read_timeout(timeout)
59            .build()
60            .map_err(VeniceClientError::client_build)?;
61
62        Ok(Self {
63            http,
64            base_url,
65            api_key: Arc::from(api_key.into()),
66            request_timeout: timeout,
67        })
68    }
69
70    /// Fetches Venice models and returns only E2EE/TEE-supported models in OpenAI shape.
71    pub async fn list_models(&self) -> Result<ModelListResponse, VeniceClientError> {
72        let url = self.models_url()?;
73        let response = self
74            .http
75            .get(url)
76            .bearer_auth(self.api_key.as_ref())
77            .header(ACCEPT, "application/json")
78            .timeout(self.request_timeout)
79            .send()
80            .await
81            .map_err(VeniceClientError::request_failure)?;
82
83        let response = Self::check_status(response)?;
84
85        let body = response
86            .bytes()
87            .await
88            .map_err(VeniceClientError::request_failure)?;
89        parse_model_list_response(&body)
90    }
91
92    /// Sends an encrypted chat request to Venice and returns the upstream SSE response.
93    pub async fn create_chat_completion_stream(
94        &self,
95        request: &VeniceE2eeChatRequest,
96        client_public_key_hex: &str,
97        model_public_key_hex: &str,
98    ) -> Result<reqwest::Response, VeniceClientError> {
99        let url = self.chat_completions_url()?;
100        let response = self
101            .http
102            .post(url)
103            .bearer_auth(self.api_key.as_ref())
104            .header(ACCEPT, "text/event-stream")
105            .header(CONTENT_TYPE, "application/json")
106            .header(HEADER_VENICE_TEE_CLIENT_PUB_KEY, client_public_key_hex)
107            .header(HEADER_VENICE_TEE_MODEL_PUB_KEY, model_public_key_hex)
108            .header(HEADER_VENICE_TEE_SIGNING_ALGO, "ecdsa")
109            .json(request)
110            .send()
111            .await
112            .map_err(VeniceClientError::request_failure)?;
113
114        Self::check_status(response)
115    }
116
117    /// Fetches attestation evidence for a model and nonce as raw JSON.
118    pub async fn fetch_attestation_evidence(
119        &self,
120        model_id: &str,
121        nonce: &str,
122    ) -> Result<Value, VeniceClientError> {
123        let url = self.attestation_url(model_id, nonce)?;
124        let response = self
125            .http
126            .get(url)
127            .bearer_auth(self.api_key.as_ref())
128            .header(ACCEPT, "application/json")
129            .timeout(self.request_timeout)
130            .send()
131            .await
132            .map_err(VeniceClientError::request_failure)?;
133
134        let response = Self::check_status(response)?;
135
136        response
137            .json::<Value>()
138            .await
139            .map_err(VeniceClientError::malformed_attestation_payload)
140    }
141
142    /// Maps unauthorized/forbidden and other non-success statuses to errors.
143    fn check_status(response: reqwest::Response) -> Result<reqwest::Response, VeniceClientError> {
144        let status = response.status();
145
146        if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
147            return Err(VeniceClientError::Authentication {
148                status: status.as_u16(),
149            });
150        }
151
152        if !status.is_success() {
153            return Err(VeniceClientError::UpstreamStatus {
154                status: status.as_u16(),
155            });
156        }
157
158        Ok(response)
159    }
160
161    /// Returns the Venice models endpoint URL.
162    fn models_url(&self) -> Result<Url, VeniceClientError> {
163        self.endpoint_url("models")
164    }
165
166    /// Returns the Venice chat completions endpoint URL.
167    fn chat_completions_url(&self) -> Result<Url, VeniceClientError> {
168        self.endpoint_url("chat/completions")
169    }
170
171    /// Returns the Venice attestation endpoint URL for a model and nonce.
172    fn attestation_url(&self, model_id: &str, nonce: &str) -> Result<Url, VeniceClientError> {
173        let mut url = self.endpoint_url("tee/attestation")?;
174        url.query_pairs_mut()
175            .append_pair("model", model_id)
176            .append_pair("nonce", nonce);
177
178        Ok(url)
179    }
180
181    /// Joins an endpoint path onto the configured Venice base URL.
182    fn endpoint_url(&self, path: &str) -> Result<Url, VeniceClientError> {
183        self.base_url
184            .join(path)
185            .map_err(|source| VeniceClientError::EndpointUrl {
186                message: source.to_string(),
187            })
188    }
189}
190
191impl fmt::Debug for VeniceClient {
192    /// Formats client metadata while redacting the API key.
193    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
194        f.debug_struct("VeniceClient")
195            .field("base_url", &self.base_url)
196            .field("api_key", &"[redacted]")
197            .finish_non_exhaustive()
198    }
199}
200
201/// Parses a Venice base URL and ensures relative endpoint joins work consistently.
202fn parse_base_url(value: &str) -> Result<Url, VeniceClientError> {
203    let mut url = Url::parse(value).map_err(|source| VeniceClientError::InvalidBaseUrl {
204        base_url: value.to_owned(),
205        message: source.to_string(),
206    })?;
207
208    if !url.path().ends_with('/') {
209        let path = format!("{}/", url.path());
210        url.set_path(&path);
211    }
212
213    Ok(url)
214}
215
216/// Parses a Venice model-list payload and maps it into an OpenAI-compatible response.
217fn parse_model_list_response(body: &[u8]) -> Result<ModelListResponse, VeniceClientError> {
218    let payload: VeniceModelListPayload =
219        serde_json::from_slice(body).map_err(VeniceClientError::malformed_payload)?;
220    Ok(payload.into_openai_model_list())
221}
222
223/// Errors returned by the Venice upstream client.
224#[derive(Debug, Error)]
225pub enum VeniceClientError {
226    #[error(transparent)]
227    Config(#[from] ConfigError),
228    #[error("invalid Venice base URL {base_url}: {message}")]
229    InvalidBaseUrl { base_url: String, message: String },
230    #[error("failed to build Venice HTTP client: {message}")]
231    ClientBuild { message: String },
232    #[error("failed to build Venice models URL: {message}")]
233    EndpointUrl { message: String },
234    #[error("Venice upstream authentication failed with status {status}")]
235    Authentication { status: u16 },
236    #[error("Venice upstream returned status {status}")]
237    UpstreamStatus { status: u16 },
238    #[error("Venice upstream request timed out")]
239    Timeout,
240    #[error("Venice upstream request failed: {message}")]
241    Request { message: String },
242    #[error("Venice upstream returned malformed model payload: {message}")]
243    MalformedPayload { message: String },
244    #[error("Venice upstream returned malformed attestation payload: {message}")]
245    MalformedAttestationPayload { message: String },
246}
247
248impl VeniceClientError {
249    /// Returns the OpenAI-compatible error type exposed for this Venice client error.
250    pub fn api_error_type(&self) -> &'static str {
251        match self {
252            Self::Config(_)
253            | Self::InvalidBaseUrl { .. }
254            | Self::ClientBuild { .. }
255            | Self::EndpointUrl { .. } => "proxy_configuration_error",
256            Self::Authentication { .. } => "proxy_upstream_authentication_error",
257            Self::UpstreamStatus { .. }
258            | Self::Timeout
259            | Self::Request { .. }
260            | Self::MalformedPayload { .. }
261            | Self::MalformedAttestationPayload { .. } => "proxy_upstream_error",
262        }
263    }
264
265    /// Returns the proxy error code exposed for this Venice client error.
266    pub fn api_error_code(&self) -> &'static str {
267        match self {
268            Self::Config(ConfigError::MissingApiKey) => "venice_api_key_missing",
269            Self::Config(_)
270            | Self::InvalidBaseUrl { .. }
271            | Self::ClientBuild { .. }
272            | Self::EndpointUrl { .. } => "venice_client_configuration_failed",
273            Self::Authentication { .. } => "upstream_authentication_failed",
274            Self::UpstreamStatus { .. } => "upstream_status_error",
275            Self::Timeout => "upstream_timeout",
276            Self::Request { .. } => "upstream_unavailable",
277            Self::MalformedPayload { .. } | Self::MalformedAttestationPayload { .. } => {
278                "upstream_malformed_response"
279            }
280        }
281    }
282
283    /// Converts an HTTP client builder error into a Venice client error.
284    fn client_build(source: reqwest::Error) -> Self {
285        Self::ClientBuild {
286            message: source.to_string(),
287        }
288    }
289
290    /// Converts a request failure into timeout or generic upstream request errors.
291    fn request_failure(source: reqwest::Error) -> Self {
292        if source.is_timeout() {
293            Self::Timeout
294        } else {
295            Self::Request {
296                message: source.to_string(),
297            }
298        }
299    }
300
301    /// Converts a model-list JSON parse error into a malformed-payload error.
302    fn malformed_payload(source: serde_json::Error) -> Self {
303        Self::MalformedPayload {
304            message: source.to_string(),
305        }
306    }
307
308    /// Converts an attestation JSON parse error into a malformed-attestation error.
309    fn malformed_attestation_payload(source: reqwest::Error) -> Self {
310        Self::MalformedAttestationPayload {
311            message: source.to_string(),
312        }
313    }
314}
315
316/// Raw Venice model-list response payload.
317#[derive(Debug, Deserialize)]
318struct VeniceModelListPayload {
319    data: Vec<VeniceModel>,
320}
321
322impl VeniceModelListPayload {
323    /// Converts raw Venice models into a filtered OpenAI-compatible model list.
324    fn into_openai_model_list(self) -> ModelListResponse {
325        let data = self
326            .data
327            .into_iter()
328            .filter_map(VeniceModel::into_openai_model_if_supported)
329            .collect();
330
331        ModelListResponse::new(data)
332    }
333}
334
335/// Raw Venice model object used for OpenAI-compatible model-list mapping.
336#[derive(Debug, Deserialize)]
337struct VeniceModel {
338    id: String,
339    #[serde(default)]
340    created: Option<i64>,
341    #[serde(default)]
342    owned_by: Option<String>,
343    #[serde(rename = "type")]
344    model_type: String,
345    model_spec: VeniceModelSpec,
346}
347
348impl VeniceModel {
349    /// Converts a supported Venice text/E2EE/TEE model into an OpenAI model object.
350    fn into_openai_model_if_supported(self) -> Option<ModelObject> {
351        let capabilities = self.model_spec.capabilities;
352        if self.model_type != "text"
353            || !capabilities.supports_e2ee
354            || !capabilities.supports_tee_attestation
355        {
356            return None;
357        }
358
359        let venice = VeniceModelMetadata::new(
360            self.id.clone(),
361            capabilities.supports_e2ee,
362            capabilities.supports_tee_attestation,
363            capabilities.supports_reasoning.unwrap_or(false),
364            capabilities.supports_reasoning_effort.unwrap_or(false),
365        );
366        let openai_capabilities = capabilities.to_openai_capabilities();
367
368        Some(ModelObject::new(
369            self.id,
370            self.created.unwrap_or(0),
371            self.owned_by.unwrap_or_else(|| "venice.ai".to_owned()),
372            openai_capabilities,
373            venice,
374        ))
375    }
376}
377
378/// Raw Venice model specification containing capability metadata.
379#[derive(Debug, Deserialize)]
380struct VeniceModelSpec {
381    capabilities: VeniceCapabilities,
382}
383
384/// Raw Venice capability flags used to decide model support and OpenAI metadata.
385#[derive(Debug, Deserialize)]
386struct VeniceCapabilities {
387    #[serde(rename = "supportsE2EE")]
388    supports_e2ee: bool,
389    #[serde(rename = "supportsTeeAttestation")]
390    supports_tee_attestation: bool,
391    #[serde(default, rename = "supportsFunctionCalling")]
392    supports_function_calling: Option<bool>,
393    #[serde(default, rename = "supportsBuiltinTools")]
394    supports_builtin_tools: Option<bool>,
395    #[serde(default, rename = "supportsWebSearch")]
396    supports_web_search: Option<bool>,
397    #[serde(default, rename = "supportsCodeInterpreter")]
398    supports_code_interpreter: Option<bool>,
399    #[serde(default, rename = "supportsVision")]
400    supports_vision: Option<bool>,
401    #[serde(default, rename = "supportsReasoning")]
402    supports_reasoning: Option<bool>,
403    #[serde(default, rename = "supportsReasoningEffort")]
404    supports_reasoning_effort: Option<bool>,
405}
406
407impl VeniceCapabilities {
408    /// Maps Venice capability flags into the OpenAI-compatible capability object.
409    fn to_openai_capabilities(&self) -> ModelCapabilities {
410        let web_search = self.supports_web_search.unwrap_or(false);
411        let code_interpreter = self.supports_code_interpreter.unwrap_or(false);
412        let builtin_tools = self
413            .supports_builtin_tools
414            .unwrap_or(web_search || code_interpreter);
415
416        ModelCapabilities {
417            function_calling: self.supports_function_calling.unwrap_or(false),
418            builtin_tools,
419            web_search,
420            code_interpreter,
421            vision: self.supports_vision.unwrap_or(false),
422            reasoning: self.supports_reasoning.unwrap_or(false),
423            reasoning_effort: self.supports_reasoning_effort.unwrap_or(false),
424        }
425    }
426}
427
428#[cfg(test)]
429mod tests {
430    use super::*;
431
432    use axum::{Router, body::Body, response::IntoResponse, routing::post};
433    use tokio::net::TcpListener;
434
435    const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(10);
436
437    #[test]
438    fn maps_supported_venice_text_models_to_openai_shape() {
439        let body = br#"
440        {
441          "data": [
442            {
443              "id": "e2ee-qwen3-5-122b-a10b",
444              "created": 1727966436,
445              "owned_by": "venice.ai",
446              "type": "text",
447              "model_spec": {
448                "capabilities": {
449                  "supportsE2EE": true,
450                  "supportsTeeAttestation": true,
451                  "supportsFunctionCalling": true,
452                  "supportsBuiltinTools": true,
453                  "supportsWebSearch": true,
454                  "supportsCodeInterpreter": true,
455                  "supportsVision": false,
456                  "supportsReasoning": true,
457                  "supportsReasoningEffort": true
458                }
459              }
460            },
461            {
462              "id": "non-e2ee-text",
463              "type": "text",
464              "model_spec": {
465                "capabilities": {
466                  "supportsE2EE": false,
467                  "supportsTeeAttestation": true
468                }
469              }
470            },
471            {
472              "id": "e2ee-image",
473              "type": "image",
474              "model_spec": {
475                "capabilities": {
476                  "supportsE2EE": true,
477                  "supportsTeeAttestation": true
478                }
479              }
480            }
481          ]
482        }
483        "#;
484
485        let response = parse_model_list_response(body).expect("valid model payload should parse");
486
487        assert_eq!(response.object, "list");
488        assert_eq!(response.data.len(), 1);
489        let model = &response.data[0];
490        assert_eq!(model.id, "e2ee-qwen3-5-122b-a10b");
491        assert_eq!(model.object, "model");
492        assert_eq!(model.created, 1727966436);
493        assert_eq!(model.owned_by, "venice.ai");
494        assert_eq!(model.name, "e2ee-qwen3-5-122b-a10b");
495        assert!(model.info.meta.capabilities.function_calling);
496        assert!(model.info.meta.capabilities.builtin_tools);
497        assert!(model.info.meta.capabilities.web_search);
498        assert!(model.info.meta.capabilities.code_interpreter);
499        assert!(!model.info.meta.capabilities.vision);
500        assert!(model.info.meta.capabilities.reasoning);
501        assert!(model.info.meta.capabilities.reasoning_effort);
502        assert_eq!(model.venice.id, "e2ee-qwen3-5-122b-a10b");
503        assert!(model.venice.supports_e2ee);
504        assert!(model.venice.supports_tee_attestation);
505        assert!(model.venice.supports_reasoning);
506        assert!(model.venice.supports_reasoning_effort);
507    }
508
509    #[test]
510    fn missing_optional_capability_metadata_defaults_to_false() {
511        let body = br#"
512        {
513          "data": [
514            {
515              "id": "e2ee-minimal",
516              "type": "text",
517              "model_spec": {
518                "capabilities": {
519                  "supportsE2EE": true,
520                  "supportsTeeAttestation": true
521                }
522              }
523            }
524          ]
525        }
526        "#;
527
528        let response =
529            parse_model_list_response(body).expect("minimal capability payload should parse");
530        let model = response
531            .data
532            .first()
533            .expect("supported model should be present");
534
535        assert_eq!(model.created, 0);
536        assert_eq!(model.owned_by, "venice.ai");
537        assert_eq!(model.info.meta.capabilities, ModelCapabilities::default());
538    }
539
540    #[test]
541    fn malformed_model_payload_is_reported() {
542        let body = br#"
543        {
544          "data": [
545            {
546              "id": "missing-required-attestation-flag",
547              "type": "text",
548              "model_spec": {
549                "capabilities": {
550                  "supportsE2EE": true
551                }
552              }
553            }
554          ]
555        }
556        "#;
557
558        let error = parse_model_list_response(body).expect_err("malformed payload should fail");
559
560        assert!(matches!(error, VeniceClientError::MalformedPayload { .. }));
561        assert_eq!(error.api_error_code(), "upstream_malformed_response");
562    }
563
564    #[test]
565    fn client_debug_output_redacts_api_key() {
566        let client = VeniceClient::new(
567            "https://api.venice.ai/api/v1",
568            "super-secret-test-key",
569            DEFAULT_REQUEST_TIMEOUT,
570        )
571        .expect("client should build");
572
573        let debug = format!("{client:?}");
574        assert!(debug.contains("api.venice.ai"));
575        assert!(debug.contains("/api/v1/"));
576        assert!(debug.contains("[redacted]"));
577        assert!(!debug.contains("super-secret-test-key"));
578    }
579
580    #[tokio::test]
581    async fn chat_stream_can_outlive_configured_request_timeout_when_chunks_keep_arriving() {
582        async fn slow_streaming_chat() -> impl IntoResponse {
583            let stream = async_stream::stream! {
584                for index in 0..5 {
585                    tokio::time::sleep(Duration::from_millis(20)).await;
586                    yield Ok::<_, std::io::Error>(format!("data: {index}\n\n"));
587                }
588                yield Ok::<_, std::io::Error>("data: [DONE]\n\n".to_owned());
589            };
590
591            (
592                [
593                    ("content-type", "text/event-stream"),
594                    ("cache-control", "no-cache"),
595                ],
596                Body::from_stream(stream),
597            )
598        }
599
600        let app = Router::new().route("/api/v1/chat/completions", post(slow_streaming_chat));
601        let listener = TcpListener::bind(("127.0.0.1", 0))
602            .await
603            .expect("mock listener should bind");
604        let addr = listener.local_addr().expect("listener should have address");
605        tokio::spawn(async move {
606            axum::serve(listener, app)
607                .await
608                .expect("mock server should run");
609        });
610
611        let client = VeniceClient::new(
612            format!("http://{addr}/api/v1"),
613            "test-api-key",
614            Duration::from_millis(50),
615        )
616        .expect("client should build");
617        let request = VeniceE2eeChatRequest {
618            model: "e2ee-test".to_owned(),
619            messages: Vec::new(),
620            stream: true,
621            stream_options: crate::openai::chat::VeniceStreamOptions {
622                include_usage: false,
623            },
624            venice_parameters: crate::openai::chat::VeniceParameters::default(),
625            temperature: None,
626            top_p: None,
627            max_tokens: None,
628            max_completion_tokens: None,
629            stop: None,
630            reasoning: None,
631            reasoning_effort: None,
632        };
633
634        let mut response = client
635            .create_chat_completion_stream(&request, "client-key", "model-key")
636            .await
637            .expect("stream response headers should arrive before timeout");
638        let mut chunks = 0;
639        while let Some(_chunk) = response
640            .chunk()
641            .await
642            .expect("frequent stream chunks should not hit total timeout")
643        {
644            chunks += 1;
645        }
646
647        assert!(chunks > 1);
648    }
649}