Skip to main content

walrus_model/
http.rs

1//! Shared HTTP transport for OpenAI-compatible LLM providers.
2//!
3//! `HttpProvider` wraps a `reqwest::Client` with pre-configured headers and
4//! endpoint URL. Provides `send()` for non-streaming and `stream_sse()` for
5//! Server-Sent Events streaming. Used by DeepSeek, OpenAI, and Mistral —
6//! Claude uses its own transport (different SSE format).
7
8use anyhow::Result;
9use async_stream::try_stream;
10use futures_core::Stream;
11use futures_util::StreamExt;
12use reqwest::{
13    Client, Method,
14    header::{self, HeaderMap, HeaderName, HeaderValue},
15};
16use serde::Serialize;
17use wcore::model::{Response, StreamChunk};
18
19/// Shared HTTP transport for OpenAI-compatible providers.
20///
21/// Holds a `reqwest::Client`, pre-built headers (auth + content-type),
22/// and the target endpoint URL.
23#[derive(Clone)]
24pub struct HttpProvider {
25    client: Client,
26    headers: HeaderMap,
27    endpoint: String,
28}
29
30impl HttpProvider {
31    /// Create a provider with Bearer token authentication.
32    pub fn bearer(client: Client, key: &str, endpoint: &str) -> Result<Self> {
33        let mut headers = HeaderMap::new();
34        headers.insert(
35            header::CONTENT_TYPE,
36            HeaderValue::from_static("application/json"),
37        );
38        headers.insert(header::ACCEPT, HeaderValue::from_static("application/json"));
39        headers.insert(header::AUTHORIZATION, format!("Bearer {key}").parse()?);
40        Ok(Self {
41            client,
42            headers,
43            endpoint: endpoint.to_owned(),
44        })
45    }
46
47    /// Create a provider without authentication (e.g. Ollama).
48    pub fn no_auth(client: Client, endpoint: &str) -> Self {
49        let mut headers = HeaderMap::new();
50        headers.insert(
51            header::CONTENT_TYPE,
52            HeaderValue::from_static("application/json"),
53        );
54        headers.insert(header::ACCEPT, HeaderValue::from_static("application/json"));
55        Self {
56            client,
57            headers,
58            endpoint: endpoint.to_owned(),
59        }
60    }
61
62    /// Create a provider with a custom header for authentication.
63    ///
64    /// Used by providers that don't use Bearer tokens (e.g. Anthropic
65    /// uses `x-api-key`).
66    pub fn custom_header(
67        client: Client,
68        header_name: &str,
69        header_value: &str,
70        endpoint: &str,
71    ) -> Result<Self> {
72        let mut headers = HeaderMap::new();
73        headers.insert(
74            header::CONTENT_TYPE,
75            HeaderValue::from_static("application/json"),
76        );
77        headers.insert(header::ACCEPT, HeaderValue::from_static("application/json"));
78        headers.insert(
79            header_name.parse::<HeaderName>()?,
80            header_value.parse::<HeaderValue>()?,
81        );
82        Ok(Self {
83            client,
84            headers,
85            endpoint: endpoint.to_owned(),
86        })
87    }
88
89    /// Send a non-streaming request and deserialize the response as JSON.
90    pub async fn send(&self, body: &impl Serialize) -> Result<Response> {
91        tracing::trace!("request: {}", serde_json::to_string(body)?);
92        let text = self
93            .client
94            .request(Method::POST, &self.endpoint)
95            .headers(self.headers.clone())
96            .json(body)
97            .send()
98            .await?
99            .text()
100            .await?;
101
102        serde_json::from_str(&text).map_err(Into::into)
103    }
104
105    /// Stream an SSE response (OpenAI-compatible format).
106    ///
107    /// Parses `data: ` prefixed lines, skips `[DONE]` sentinel, and
108    /// deserializes each chunk as [`StreamChunk`].
109    pub fn stream_sse(
110        &self,
111        body: &impl Serialize,
112    ) -> impl Stream<Item = Result<StreamChunk>> + Send {
113        if let Ok(body) = serde_json::to_string(body) {
114            tracing::trace!("request: {}", body);
115        }
116        let request = self
117            .client
118            .request(Method::POST, &self.endpoint)
119            .headers(self.headers.clone())
120            .json(body);
121
122        try_stream! {
123            let response = request.send().await?;
124            let mut stream = response.bytes_stream();
125            while let Some(next) = stream.next().await {
126                let bytes = next?;
127                let text = String::from_utf8_lossy(&bytes);
128                tracing::trace!("chunk: {}", text);
129                for data in text.split("data: ").skip(1).filter(|s| !s.starts_with("[DONE]")) {
130                    let trimmed = data.trim();
131                    if trimmed.is_empty() {
132                        continue;
133                    }
134                    match serde_json::from_str::<StreamChunk>(trimmed) {
135                        Ok(chunk) => yield chunk,
136                        Err(e) => tracing::warn!("failed to parse chunk: {e}, data: {trimmed}"),
137                    }
138                }
139            }
140        }
141    }
142
143    /// Get the endpoint URL.
144    pub fn endpoint(&self) -> &str {
145        &self.endpoint
146    }
147
148    /// Get a reference to the headers.
149    pub fn headers(&self) -> &HeaderMap {
150        &self.headers
151    }
152}