Skip to main content

walrus_model/
provider.rs

1//! Provider implementation backed by crabtalk-provider.
2//!
3//! Wraps `crabtalk_provider::Provider` behind wcore's `Model` trait with
4//! type conversion and retry logic.
5
6use crate::{config::ProviderDef, convert};
7use anyhow::Result;
8use async_stream::try_stream;
9use crabtalk_provider::Provider as CtProvider;
10use futures_core::Stream;
11use futures_util::StreamExt;
12use rand::Rng;
13use std::time::Duration;
14use wcore::model::{Model, Response, StreamChunk};
15
16/// Unified LLM provider wrapping a crabtalk provider instance.
17#[derive(Clone)]
18pub struct Provider {
19    inner: CtProvider,
20    client: reqwest::Client,
21    model: String,
22    max_retries: u32,
23    timeout: Duration,
24}
25
26impl Provider {
27    /// Get the model name this provider was constructed for.
28    pub fn model_name(&self) -> &String {
29        &self.model
30    }
31}
32
33/// Strip known endpoint suffixes so both bare origins and full paths work.
34fn normalize_base_url(url: &str) -> String {
35    let url = url.trim_end_matches('/');
36    for suffix in ["/chat/completions", "/messages", "/embeddings"] {
37        if let Some(stripped) = url.strip_suffix(suffix) {
38            return stripped.to_string();
39        }
40    }
41    url.to_string()
42}
43
44/// Construct a `Provider` from a provider definition and model name.
45pub fn build_provider(def: &ProviderDef, model: &str, client: reqwest::Client) -> Result<Provider> {
46    let mut config = def.clone();
47    config.kind = config.effective_kind();
48    let mut inner = CtProvider::from(&config);
49
50    // Apply walrus-specific base_url normalization (strip endpoint suffixes).
51    if let CtProvider::OpenAiCompat {
52        ref mut base_url, ..
53    } = inner
54    {
55        *base_url = normalize_base_url(base_url);
56    }
57
58    Ok(Provider {
59        inner,
60        client,
61        model: model.to_owned(),
62        max_retries: def.max_retries.unwrap_or(2),
63        timeout: Duration::from_secs(def.timeout.unwrap_or(30)),
64    })
65}
66
67impl Model for Provider {
68    async fn send(&self, request: &wcore::model::Request) -> Result<Response> {
69        let mut ct_req = convert::to_ct_request(request);
70        ct_req.stream = Some(false);
71        send_with_retry(
72            &self.inner,
73            &self.client,
74            &ct_req,
75            self.max_retries,
76            self.timeout,
77        )
78        .await
79    }
80
81    fn stream(
82        &self,
83        request: wcore::model::Request,
84    ) -> impl Stream<Item = Result<StreamChunk>> + Send {
85        let inner = self.inner.clone();
86        let client = self.client.clone();
87        let timeout = self.timeout;
88        try_stream! {
89            let mut ct_req = convert::to_ct_request(&request);
90            ct_req.stream = Some(true);
91
92            let boxed = tokio::time::timeout(timeout, inner.chat_completion_stream(&client, &ct_req))
93                .await
94                .map_err(|_| anyhow::anyhow!("stream connection timed out"))?
95                .map_err(|e| anyhow::anyhow!("{e}"))?;
96
97            let mut stream = std::pin::pin!(boxed);
98            while let Some(chunk) = stream.next().await {
99                let ct_chunk = chunk.map_err(|e| anyhow::anyhow!("{e}"))?;
100                yield convert::from_ct_chunk(ct_chunk);
101            }
102        }
103    }
104
105    fn context_limit(&self, model: &str) -> usize {
106        wcore::model::default_context_limit(model)
107    }
108
109    fn active_model(&self) -> String {
110        self.model.clone()
111    }
112}
113
114/// Send a non-streaming request with exponential backoff retry on transient errors.
115async fn send_with_retry(
116    provider: &CtProvider,
117    client: &reqwest::Client,
118    request: &crabtalk_core::ChatCompletionRequest,
119    max_retries: u32,
120    timeout: Duration,
121) -> Result<Response> {
122    let mut backoff = Duration::from_millis(100);
123    let mut last_err = None;
124
125    for _ in 0..=max_retries {
126        let result = if timeout.is_zero() {
127            provider.chat_completion(client, request).await
128        } else {
129            tokio::time::timeout(timeout, provider.chat_completion(client, request))
130                .await
131                .map_err(|_| crabtalk_core::Error::Timeout)?
132        };
133
134        match result {
135            Ok(resp) => return Ok(convert::from_ct_response(resp)),
136            Err(e) if e.is_transient() => {
137                last_err = Some(e);
138                let jitter = jittered(backoff);
139                tokio::time::sleep(jitter).await;
140                backoff *= 2;
141            }
142            Err(e) => return Err(anyhow::anyhow!("{e}")),
143        }
144    }
145
146    Err(anyhow::anyhow!("{}", last_err.unwrap()))
147}
148
149/// Full jitter: random duration in [backoff/2, backoff].
150fn jittered(backoff: Duration) -> Duration {
151    let lo = backoff.as_millis() as u64 / 2;
152    let hi = backoff.as_millis() as u64;
153    if lo >= hi {
154        return backoff;
155    }
156    Duration::from_millis(rand::rng().random_range(lo..=hi))
157}