skg_op_single_shot/
lib.rs1#![deny(missing_docs)]
2use async_trait::async_trait;
10use layer0::content::Content;
11use layer0::context::{Message, Role};
12use layer0::duration::DurationMs;
13use layer0::error::OperatorError;
14use layer0::operator::{ExitReason, Operator, OperatorInput, OperatorMetadata, OperatorOutput};
15use skg_turn::infer::InferRequest;
16use skg_turn::provider::Provider;
17use rust_decimal::Decimal;
18use std::time::Instant;
19
20pub struct SingleShotConfig {
22 pub system_prompt: String,
24 pub default_model: String,
26 pub default_max_tokens: u32,
28}
29
30impl Default for SingleShotConfig {
31 fn default() -> Self {
32 Self {
33 system_prompt: String::new(),
34 default_model: String::new(),
35 default_max_tokens: 4096,
36 }
37 }
38}
39
40pub struct SingleShotOperator<P: Provider> {
46 provider: P,
47 config: SingleShotConfig,
48}
49
50impl<P: Provider> SingleShotOperator<P> {
51 pub fn new(provider: P, config: SingleShotConfig) -> Self {
53 Self { provider, config }
54 }
55
56 fn resolve_model(&self, input: &OperatorInput) -> Option<String> {
58 input
59 .config
60 .as_ref()
61 .and_then(|c| c.model.clone())
62 .or_else(|| {
63 if self.config.default_model.is_empty() {
64 None
65 } else {
66 Some(self.config.default_model.clone())
67 }
68 })
69 }
70
71 fn resolve_system(&self, input: &OperatorInput) -> String {
73 match input
74 .config
75 .as_ref()
76 .and_then(|c| c.system_addendum.as_ref())
77 {
78 Some(addendum) => format!("{}\n{}", self.config.system_prompt, addendum),
79 None => self.config.system_prompt.clone(),
80 }
81 }
82}
83
84#[async_trait]
85impl<P: Provider + 'static> Operator for SingleShotOperator<P> {
86 #[tracing::instrument(skip_all, fields(trigger = ?input.trigger))]
87 async fn execute(&self, input: OperatorInput) -> Result<OperatorOutput, OperatorError> {
88 let start = Instant::now();
89 tracing::info!("single-shot executing");
90
91 let model = self.resolve_model(&input);
92 let system = self.resolve_system(&input);
93 let max_tokens = self.config.default_max_tokens;
94
95 let user_msg = Message::new(Role::User, input.message.clone());
97
98 let mut request = InferRequest::new(vec![user_msg]);
100 if let Some(m) = model {
101 request = request.with_model(m);
102 }
103 if !system.is_empty() {
104 request = request.with_system(system);
105 }
106 request = request
107 .with_max_tokens(max_tokens)
108 .with_extra(input.metadata.clone());
109
110 let response = self.provider.infer(request).await.map_err(|e| {
112 if e.is_retryable() {
113 OperatorError::Retryable(e.to_string())
114 } else {
115 OperatorError::Model(e.to_string())
116 }
117 })?;
118
119 let duration = DurationMs::from(start.elapsed());
120
121 let mut metadata = OperatorMetadata::default();
123 metadata.tokens_in = response.usage.input_tokens;
124 metadata.tokens_out = response.usage.output_tokens;
125 metadata.cost = response.cost.unwrap_or(Decimal::ZERO);
126 metadata.turns_used = 1;
127 metadata.sub_dispatches = vec![];
128 metadata.duration = duration;
129
130 let message: Content = response.content;
132
133 let mut output = OperatorOutput::new(message, ExitReason::Complete);
135 output.metadata = metadata;
136 output.effects = vec![];
137
138 Ok(output)
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145 use skg_turn::infer::InferResponse;
146 use skg_turn::test_utils::{TestProvider, error_provider_rate_limited, make_text_response};
147 use skg_turn::types::{StopReason, TokenUsage};
148 use std::sync::Arc;
149
150 fn simple_input(text: &str) -> OperatorInput {
153 OperatorInput::new(Content::text(text), layer0::operator::TriggerType::User)
154 }
155
156 fn make_op(provider: TestProvider) -> SingleShotOperator<TestProvider> {
157 SingleShotOperator::new(provider, SingleShotConfig::default())
158 }
159
160 #[tokio::test]
163 async fn single_shot_returns_completion() {
164 let provider = TestProvider::with_responses(vec![make_text_response("Hello!")]);
165 let op = make_op(provider);
166
167 let output = op.execute(simple_input("Hi")).await.unwrap();
168
169 assert_eq!(output.exit_reason, ExitReason::Complete);
170 assert_eq!(output.message.as_text().unwrap(), "Hello!");
171 }
172
173 #[tokio::test]
174 async fn single_shot_always_one_turn() {
175 let provider = TestProvider::with_responses(vec![make_text_response("Response")]);
176 let op = make_op(provider);
177
178 let output = op.execute(simple_input("Query")).await.unwrap();
179
180 assert_eq!(output.metadata.turns_used, 1);
181 }
182
183 #[tokio::test]
184 async fn single_shot_no_tools_in_request() {
185 let provider = TestProvider::with_responses(vec![make_text_response("Done")]);
186 let op = make_op(provider);
187
188 op.execute(simple_input("Test")).await.unwrap();
189
190 let requests = op.provider.requests();
191 assert_eq!(requests.len(), 1);
192 assert!(
193 requests[0].tools.is_empty(),
194 "single-shot must send no tools"
195 );
196 }
197
198 #[tokio::test]
199 async fn single_shot_rate_limit_maps_to_retryable() {
200 let provider = error_provider_rate_limited();
201 let op = SingleShotOperator::new(provider, SingleShotConfig::default());
202
203 let result = op.execute(simple_input("test")).await;
204 assert!(matches!(result, Err(OperatorError::Retryable(_))));
205 }
206
207 #[tokio::test]
208 async fn single_shot_cost_passed_through() {
209 let cost = Decimal::new(42, 4); let response = InferResponse {
211 content: Content::text("result"),
212 tool_calls: vec![],
213 stop_reason: StopReason::EndTurn,
214 usage: TokenUsage {
215 input_tokens: 100,
216 output_tokens: 50,
217 ..Default::default()
218 },
219 model: "mock".into(),
220 cost: Some(cost),
221 truncated: None,
222 };
223 let provider = TestProvider::with_responses(vec![response]);
224 let op = make_op(provider);
225
226 let output = op.execute(simple_input("test")).await.unwrap();
227
228 assert_eq!(output.metadata.cost, cost);
229 assert_eq!(output.metadata.tokens_in, 100);
230 assert_eq!(output.metadata.tokens_out, 50);
231 }
232
233 #[tokio::test]
234 async fn single_shot_as_arc_dyn_operator() {
235 let provider = TestProvider::with_responses(vec![make_text_response("Hello!")]);
236 let op: Arc<dyn Operator> = Arc::new(SingleShotOperator::new(
237 provider,
238 SingleShotConfig::default(),
239 ));
240
241 let output = Operator::execute(op.as_ref(), simple_input("Hi"))
242 .await
243 .unwrap();
244 assert_eq!(output.exit_reason, ExitReason::Complete);
245 }
246}