smol_workflow_engine/agent_providers/
pi.rs1use super::common::*;
2use super::types::*;
3use anyhow::bail;
4use serde_json::{json, Value};
5use std::collections::HashMap;
6use std::fs;
7use std::path::PathBuf;
8
9#[derive(Debug, Clone, Default)]
10pub struct PiAgentProviderOptions {
11 pub command: Option<String>,
12 pub subcommand: Vec<String>,
13 pub args: Vec<String>,
14 pub cwd: Option<PathBuf>,
15 pub env: HashMap<String, String>,
16 pub timeout_ms: Option<u64>,
17}
18
19#[derive(Debug, Clone, Default)]
20pub struct PiAgentProvider {
21 options: PiAgentProviderOptions,
22}
23impl PiAgentProvider {
24 pub fn new(options: PiAgentProviderOptions) -> Self {
25 Self { options }
26 }
27}
28
29#[async_trait::async_trait]
30impl AgentProvider for PiAgentProvider {
31 fn name(&self) -> &str {
32 "pi"
33 }
34 fn schema_mode(&self) -> AgentProviderSchemaMode {
35 AgentProviderSchemaMode::Builtin
36 }
37 fn usage_mode(&self) -> AgentProviderUsageMode {
38 AgentProviderUsageMode::Builtin
39 }
40 async fn run(&self, input: AgentProviderRunInput) -> anyhow::Result<AgentProviderResult> {
41 run_pi(input, &self.options).await
42 }
43}
44
45async fn run_pi(
46 input: AgentProviderRunInput,
47 options: &PiAgentProviderOptions,
48) -> anyhow::Result<AgentProviderResult> {
49 let command = options.command.as_deref().unwrap_or("pi");
50 let has_schema = option_schema(&input.options).is_some();
51 let prompt = if has_schema {
52 with_structured_output_tool_instruction(&input.prompt)
53 } else {
54 input.prompt.clone()
55 };
56 let temp = temp_dir("smol-wf-pi-")?;
57 let extension_path = has_schema.then(|| temp.path().join("structured-output-extension.ts"));
58 let prompt_path = temp.path().join("prompt.md");
59
60 if let Some(path) = &extension_path {
61 fs::write(
62 path,
63 build_structured_output_extension(option_schema(&input.options).unwrap()),
64 )?;
65 }
66 fs::write(&prompt_path, &prompt)?;
67
68 let prompt_arg = format!("@{}", prompt_path.to_string_lossy());
69
70 let mut args = Vec::new();
71 args.extend(options.subcommand.clone());
72 args.extend(options.args.clone());
73 if let Some(path) = &extension_path {
74 args.extend(["--extension".into(), path.to_string_lossy().into_owned()]);
75 }
76 args.extend(["--print".into(), "--mode".into(), "json".into()]);
77 if let Some(model) = option_str(&input.options, "model") {
78 args.extend(["--model".into(), model]);
79 }
80 if let Some(thinking) = option_str(&input.options, "thinking") {
81 args.extend(["--thinking".into(), thinking]);
82 }
83 args.push(prompt_arg);
84
85 let cwd = input.context.cwd.as_deref().or(options.cwd.as_deref());
86 let (stdout, stderr) = run_command(
87 "Pi",
88 command,
89 &args,
90 None,
91 cwd,
92 &options.env,
93 options.timeout_ms,
94 )
95 .await?;
96 let events = parse_json_lines(&stdout);
97 let output = if has_schema {
98 extract_structured_tool_output(&events)?
99 } else {
100 let candidate = extract_output(&events).ok_or_else(|| {
101 let message = extract_error_message(&events)
102 .or_else(|| (!stderr.trim().is_empty()).then(|| stderr.trim().to_string()))
103 .unwrap_or_else(|| "Pi provider did not return assistant output".to_string());
104 anyhow::anyhow!(message)
105 })?;
106 Value::String(candidate.trim_end().to_string())
107 };
108 let session_id = extract_session_id(&events)
109 .ok_or_else(|| anyhow::anyhow!("Pi provider response did not include a session id"))?;
110
111 Ok(AgentProviderResult {
112 output,
113 session_id: Some(session_id),
114 model: extract_model(&Value::Array(events.clone()))
115 .or_else(|| option_model(&input.options)),
116 usage: extract_usage(&events),
117 isolation: None,
118 raw: Some(to_json_value(
119 json!({ "events": events, "stderr": stderr, "extensionPath": extension_path.map(|p| p.to_string_lossy().into_owned()) }),
120 )),
121 })
122}
123
124fn with_structured_output_tool_instruction(prompt: &str) -> String {
125 [
126 prompt,
127 "",
128 "Use the smol_workflows_structured_output tool as your final action exactly once.",
129 "Do not emit a final assistant message after calling smol_workflows_structured_output.",
130 ]
131 .join("\n")
132}
133
134fn build_structured_output_extension(schema: &Value) -> String {
135 let wrapped = !schema
136 .as_object()
137 .is_some_and(|o| o.get("type") == Some(&Value::String("object".into())));
138 let parameters = if wrapped {
139 format!(
140 "Type.Object({{ value: {} }})",
141 json_schema_to_typebox_expression(schema)
142 )
143 } else {
144 json_schema_to_typebox_expression(schema)
145 };
146 let details = if wrapped { "params.value" } else { "params" };
147 format!(
148 r#"import {{ defineTool, type ExtensionAPI }} from "@earendil-works/pi-coding-agent";
149import {{ Type }} from "typebox";
150
151const structuredOutputTool = defineTool({{
152 name: "smol_workflows_structured_output",
153 label: "Structured Output",
154 description: "Submit the final structured response for this agent call.",
155 promptSnippet: "Submit the final structured response with the smol_workflows_structured_output tool.",
156 promptGuidelines: [
157 "Use smol_workflows_structured_output as your final action exactly once.",
158 "The tool parameters are generated from the caller's JSON Schema.",
159 "After calling smol_workflows_structured_output, do not emit another assistant response in the same turn.",
160 ],
161 parameters: {parameters},
162 async execute(_toolCallId, params) {{
163 return {{
164 content: [{{ type: "text", text: "Structured output captured successfully." }}],
165 details: {details},
166 terminate: true,
167 }};
168 }},
169}});
170
171export default function (pi: ExtensionAPI) {{
172 pi.registerTool(structuredOutputTool);
173}}
174"#
175 )
176}
177
178fn json_schema_to_typebox_expression(schema: &Value) -> String {
179 match schema {
180 Value::Bool(true) => "Type.Any()".into(),
181 Value::Bool(false) => "Type.Never()".into(),
182 Value::Object(record) => {
183 if let Some(value) = record.get("const") {
184 return format!("Type.Literal({})", serde_json::to_string(value).unwrap());
185 }
186 if let Some(values) = record.get("enum").and_then(Value::as_array) {
187 if !values.is_empty() {
188 return if values.len() == 1 {
189 format!(
190 "Type.Literal({})",
191 serde_json::to_string(&values[0]).unwrap()
192 )
193 } else {
194 format!(
195 "Type.Union([{}])",
196 values
197 .iter()
198 .map(|v| format!(
199 "Type.Literal({})",
200 serde_json::to_string(v).unwrap()
201 ))
202 .collect::<Vec<_>>()
203 .join(", ")
204 )
205 };
206 }
207 }
208 for key in ["oneOf", "anyOf"] {
209 if let Some(values) = record.get(key).and_then(Value::as_array) {
210 if !values.is_empty() {
211 return format!(
212 "Type.Union([{}])",
213 values
214 .iter()
215 .map(json_schema_to_typebox_expression)
216 .collect::<Vec<_>>()
217 .join(", ")
218 );
219 }
220 }
221 }
222 match first_schema_type(record.get("type")).or_else(|| infer_schema_type(record)) {
223 Some("null") => "Type.Null()".into(),
224 Some("boolean") => format!("Type.Boolean({})", typebox_options(record)),
225 Some("integer") => format!("Type.Integer({})", typebox_options(record)),
226 Some("number") => format!("Type.Number({})", typebox_options(record)),
227 Some("string") => format!("Type.String({})", typebox_options(record)),
228 Some("array") => array_schema_to_typebox_expression(record),
229 Some("object") => object_schema_to_typebox_expression(record),
230 _ => "Type.Any()".into(),
231 }
232 }
233 _ => "Type.Any()".into(),
234 }
235}
236
237fn object_schema_to_typebox_expression(schema: &serde_json::Map<String, Value>) -> String {
238 let properties = schema
239 .get("properties")
240 .and_then(Value::as_object)
241 .cloned()
242 .unwrap_or_default();
243 let required = schema
244 .get("required")
245 .and_then(Value::as_array)
246 .map(|items| items.iter().filter_map(Value::as_str).collect::<Vec<_>>())
247 .unwrap_or_default();
248 let entries = properties
249 .iter()
250 .map(|(key, value)| {
251 let expression = json_schema_to_typebox_expression(value);
252 if required.iter().any(|required| required == key) {
253 format!("{}: {}", serde_json::to_string(key).unwrap(), expression)
254 } else {
255 format!(
256 "{}: Type.Optional({})",
257 serde_json::to_string(key).unwrap(),
258 expression
259 )
260 }
261 })
262 .collect::<Vec<_>>()
263 .join(", ");
264 format!("Type.Object({{ {entries} }}, {})", typebox_options(schema))
265}
266
267fn array_schema_to_typebox_expression(schema: &serde_json::Map<String, Value>) -> String {
268 let options = typebox_options(schema);
269 if let Some(prefix_items) = schema.get("prefixItems").and_then(Value::as_array) {
270 if !prefix_items.is_empty() {
271 return format!(
272 "Type.Tuple([{}], {options})",
273 prefix_items
274 .iter()
275 .map(json_schema_to_typebox_expression)
276 .collect::<Vec<_>>()
277 .join(", ")
278 );
279 }
280 }
281 let item_schema = schema
282 .get("items")
283 .filter(|value| !value.is_array())
284 .unwrap_or(&Value::Bool(true));
285 format!(
286 "Type.Array({}, {options})",
287 json_schema_to_typebox_expression(item_schema)
288 )
289}
290
291fn typebox_options(schema: &serde_json::Map<String, Value>) -> String {
292 let option_keys = [
293 "title",
294 "description",
295 "default",
296 "examples",
297 "minimum",
298 "maximum",
299 "exclusiveMinimum",
300 "exclusiveMaximum",
301 "multipleOf",
302 "minLength",
303 "maxLength",
304 "pattern",
305 "format",
306 "minItems",
307 "maxItems",
308 "uniqueItems",
309 "additionalProperties",
310 ];
311 let mut options = serde_json::Map::new();
312 for key in option_keys {
313 if let Some(value) = schema.get(key) {
314 options.insert(key.to_string(), value.clone());
315 }
316 }
317 serde_json::to_string(&Value::Object(options)).unwrap()
318}
319
320fn first_schema_type(value: Option<&Value>) -> Option<&str> {
321 match value {
322 Some(Value::String(value)) => Some(value),
323 Some(Value::Array(values)) => values.iter().find_map(Value::as_str),
324 _ => None,
325 }
326}
327
328fn infer_schema_type(schema: &serde_json::Map<String, Value>) -> Option<&'static str> {
329 if schema.contains_key("properties")
330 || schema.contains_key("required")
331 || schema.contains_key("additionalProperties")
332 {
333 Some("object")
334 } else if schema.contains_key("items") || schema.contains_key("prefixItems") {
335 Some("array")
336 } else if schema.contains_key("minimum")
337 || schema.contains_key("maximum")
338 || schema.contains_key("multipleOf")
339 {
340 Some("number")
341 } else if schema.contains_key("minLength")
342 || schema.contains_key("maxLength")
343 || schema.contains_key("pattern")
344 || schema.contains_key("format")
345 {
346 Some("string")
347 } else {
348 None
349 }
350}
351
352fn extract_structured_tool_output(events: &[Value]) -> anyhow::Result<Value> {
353 let mut output = None;
354 let mut recovered_output = None;
355 let mut started_args = HashMap::<String, Value>::new();
356 let mut calls = 0;
357 let mut successes = 0;
358 let mut errors = 0;
359
360 for event in events {
361 let Some(record) = event.as_object() else {
362 continue;
363 };
364 if record.get("toolName").and_then(Value::as_str)
365 != Some("smol_workflows_structured_output")
366 {
367 continue;
368 }
369
370 if record.get("type").and_then(Value::as_str) == Some("tool_execution_start") {
371 if let (Some(tool_call_id), Some(args)) = (
372 record.get("toolCallId").and_then(Value::as_str),
373 record.get("args").or_else(|| record.get("parameters")),
374 ) {
375 started_args.insert(tool_call_id.to_string(), args.clone());
376 }
377 continue;
378 }
379
380 if record.get("type").and_then(Value::as_str) != Some("tool_execution_end") {
381 continue;
382 }
383
384 calls += 1;
385 if record.get("isError").and_then(Value::as_bool) == Some(true) {
386 errors += 1;
387 if recovered_output.is_none() {
388 recovered_output = recover_structured_tool_arguments(event, &started_args);
389 }
390 continue;
391 }
392
393 if let Some(details) = get_path(event, &["result", "details"]) {
394 successes += 1;
395 output = Some(details.clone());
396 }
397 }
398
399 if let Some(output) = output {
400 if errors > 0 {
401 log::debug!(
402 "Pi structured-output tool had {errors} failed attempt(s) before a successful output"
403 );
404 }
405 if successes > 1 {
406 log::debug!("Pi structured-output tool returned {successes} successful outputs; using the last one");
407 }
408 return Ok(output);
409 }
410
411 if let Some(output) = recovered_output {
412 log::debug!(
413 "Pi structured-output tool failed, but attempted tool arguments were recovered from events"
414 );
415 return Ok(output);
416 }
417
418 if calls == 0 {
419 bail!("Pi provider did not call smol_workflows_structured_output for schema output");
420 }
421 if errors > 0 {
422 bail!("Pi smol_workflows_structured_output tool failed");
423 }
424 bail!("Pi smol_workflows_structured_output tool did not return details")
425}
426
427fn recover_structured_tool_arguments(
428 event: &Value,
429 started_args: &HashMap<String, Value>,
430) -> Option<Value> {
431 for path in [
432 &["result", "details"][..],
433 &["result", "input"],
434 &["state", "input"],
435 &["input"],
436 &["args"],
437 &["parameters"],
438 ] {
439 if let Some(value) = get_path(event, path) {
440 return Some(value.clone());
441 }
442 }
443
444 event
445 .get("toolCallId")
446 .and_then(Value::as_str)
447 .and_then(|tool_call_id| started_args.get(tool_call_id))
448 .cloned()
449}
450
451fn extract_output(events: &[Value]) -> Option<String> {
452 let mut output = None;
453 for event in events {
454 if let Some(value) = extract_output_from_event(event) {
455 output = Some(value);
456 }
457 }
458 output
459}
460
461fn extract_output_from_event(event: &Value) -> Option<String> {
462 let record = event.as_object()?;
463 match record.get("type").and_then(Value::as_str) {
464 Some("message_end" | "turn_end") => record
465 .get("message")
466 .and_then(extract_assistant_message_text),
467 Some("agent_end") => record
468 .get("messages")
469 .and_then(Value::as_array)
470 .and_then(|messages| messages.iter().rev().find(|m| is_assistant_message(m)))
471 .and_then(extract_assistant_message_text),
472 Some("message_update") => record
473 .get("message")
474 .and_then(extract_assistant_message_text),
475 _ => None,
476 }
477}
478
479fn is_assistant_message(value: &Value) -> bool {
480 value
481 .as_object()
482 .and_then(|record| record.get("role"))
483 .and_then(Value::as_str)
484 == Some("assistant")
485}
486
487fn extract_assistant_message_text(message: &Value) -> Option<String> {
488 let record = message.as_object()?;
489 if record.get("role").is_some()
490 && record.get("role").and_then(Value::as_str) != Some("assistant")
491 {
492 return None;
493 }
494 record.get("content").and_then(extract_text)
495}
496
497fn extract_text(value: &Value) -> Option<String> {
498 match value {
499 Value::String(text) => Some(text.clone()),
500 Value::Array(items) => {
501 let text = items
502 .iter()
503 .map(|item| extract_text(item).unwrap_or_default())
504 .collect::<Vec<_>>()
505 .join("");
506 (!text.is_empty()).then_some(text)
507 }
508 Value::Object(record) => record
509 .get("text")
510 .or_else(|| record.get("content"))
511 .or_else(|| record.get("message"))
512 .and_then(extract_text),
513 _ => None,
514 }
515}
516
517fn extract_error_message(events: &[Value]) -> Option<String> {
518 events.iter().find_map(find_error_message)
519}
520
521fn find_error_message(value: &Value) -> Option<String> {
522 match value {
523 Value::Array(items) => items.iter().find_map(find_error_message),
524 Value::Object(record) => {
525 if let Some(message) = record.get("errorMessage").and_then(Value::as_str) {
526 return Some(message.to_string());
527 }
528 record.values().find_map(find_error_message)
529 }
530 _ => None,
531 }
532}
533
534fn extract_session_id(events: &[Value]) -> Option<String> {
535 for event in events {
536 if event.get("type").and_then(Value::as_str) == Some("session") {
537 if let Some(id) = event.get("id").and_then(Value::as_str) {
538 return Some(id.to_string());
539 }
540 }
541 if let Some(id) = event
542 .get("session_id")
543 .or_else(|| event.get("sessionId"))
544 .or_else(|| event.get("sessionID"))
545 .and_then(Value::as_str)
546 {
547 return Some(id.to_string());
548 }
549 }
550 None
551}
552
553fn extract_usage(events: &[Value]) -> Option<AgentUsage> {
554 let mut usage = None;
555 for event in events {
556 let mut candidates = Vec::new();
557 find_usage_objects(event, &mut candidates);
558 for candidate in candidates {
559 usage = Some(merge_usage_right(usage, normalize_usage(&candidate)));
560 }
561 }
562 usage
563}