synwire_core/agents/sampling.rs
1//! Sampling provider trait for tool-internal LLM access.
2//!
3//! Allows tools and middleware to request LLM completions without taking
4//! a hard dependency on a specific model or MCP transport. Zero LLM calls
5//! happen during indexing — sampling is only invoked when explicitly needed
6//! (e.g. community summary generation, hierarchical narrowing ranking).
7
8use crate::BoxFuture;
9
10// ---------------------------------------------------------------------------
11// Request / Response types
12// ---------------------------------------------------------------------------
13
14/// A request to the LLM for a text completion.
15#[derive(Debug, Clone)]
16#[non_exhaustive]
17pub struct SamplingRequest {
18 /// Optional system message.
19 pub system: Option<String>,
20 /// User message content.
21 pub prompt: String,
22 /// Maximum tokens to generate.
23 pub max_tokens: Option<u32>,
24 /// Temperature (0.0–1.0).
25 pub temperature: Option<f32>,
26}
27
28impl SamplingRequest {
29 /// Create a simple prompt-only request.
30 ///
31 /// # Examples
32 ///
33 /// ```
34 /// use synwire_core::agents::sampling::SamplingRequest;
35 ///
36 /// let req = SamplingRequest::new("Summarise this code.");
37 /// assert_eq!(req.prompt, "Summarise this code.");
38 /// assert!(req.system.is_none());
39 /// ```
40 pub fn new(prompt: impl Into<String>) -> Self {
41 Self {
42 system: None,
43 prompt: prompt.into(),
44 max_tokens: None,
45 temperature: None,
46 }
47 }
48
49 /// Set the system message.
50 ///
51 /// # Examples
52 ///
53 /// ```
54 /// use synwire_core::agents::sampling::SamplingRequest;
55 ///
56 /// let req = SamplingRequest::new("Hello").with_system("You are a helpful assistant.");
57 /// assert!(req.system.is_some());
58 /// ```
59 #[must_use]
60 pub fn with_system(mut self, system: impl Into<String>) -> Self {
61 self.system = Some(system.into());
62 self
63 }
64
65 /// Set the maximum number of tokens to generate.
66 #[must_use]
67 pub const fn with_max_tokens(mut self, max_tokens: u32) -> Self {
68 self.max_tokens = Some(max_tokens);
69 self
70 }
71
72 /// Set the sampling temperature.
73 #[must_use]
74 pub const fn with_temperature(mut self, temperature: f32) -> Self {
75 self.temperature = Some(temperature);
76 self
77 }
78}
79
80/// A response from the LLM.
81#[derive(Debug, Clone)]
82#[non_exhaustive]
83pub struct SamplingResponse {
84 /// Generated text.
85 pub text: String,
86 /// Stop reason (`"end_turn"`, `"max_tokens"`, etc.)
87 pub stop_reason: String,
88}
89
90// ---------------------------------------------------------------------------
91// Error
92// ---------------------------------------------------------------------------
93
94/// Error from a sampling call.
95#[derive(Debug, thiserror::Error)]
96#[non_exhaustive]
97pub enum SamplingError {
98 /// Provider not configured.
99 #[error("no sampling provider configured")]
100 NotAvailable,
101 /// The model refused the request.
102 #[error("model refused: {0}")]
103 Refused(String),
104 /// The sampling call timed out.
105 #[error("sampling timed out")]
106 Timeout,
107 /// Any other error.
108 #[error("sampling error: {0}")]
109 Other(String),
110}
111
112// ---------------------------------------------------------------------------
113// Trait
114// ---------------------------------------------------------------------------
115
116/// Provides LLM sampling for tool-internal use.
117///
118/// Implementations include MCP `sampling/createMessage` delegation and
119/// direct model invocation.
120///
121/// The `sample` method returns a [`BoxFuture`] so that the trait remains
122/// object-safe and can be used as `dyn SamplingProvider`.
123pub trait SamplingProvider: Send + Sync {
124 /// Returns `true` if sampling is available (provider configured).
125 fn is_available(&self) -> bool;
126
127 /// Request a completion from the LLM.
128 ///
129 /// Returns a boxed future to preserve object safety.
130 fn sample(
131 &self,
132 request: SamplingRequest,
133 ) -> BoxFuture<'_, Result<SamplingResponse, SamplingError>>;
134}
135
136// ---------------------------------------------------------------------------
137// No-op implementation
138// ---------------------------------------------------------------------------
139
140/// A sampling provider that is always unavailable.
141///
142/// Used as a default when no provider is configured. All calls return
143/// [`SamplingError::NotAvailable`], enabling graceful degradation in callers.
144///
145/// # Examples
146///
147/// ```
148/// use synwire_core::agents::sampling::{NoopSamplingProvider, SamplingProvider, SamplingRequest};
149///
150/// let p = NoopSamplingProvider;
151/// assert!(!p.is_available());
152/// ```
153pub struct NoopSamplingProvider;
154
155impl SamplingProvider for NoopSamplingProvider {
156 fn is_available(&self) -> bool {
157 false
158 }
159
160 fn sample(
161 &self,
162 _request: SamplingRequest,
163 ) -> BoxFuture<'_, Result<SamplingResponse, SamplingError>> {
164 Box::pin(async { Err(SamplingError::NotAvailable) })
165 }
166}
167
168// ---------------------------------------------------------------------------
169// Tests
170// ---------------------------------------------------------------------------
171
172#[cfg(test)]
173#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
174mod tests {
175 use super::*;
176
177 #[tokio::test]
178 async fn noop_provider_returns_not_available() {
179 let p = NoopSamplingProvider;
180 assert!(!p.is_available());
181 let result = p.sample(SamplingRequest::new("test")).await;
182 assert!(matches!(result, Err(SamplingError::NotAvailable)));
183 }
184
185 #[test]
186 fn sampling_request_builder() {
187 let req = SamplingRequest::new("hello")
188 .with_system("sys")
189 .with_max_tokens(100)
190 .with_temperature(0.7);
191 assert_eq!(req.prompt, "hello");
192 assert_eq!(req.system.as_deref(), Some("sys"));
193 assert_eq!(req.max_tokens, Some(100));
194 assert!((req.temperature.unwrap_or(0.0) - 0.7).abs() < f32::EPSILON);
195 }
196
197 #[tokio::test]
198 async fn noop_provider_is_object_safe() {
199 let p: &dyn SamplingProvider = &NoopSamplingProvider;
200 assert!(!p.is_available());
201 let result = p.sample(SamplingRequest::new("test")).await;
202 assert!(matches!(result, Err(SamplingError::NotAvailable)));
203 }
204}