Skip to main content

synwire_core/agents/
sampling.rs

1//! Sampling provider trait for tool-internal LLM access.
2//!
3//! Allows tools and middleware to request LLM completions without taking
4//! a hard dependency on a specific model or MCP transport. Zero LLM calls
5//! happen during indexing — sampling is only invoked when explicitly needed
6//! (e.g. community summary generation, hierarchical narrowing ranking).
7
8use crate::BoxFuture;
9
10// ---------------------------------------------------------------------------
11// Request / Response types
12// ---------------------------------------------------------------------------
13
14/// A request to the LLM for a text completion.
15#[derive(Debug, Clone)]
16#[non_exhaustive]
17pub struct SamplingRequest {
18    /// Optional system message.
19    pub system: Option<String>,
20    /// User message content.
21    pub prompt: String,
22    /// Maximum tokens to generate.
23    pub max_tokens: Option<u32>,
24    /// Temperature (0.0–1.0).
25    pub temperature: Option<f32>,
26}
27
28impl SamplingRequest {
29    /// Create a simple prompt-only request.
30    ///
31    /// # Examples
32    ///
33    /// ```
34    /// use synwire_core::agents::sampling::SamplingRequest;
35    ///
36    /// let req = SamplingRequest::new("Summarise this code.");
37    /// assert_eq!(req.prompt, "Summarise this code.");
38    /// assert!(req.system.is_none());
39    /// ```
40    pub fn new(prompt: impl Into<String>) -> Self {
41        Self {
42            system: None,
43            prompt: prompt.into(),
44            max_tokens: None,
45            temperature: None,
46        }
47    }
48
49    /// Set the system message.
50    ///
51    /// # Examples
52    ///
53    /// ```
54    /// use synwire_core::agents::sampling::SamplingRequest;
55    ///
56    /// let req = SamplingRequest::new("Hello").with_system("You are a helpful assistant.");
57    /// assert!(req.system.is_some());
58    /// ```
59    #[must_use]
60    pub fn with_system(mut self, system: impl Into<String>) -> Self {
61        self.system = Some(system.into());
62        self
63    }
64
65    /// Set the maximum number of tokens to generate.
66    #[must_use]
67    pub const fn with_max_tokens(mut self, max_tokens: u32) -> Self {
68        self.max_tokens = Some(max_tokens);
69        self
70    }
71
72    /// Set the sampling temperature.
73    #[must_use]
74    pub const fn with_temperature(mut self, temperature: f32) -> Self {
75        self.temperature = Some(temperature);
76        self
77    }
78}
79
80/// A response from the LLM.
81#[derive(Debug, Clone)]
82#[non_exhaustive]
83pub struct SamplingResponse {
84    /// Generated text.
85    pub text: String,
86    /// Stop reason (`"end_turn"`, `"max_tokens"`, etc.)
87    pub stop_reason: String,
88}
89
90// ---------------------------------------------------------------------------
91// Error
92// ---------------------------------------------------------------------------
93
94/// Error from a sampling call.
95#[derive(Debug, thiserror::Error)]
96#[non_exhaustive]
97pub enum SamplingError {
98    /// Provider not configured.
99    #[error("no sampling provider configured")]
100    NotAvailable,
101    /// The model refused the request.
102    #[error("model refused: {0}")]
103    Refused(String),
104    /// The sampling call timed out.
105    #[error("sampling timed out")]
106    Timeout,
107    /// Any other error.
108    #[error("sampling error: {0}")]
109    Other(String),
110}
111
112// ---------------------------------------------------------------------------
113// Trait
114// ---------------------------------------------------------------------------
115
116/// Provides LLM sampling for tool-internal use.
117///
118/// Implementations include MCP `sampling/createMessage` delegation and
119/// direct model invocation.
120///
121/// The `sample` method returns a [`BoxFuture`] so that the trait remains
122/// object-safe and can be used as `dyn SamplingProvider`.
123pub trait SamplingProvider: Send + Sync {
124    /// Returns `true` if sampling is available (provider configured).
125    fn is_available(&self) -> bool;
126
127    /// Request a completion from the LLM.
128    ///
129    /// Returns a boxed future to preserve object safety.
130    fn sample(
131        &self,
132        request: SamplingRequest,
133    ) -> BoxFuture<'_, Result<SamplingResponse, SamplingError>>;
134}
135
136// ---------------------------------------------------------------------------
137// No-op implementation
138// ---------------------------------------------------------------------------
139
140/// A sampling provider that is always unavailable.
141///
142/// Used as a default when no provider is configured. All calls return
143/// [`SamplingError::NotAvailable`], enabling graceful degradation in callers.
144///
145/// # Examples
146///
147/// ```
148/// use synwire_core::agents::sampling::{NoopSamplingProvider, SamplingProvider, SamplingRequest};
149///
150/// let p = NoopSamplingProvider;
151/// assert!(!p.is_available());
152/// ```
153pub struct NoopSamplingProvider;
154
155impl SamplingProvider for NoopSamplingProvider {
156    fn is_available(&self) -> bool {
157        false
158    }
159
160    fn sample(
161        &self,
162        _request: SamplingRequest,
163    ) -> BoxFuture<'_, Result<SamplingResponse, SamplingError>> {
164        Box::pin(async { Err(SamplingError::NotAvailable) })
165    }
166}
167
168// ---------------------------------------------------------------------------
169// Tests
170// ---------------------------------------------------------------------------
171
172#[cfg(test)]
173#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
174mod tests {
175    use super::*;
176
177    #[tokio::test]
178    async fn noop_provider_returns_not_available() {
179        let p = NoopSamplingProvider;
180        assert!(!p.is_available());
181        let result = p.sample(SamplingRequest::new("test")).await;
182        assert!(matches!(result, Err(SamplingError::NotAvailable)));
183    }
184
185    #[test]
186    fn sampling_request_builder() {
187        let req = SamplingRequest::new("hello")
188            .with_system("sys")
189            .with_max_tokens(100)
190            .with_temperature(0.7);
191        assert_eq!(req.prompt, "hello");
192        assert_eq!(req.system.as_deref(), Some("sys"));
193        assert_eq!(req.max_tokens, Some(100));
194        assert!((req.temperature.unwrap_or(0.0) - 0.7).abs() < f32::EPSILON);
195    }
196
197    #[tokio::test]
198    async fn noop_provider_is_object_safe() {
199        let p: &dyn SamplingProvider = &NoopSamplingProvider;
200        assert!(!p.is_available());
201        let result = p.sample(SamplingRequest::new("test")).await;
202        assert!(matches!(result, Err(SamplingError::NotAvailable)));
203    }
204}