Skip to main content

walrus_model/
http.rs

1//! Shared HTTP transport for OpenAI-compatible LLM providers.
2//!
3//! `HttpProvider` wraps a `reqwest::Client` with pre-configured headers and
4//! endpoint URL. Provides `send()` for non-streaming and `stream_sse()` for
5//! Server-Sent Events streaming. Used by DeepSeek, OpenAI, and Mistral —
6//! Claude uses its own transport (different SSE format).
7
8use anyhow::Result;
9use async_stream::try_stream;
10use futures_core::Stream;
11use futures_util::StreamExt;
12use reqwest::{
13    Client, Method,
14    header::{self, HeaderMap, HeaderName, HeaderValue},
15};
16use serde::Serialize;
17use wcore::model::{Response, StreamChunk};
18
19/// Shared HTTP transport for OpenAI-compatible providers.
20///
21/// Holds a `reqwest::Client`, pre-built headers (auth + content-type),
22/// and the target endpoint URL.
23#[derive(Clone)]
24pub struct HttpProvider {
25    client: Client,
26    headers: HeaderMap,
27    endpoint: String,
28}
29
30impl HttpProvider {
31    /// Create a provider with Bearer token authentication.
32    pub fn bearer(client: Client, key: &str, endpoint: &str) -> Result<Self> {
33        let mut headers = HeaderMap::new();
34        headers.insert(
35            header::CONTENT_TYPE,
36            HeaderValue::from_static("application/json"),
37        );
38        headers.insert(header::ACCEPT, HeaderValue::from_static("application/json"));
39        headers.insert(header::AUTHORIZATION, format!("Bearer {key}").parse()?);
40        Ok(Self {
41            client,
42            headers,
43            endpoint: endpoint.to_owned(),
44        })
45    }
46
47    /// Create a provider without authentication (e.g. Ollama).
48    pub fn no_auth(client: Client, endpoint: &str) -> Self {
49        let mut headers = HeaderMap::new();
50        headers.insert(
51            header::CONTENT_TYPE,
52            HeaderValue::from_static("application/json"),
53        );
54        headers.insert(header::ACCEPT, HeaderValue::from_static("application/json"));
55        Self {
56            client,
57            headers,
58            endpoint: endpoint.to_owned(),
59        }
60    }
61
62    /// Create a provider with a custom header for authentication.
63    ///
64    /// Used by providers that don't use Bearer tokens (e.g. Anthropic
65    /// uses `x-api-key`).
66    pub fn custom_header(
67        client: Client,
68        header_name: &str,
69        header_value: &str,
70        endpoint: &str,
71    ) -> Result<Self> {
72        let mut headers = HeaderMap::new();
73        headers.insert(
74            header::CONTENT_TYPE,
75            HeaderValue::from_static("application/json"),
76        );
77        headers.insert(header::ACCEPT, HeaderValue::from_static("application/json"));
78        headers.insert(
79            header_name.parse::<HeaderName>()?,
80            header_value.parse::<HeaderValue>()?,
81        );
82        Ok(Self {
83            client,
84            headers,
85            endpoint: endpoint.to_owned(),
86        })
87    }
88
89    /// Send a non-streaming request and deserialize the response as JSON.
90    pub async fn send(&self, body: &impl Serialize) -> Result<Response> {
91        tracing::trace!("request: {}", serde_json::to_string(body)?);
92        let response = self
93            .client
94            .request(Method::POST, &self.endpoint)
95            .headers(self.headers.clone())
96            .json(body)
97            .send()
98            .await?;
99
100        let status = response.status();
101        let text = response.text().await?;
102        if !status.is_success() {
103            anyhow::bail!("API error ({status}): {text}");
104        }
105
106        serde_json::from_str(&text).map_err(Into::into)
107    }
108
109    /// Stream an SSE response (OpenAI-compatible format).
110    ///
111    /// Parses `data: ` prefixed lines, skips `[DONE]` sentinel, and
112    /// deserializes each chunk as [`StreamChunk`].
113    pub fn stream_sse(
114        &self,
115        body: &impl Serialize,
116    ) -> impl Stream<Item = Result<StreamChunk>> + Send {
117        if let Ok(body) = serde_json::to_string(body) {
118            tracing::trace!("request: {}", body);
119        }
120        let request = self
121            .client
122            .request(Method::POST, &self.endpoint)
123            .headers(self.headers.clone())
124            .json(body);
125
126        try_stream! {
127            let response = request.send().await?;
128            let mut stream = response.bytes_stream();
129            while let Some(next) = stream.next().await {
130                let bytes = next?;
131                let text = String::from_utf8_lossy(&bytes);
132                tracing::trace!("chunk: {}", text);
133                for data in text.split("data: ").skip(1).filter(|s| !s.starts_with("[DONE]")) {
134                    let trimmed = data.trim();
135                    if trimmed.is_empty() {
136                        continue;
137                    }
138                    match serde_json::from_str::<StreamChunk>(trimmed) {
139                        Ok(chunk) => yield chunk,
140                        Err(e) => tracing::warn!("failed to parse chunk: {e}, data: {trimmed}"),
141                    }
142                }
143            }
144        }
145    }
146
147    /// Get the endpoint URL.
148    pub fn endpoint(&self) -> &str {
149        &self.endpoint
150    }
151
152    /// Get a reference to the headers.
153    pub fn headers(&self) -> &HeaderMap {
154        &self.headers
155    }
156}