Skip to main content

skg_op_single_shot/
lib.rs

1#![deny(missing_docs)]
2//! Single-shot operator — one model call, no tools, return immediately.
3//!
4//! Implements `layer0::Operator` for the simplest case: send a single
5//! prompt to a model and return the result. No tool use, no ReAct loop,
6//! no hooks, no state reader. Used for classification, summarization,
7//! extraction, and other single-inference tasks.
8
9use async_trait::async_trait;
10use layer0::content::Content;
11use layer0::context::{Message, Role};
12use layer0::duration::DurationMs;
13use layer0::error::OperatorError;
14use layer0::operator::{ExitReason, Operator, OperatorInput, OperatorMetadata, OperatorOutput};
15use skg_turn::infer::InferRequest;
16use skg_turn::provider::Provider;
17use rust_decimal::Decimal;
18use std::time::Instant;
19
20/// Static configuration for a SingleShotOperator instance.
21pub struct SingleShotConfig {
22    /// Base system prompt.
23    pub system_prompt: String,
24    /// Default model identifier.
25    pub default_model: String,
26    /// Default max tokens per response.
27    pub default_max_tokens: u32,
28}
29
30impl Default for SingleShotConfig {
31    fn default() -> Self {
32        Self {
33            system_prompt: String::new(),
34            default_model: String::new(),
35            default_max_tokens: 4096,
36        }
37    }
38}
39
40/// A single-shot Operator: one model call, no tools, return immediately.
41///
42/// Generic over `P: Provider` (not object-safe). The object-safe boundary
43/// is `layer0::Operator`, which `SingleShotOperator<P>` implements via
44/// `#[async_trait]`.
45pub struct SingleShotOperator<P: Provider> {
46    provider: P,
47    config: SingleShotConfig,
48}
49
50impl<P: Provider> SingleShotOperator<P> {
51    /// Create a new SingleShotOperator with a provider and configuration.
52    pub fn new(provider: P, config: SingleShotConfig) -> Self {
53        Self { provider, config }
54    }
55
56    /// Resolve model and max_tokens from per-request overrides or defaults.
57    fn resolve_model(&self, input: &OperatorInput) -> Option<String> {
58        input
59            .config
60            .as_ref()
61            .and_then(|c| c.model.clone())
62            .or_else(|| {
63                if self.config.default_model.is_empty() {
64                    None
65                } else {
66                    Some(self.config.default_model.clone())
67                }
68            })
69    }
70
71    /// Resolve the system prompt, appending any per-request addendum.
72    fn resolve_system(&self, input: &OperatorInput) -> String {
73        match input
74            .config
75            .as_ref()
76            .and_then(|c| c.system_addendum.as_ref())
77        {
78            Some(addendum) => format!("{}\n{}", self.config.system_prompt, addendum),
79            None => self.config.system_prompt.clone(),
80        }
81    }
82}
83
84#[async_trait]
85impl<P: Provider + 'static> Operator for SingleShotOperator<P> {
86    #[tracing::instrument(skip_all, fields(trigger = ?input.trigger))]
87    async fn execute(&self, input: OperatorInput) -> Result<OperatorOutput, OperatorError> {
88        let start = Instant::now();
89        tracing::info!("single-shot executing");
90
91        let model = self.resolve_model(&input);
92        let system = self.resolve_system(&input);
93        let max_tokens = self.config.default_max_tokens;
94
95        // Build single user message from trigger content
96        let user_msg = Message::new(Role::User, input.message.clone());
97
98        // Build inference request
99        let mut request = InferRequest::new(vec![user_msg]);
100        if let Some(m) = model {
101            request = request.with_model(m);
102        }
103        if !system.is_empty() {
104            request = request.with_system(system);
105        }
106        request = request
107            .with_max_tokens(max_tokens)
108            .with_extra(input.metadata.clone());
109
110        // Single model call
111        let response = self.provider.infer(request).await.map_err(|e| {
112            if e.is_retryable() {
113                OperatorError::Retryable(e.to_string())
114            } else {
115                OperatorError::Model(e.to_string())
116            }
117        })?;
118
119        let duration = DurationMs::from(start.elapsed());
120
121        // Build metadata
122        let mut metadata = OperatorMetadata::default();
123        metadata.tokens_in = response.usage.input_tokens;
124        metadata.tokens_out = response.usage.output_tokens;
125        metadata.cost = response.cost.unwrap_or(Decimal::ZERO);
126        metadata.turns_used = 1;
127        metadata.sub_dispatches = vec![];
128        metadata.duration = duration;
129
130        // Response content is already layer0 Content
131        let message: Content = response.content;
132
133        // Always ExitReason::Complete for single-shot
134        let mut output = OperatorOutput::new(message, ExitReason::Complete);
135        output.metadata = metadata;
136        output.effects = vec![];
137
138        Ok(output)
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145    use skg_turn::infer::InferResponse;
146    use skg_turn::test_utils::{TestProvider, error_provider_rate_limited, make_text_response};
147    use skg_turn::types::{StopReason, TokenUsage};
148    use std::sync::Arc;
149
150    // -- Helpers --
151
152    fn simple_input(text: &str) -> OperatorInput {
153        OperatorInput::new(Content::text(text), layer0::operator::TriggerType::User)
154    }
155
156    fn make_op(provider: TestProvider) -> SingleShotOperator<TestProvider> {
157        SingleShotOperator::new(provider, SingleShotConfig::default())
158    }
159
160    // -- Tests --
161
162    #[tokio::test]
163    async fn single_shot_returns_completion() {
164        let provider = TestProvider::with_responses(vec![make_text_response("Hello!")]);
165        let op = make_op(provider);
166
167        let output = op.execute(simple_input("Hi")).await.unwrap();
168
169        assert_eq!(output.exit_reason, ExitReason::Complete);
170        assert_eq!(output.message.as_text().unwrap(), "Hello!");
171    }
172
173    #[tokio::test]
174    async fn single_shot_always_one_turn() {
175        let provider = TestProvider::with_responses(vec![make_text_response("Response")]);
176        let op = make_op(provider);
177
178        let output = op.execute(simple_input("Query")).await.unwrap();
179
180        assert_eq!(output.metadata.turns_used, 1);
181    }
182
183    #[tokio::test]
184    async fn single_shot_no_tools_in_request() {
185        let provider = TestProvider::with_responses(vec![make_text_response("Done")]);
186        let op = make_op(provider);
187
188        op.execute(simple_input("Test")).await.unwrap();
189
190        let requests = op.provider.requests();
191        assert_eq!(requests.len(), 1);
192        assert!(
193            requests[0].tools.is_empty(),
194            "single-shot must send no tools"
195        );
196    }
197
198    #[tokio::test]
199    async fn single_shot_rate_limit_maps_to_retryable() {
200        let provider = error_provider_rate_limited();
201        let op = SingleShotOperator::new(provider, SingleShotConfig::default());
202
203        let result = op.execute(simple_input("test")).await;
204        assert!(matches!(result, Err(OperatorError::Retryable(_))));
205    }
206
207    #[tokio::test]
208    async fn single_shot_cost_passed_through() {
209        let cost = Decimal::new(42, 4); // $0.0042
210        let response = InferResponse {
211            content: Content::text("result"),
212            tool_calls: vec![],
213            stop_reason: StopReason::EndTurn,
214            usage: TokenUsage {
215                input_tokens: 100,
216                output_tokens: 50,
217                ..Default::default()
218            },
219            model: "mock".into(),
220            cost: Some(cost),
221            truncated: None,
222        };
223        let provider = TestProvider::with_responses(vec![response]);
224        let op = make_op(provider);
225
226        let output = op.execute(simple_input("test")).await.unwrap();
227
228        assert_eq!(output.metadata.cost, cost);
229        assert_eq!(output.metadata.tokens_in, 100);
230        assert_eq!(output.metadata.tokens_out, 50);
231    }
232
233    #[tokio::test]
234    async fn single_shot_as_arc_dyn_operator() {
235        let provider = TestProvider::with_responses(vec![make_text_response("Hello!")]);
236        let op: Arc<dyn Operator> = Arc::new(SingleShotOperator::new(
237            provider,
238            SingleShotConfig::default(),
239        ));
240
241        let output = Operator::execute(op.as_ref(), simple_input("Hi"))
242            .await
243            .unwrap();
244        assert_eq!(output.exit_reason, ExitReason::Complete);
245    }
246}