Skip to main content

walrus_model/openai/
provider.rs

1//! Model trait implementation for the OpenAI-compatible provider.
2
3use super::OpenAI;
4use anyhow::Result;
5use async_stream::try_stream;
6use compact_str::CompactString;
7use futures_core::Stream;
8use futures_util::StreamExt;
9use reqwest::Method;
10use wcore::model::{Model, Response, StreamChunk};
11
12impl Model for OpenAI {
13    async fn send(&self, request: &wcore::model::Request) -> Result<Response> {
14        let body = crate::request::Request::from(request.clone());
15        tracing::trace!("request: {}", serde_json::to_string(&body)?);
16        let text = self
17            .client
18            .request(Method::POST, &self.endpoint)
19            .headers(self.headers.clone())
20            .json(&body)
21            .send()
22            .await?
23            .text()
24            .await?;
25
26        serde_json::from_str(&text).map_err(Into::into)
27    }
28
29    fn stream(
30        &self,
31        request: wcore::model::Request,
32    ) -> impl Stream<Item = Result<StreamChunk>> + Send {
33        let usage = request.usage;
34        let body = crate::request::Request::from(request).stream(usage);
35        if let Ok(body) = serde_json::to_string(&body) {
36            tracing::trace!("request: {}", body);
37        }
38        let request = self
39            .client
40            .request(Method::POST, &self.endpoint)
41            .headers(self.headers.clone())
42            .json(&body);
43
44        try_stream! {
45            let response = request.send().await?;
46            let mut stream = response.bytes_stream();
47            while let Some(Ok(bytes)) = stream.next().await {
48                let text = String::from_utf8_lossy(&bytes).into_owned();
49                tracing::trace!("chunk: {}", text);
50                for data in text.split("data: ").skip(1).filter(|s| !s.starts_with("[DONE]")) {
51                    let trimmed = data.trim();
52                    if trimmed.is_empty() {
53                        continue;
54                    }
55                    match serde_json::from_str::<StreamChunk>(trimmed) {
56                        Ok(chunk) => yield chunk,
57                        Err(e) => tracing::warn!("failed to parse chunk: {e}, data: {trimmed}"),
58                    }
59                }
60            }
61        }
62    }
63
64    fn active_model(&self) -> CompactString {
65        CompactString::from("gpt-4o")
66    }
67}