Skip to main content

smol_workflow_engine/agent_providers/
pi.rs

1use super::common::*;
2use super::types::*;
3use anyhow::bail;
4use serde_json::{json, Value};
5use std::collections::HashMap;
6use std::fs;
7use std::path::PathBuf;
8
9#[derive(Debug, Clone, Default)]
10pub struct PiAgentProviderOptions {
11    pub command: Option<String>,
12    pub subcommand: Vec<String>,
13    pub args: Vec<String>,
14    pub cwd: Option<PathBuf>,
15    pub env: HashMap<String, String>,
16    pub timeout_ms: Option<u64>,
17}
18
19#[derive(Debug, Clone, Default)]
20pub struct PiAgentProvider {
21    options: PiAgentProviderOptions,
22}
23impl PiAgentProvider {
24    pub fn new(options: PiAgentProviderOptions) -> Self {
25        Self { options }
26    }
27}
28
29#[async_trait::async_trait]
30impl AgentProvider for PiAgentProvider {
31    fn name(&self) -> &str {
32        "pi"
33    }
34    fn schema_mode(&self) -> AgentProviderSchemaMode {
35        AgentProviderSchemaMode::Builtin
36    }
37    fn usage_mode(&self) -> AgentProviderUsageMode {
38        AgentProviderUsageMode::Builtin
39    }
40    async fn run(&self, input: AgentProviderRunInput) -> anyhow::Result<AgentProviderResult> {
41        run_pi(input, &self.options).await
42    }
43}
44
45async fn run_pi(
46    input: AgentProviderRunInput,
47    options: &PiAgentProviderOptions,
48) -> anyhow::Result<AgentProviderResult> {
49    let command = options.command.as_deref().unwrap_or("pi");
50    let has_schema = option_schema(&input.options).is_some();
51    let prompt = if has_schema {
52        with_structured_output_tool_instruction(&input.prompt)
53    } else {
54        input.prompt.clone()
55    };
56    let temp = temp_dir("smol-wf-pi-")?;
57    let extension_path = has_schema.then(|| temp.path().join("structured-output-extension.ts"));
58    let prompt_path = temp.path().join("prompt.md");
59
60    if let Some(path) = &extension_path {
61        fs::write(
62            path,
63            build_structured_output_extension(option_schema(&input.options).unwrap()),
64        )?;
65    }
66    fs::write(&prompt_path, &prompt)?;
67
68    let prompt_arg = format!("@{}", prompt_path.to_string_lossy());
69
70    let mut args = Vec::new();
71    args.extend(options.subcommand.clone());
72    args.extend(options.args.clone());
73    if let Some(path) = &extension_path {
74        args.extend(["--extension".into(), path.to_string_lossy().into_owned()]);
75    }
76    args.extend(["--print".into(), "--mode".into(), "json".into()]);
77    if let Some(model) = option_str(&input.options, "model") {
78        args.extend(["--model".into(), model]);
79    }
80    if let Some(thinking) = option_str(&input.options, "thinking") {
81        args.extend(["--thinking".into(), thinking]);
82    }
83    args.push(prompt_arg);
84
85    let cwd = input.context.cwd.as_deref().or(options.cwd.as_deref());
86    let (stdout, stderr) = run_command(
87        "Pi",
88        command,
89        &args,
90        None,
91        cwd,
92        &options.env,
93        options.timeout_ms,
94    )
95    .await?;
96    let events = parse_json_lines(&stdout);
97    let output = if has_schema {
98        extract_structured_tool_output(&events)?
99    } else {
100        let candidate = extract_output(&events).ok_or_else(|| {
101            let message = extract_error_message(&events)
102                .or_else(|| (!stderr.trim().is_empty()).then(|| stderr.trim().to_string()))
103                .unwrap_or_else(|| "Pi provider did not return assistant output".to_string());
104            anyhow::anyhow!(message)
105        })?;
106        Value::String(candidate.trim_end().to_string())
107    };
108    let session_id = extract_session_id(&events)
109        .ok_or_else(|| anyhow::anyhow!("Pi provider response did not include a session id"))?;
110
111    Ok(AgentProviderResult {
112        output,
113        session_id: Some(session_id),
114        model: extract_model(&Value::Array(events.clone()))
115            .or_else(|| option_model(&input.options)),
116        usage: extract_usage(&events),
117        isolation: None,
118        raw: Some(to_json_value(
119            json!({ "events": events, "stderr": stderr, "extensionPath": extension_path.map(|p| p.to_string_lossy().into_owned()) }),
120        )),
121    })
122}
123
124fn with_structured_output_tool_instruction(prompt: &str) -> String {
125    [
126        prompt,
127        "",
128        "Use the smol_workflows_structured_output tool as your final action exactly once.",
129        "Do not emit a final assistant message after calling smol_workflows_structured_output.",
130    ]
131    .join("\n")
132}
133
134fn build_structured_output_extension(schema: &Value) -> String {
135    let wrapped = !schema
136        .as_object()
137        .is_some_and(|o| o.get("type") == Some(&Value::String("object".into())));
138    let parameters = if wrapped {
139        format!(
140            "Type.Object({{ value: {} }})",
141            json_schema_to_typebox_expression(schema)
142        )
143    } else {
144        json_schema_to_typebox_expression(schema)
145    };
146    let details = if wrapped { "params.value" } else { "params" };
147    format!(
148        r#"import {{ defineTool, type ExtensionAPI }} from "@earendil-works/pi-coding-agent";
149import {{ Type }} from "typebox";
150
151const structuredOutputTool = defineTool({{
152  name: "smol_workflows_structured_output",
153  label: "Structured Output",
154  description: "Submit the final structured response for this agent call.",
155  promptSnippet: "Submit the final structured response with the smol_workflows_structured_output tool.",
156  promptGuidelines: [
157    "Use smol_workflows_structured_output as your final action exactly once.",
158    "The tool parameters are generated from the caller's JSON Schema.",
159    "After calling smol_workflows_structured_output, do not emit another assistant response in the same turn.",
160  ],
161  parameters: {parameters},
162  async execute(_toolCallId, params) {{
163    return {{
164      content: [{{ type: "text", text: "Structured output captured successfully." }}],
165      details: {details},
166      terminate: true,
167    }};
168  }},
169}});
170
171export default function (pi: ExtensionAPI) {{
172  pi.registerTool(structuredOutputTool);
173}}
174"#
175    )
176}
177
178fn json_schema_to_typebox_expression(schema: &Value) -> String {
179    match schema {
180        Value::Bool(true) => "Type.Any()".into(),
181        Value::Bool(false) => "Type.Never()".into(),
182        Value::Object(record) => {
183            if let Some(value) = record.get("const") {
184                return format!("Type.Literal({})", serde_json::to_string(value).unwrap());
185            }
186            if let Some(values) = record.get("enum").and_then(Value::as_array) {
187                if !values.is_empty() {
188                    return if values.len() == 1 {
189                        format!(
190                            "Type.Literal({})",
191                            serde_json::to_string(&values[0]).unwrap()
192                        )
193                    } else {
194                        format!(
195                            "Type.Union([{}])",
196                            values
197                                .iter()
198                                .map(|v| format!(
199                                    "Type.Literal({})",
200                                    serde_json::to_string(v).unwrap()
201                                ))
202                                .collect::<Vec<_>>()
203                                .join(", ")
204                        )
205                    };
206                }
207            }
208            for key in ["oneOf", "anyOf"] {
209                if let Some(values) = record.get(key).and_then(Value::as_array) {
210                    if !values.is_empty() {
211                        return format!(
212                            "Type.Union([{}])",
213                            values
214                                .iter()
215                                .map(json_schema_to_typebox_expression)
216                                .collect::<Vec<_>>()
217                                .join(", ")
218                        );
219                    }
220                }
221            }
222            match first_schema_type(record.get("type")).or_else(|| infer_schema_type(record)) {
223                Some("null") => "Type.Null()".into(),
224                Some("boolean") => format!("Type.Boolean({})", typebox_options(record)),
225                Some("integer") => format!("Type.Integer({})", typebox_options(record)),
226                Some("number") => format!("Type.Number({})", typebox_options(record)),
227                Some("string") => format!("Type.String({})", typebox_options(record)),
228                Some("array") => array_schema_to_typebox_expression(record),
229                Some("object") => object_schema_to_typebox_expression(record),
230                _ => "Type.Any()".into(),
231            }
232        }
233        _ => "Type.Any()".into(),
234    }
235}
236
237fn object_schema_to_typebox_expression(schema: &serde_json::Map<String, Value>) -> String {
238    let properties = schema
239        .get("properties")
240        .and_then(Value::as_object)
241        .cloned()
242        .unwrap_or_default();
243    let required = schema
244        .get("required")
245        .and_then(Value::as_array)
246        .map(|items| items.iter().filter_map(Value::as_str).collect::<Vec<_>>())
247        .unwrap_or_default();
248    let entries = properties
249        .iter()
250        .map(|(key, value)| {
251            let expression = json_schema_to_typebox_expression(value);
252            if required.iter().any(|required| required == key) {
253                format!("{}: {}", serde_json::to_string(key).unwrap(), expression)
254            } else {
255                format!(
256                    "{}: Type.Optional({})",
257                    serde_json::to_string(key).unwrap(),
258                    expression
259                )
260            }
261        })
262        .collect::<Vec<_>>()
263        .join(", ");
264    format!("Type.Object({{ {entries} }}, {})", typebox_options(schema))
265}
266
267fn array_schema_to_typebox_expression(schema: &serde_json::Map<String, Value>) -> String {
268    let options = typebox_options(schema);
269    if let Some(prefix_items) = schema.get("prefixItems").and_then(Value::as_array) {
270        if !prefix_items.is_empty() {
271            return format!(
272                "Type.Tuple([{}], {options})",
273                prefix_items
274                    .iter()
275                    .map(json_schema_to_typebox_expression)
276                    .collect::<Vec<_>>()
277                    .join(", ")
278            );
279        }
280    }
281    let item_schema = schema
282        .get("items")
283        .filter(|value| !value.is_array())
284        .unwrap_or(&Value::Bool(true));
285    format!(
286        "Type.Array({}, {options})",
287        json_schema_to_typebox_expression(item_schema)
288    )
289}
290
291fn typebox_options(schema: &serde_json::Map<String, Value>) -> String {
292    let option_keys = [
293        "title",
294        "description",
295        "default",
296        "examples",
297        "minimum",
298        "maximum",
299        "exclusiveMinimum",
300        "exclusiveMaximum",
301        "multipleOf",
302        "minLength",
303        "maxLength",
304        "pattern",
305        "format",
306        "minItems",
307        "maxItems",
308        "uniqueItems",
309        "additionalProperties",
310    ];
311    let mut options = serde_json::Map::new();
312    for key in option_keys {
313        if let Some(value) = schema.get(key) {
314            options.insert(key.to_string(), value.clone());
315        }
316    }
317    serde_json::to_string(&Value::Object(options)).unwrap()
318}
319
320fn first_schema_type(value: Option<&Value>) -> Option<&str> {
321    match value {
322        Some(Value::String(value)) => Some(value),
323        Some(Value::Array(values)) => values.iter().find_map(Value::as_str),
324        _ => None,
325    }
326}
327
328fn infer_schema_type(schema: &serde_json::Map<String, Value>) -> Option<&'static str> {
329    if schema.contains_key("properties")
330        || schema.contains_key("required")
331        || schema.contains_key("additionalProperties")
332    {
333        Some("object")
334    } else if schema.contains_key("items") || schema.contains_key("prefixItems") {
335        Some("array")
336    } else if schema.contains_key("minimum")
337        || schema.contains_key("maximum")
338        || schema.contains_key("multipleOf")
339    {
340        Some("number")
341    } else if schema.contains_key("minLength")
342        || schema.contains_key("maxLength")
343        || schema.contains_key("pattern")
344        || schema.contains_key("format")
345    {
346        Some("string")
347    } else {
348        None
349    }
350}
351
352fn extract_structured_tool_output(events: &[Value]) -> anyhow::Result<Value> {
353    let mut output = None;
354    let mut recovered_output = None;
355    let mut started_args = HashMap::<String, Value>::new();
356    let mut calls = 0;
357    let mut successes = 0;
358    let mut errors = 0;
359
360    for event in events {
361        let Some(record) = event.as_object() else {
362            continue;
363        };
364        if record.get("toolName").and_then(Value::as_str)
365            != Some("smol_workflows_structured_output")
366        {
367            continue;
368        }
369
370        if record.get("type").and_then(Value::as_str) == Some("tool_execution_start") {
371            if let (Some(tool_call_id), Some(args)) = (
372                record.get("toolCallId").and_then(Value::as_str),
373                record.get("args").or_else(|| record.get("parameters")),
374            ) {
375                started_args.insert(tool_call_id.to_string(), args.clone());
376            }
377            continue;
378        }
379
380        if record.get("type").and_then(Value::as_str) != Some("tool_execution_end") {
381            continue;
382        }
383
384        calls += 1;
385        if record.get("isError").and_then(Value::as_bool) == Some(true) {
386            errors += 1;
387            if recovered_output.is_none() {
388                recovered_output = recover_structured_tool_arguments(event, &started_args);
389            }
390            continue;
391        }
392
393        if let Some(details) = get_path(event, &["result", "details"]) {
394            successes += 1;
395            output = Some(details.clone());
396        }
397    }
398
399    if let Some(output) = output {
400        if errors > 0 {
401            log::debug!(
402                "Pi structured-output tool had {errors} failed attempt(s) before a successful output"
403            );
404        }
405        if successes > 1 {
406            log::debug!("Pi structured-output tool returned {successes} successful outputs; using the last one");
407        }
408        return Ok(output);
409    }
410
411    if let Some(output) = recovered_output {
412        log::debug!(
413            "Pi structured-output tool failed, but attempted tool arguments were recovered from events"
414        );
415        return Ok(output);
416    }
417
418    if calls == 0 {
419        bail!("Pi provider did not call smol_workflows_structured_output for schema output");
420    }
421    if errors > 0 {
422        bail!("Pi smol_workflows_structured_output tool failed");
423    }
424    bail!("Pi smol_workflows_structured_output tool did not return details")
425}
426
427fn recover_structured_tool_arguments(
428    event: &Value,
429    started_args: &HashMap<String, Value>,
430) -> Option<Value> {
431    for path in [
432        &["result", "details"][..],
433        &["result", "input"],
434        &["state", "input"],
435        &["input"],
436        &["args"],
437        &["parameters"],
438    ] {
439        if let Some(value) = get_path(event, path) {
440            return Some(value.clone());
441        }
442    }
443
444    event
445        .get("toolCallId")
446        .and_then(Value::as_str)
447        .and_then(|tool_call_id| started_args.get(tool_call_id))
448        .cloned()
449}
450
451fn extract_output(events: &[Value]) -> Option<String> {
452    let mut output = None;
453    for event in events {
454        if let Some(value) = extract_output_from_event(event) {
455            output = Some(value);
456        }
457    }
458    output
459}
460
461fn extract_output_from_event(event: &Value) -> Option<String> {
462    let record = event.as_object()?;
463    match record.get("type").and_then(Value::as_str) {
464        Some("message_end" | "turn_end") => record
465            .get("message")
466            .and_then(extract_assistant_message_text),
467        Some("agent_end") => record
468            .get("messages")
469            .and_then(Value::as_array)
470            .and_then(|messages| messages.iter().rev().find(|m| is_assistant_message(m)))
471            .and_then(extract_assistant_message_text),
472        Some("message_update") => record
473            .get("message")
474            .and_then(extract_assistant_message_text),
475        _ => None,
476    }
477}
478
479fn is_assistant_message(value: &Value) -> bool {
480    value
481        .as_object()
482        .and_then(|record| record.get("role"))
483        .and_then(Value::as_str)
484        == Some("assistant")
485}
486
487fn extract_assistant_message_text(message: &Value) -> Option<String> {
488    let record = message.as_object()?;
489    if record.get("role").is_some()
490        && record.get("role").and_then(Value::as_str) != Some("assistant")
491    {
492        return None;
493    }
494    record.get("content").and_then(extract_text)
495}
496
497fn extract_text(value: &Value) -> Option<String> {
498    match value {
499        Value::String(text) => Some(text.clone()),
500        Value::Array(items) => {
501            let text = items
502                .iter()
503                .map(|item| extract_text(item).unwrap_or_default())
504                .collect::<Vec<_>>()
505                .join("");
506            (!text.is_empty()).then_some(text)
507        }
508        Value::Object(record) => record
509            .get("text")
510            .or_else(|| record.get("content"))
511            .or_else(|| record.get("message"))
512            .and_then(extract_text),
513        _ => None,
514    }
515}
516
517fn extract_error_message(events: &[Value]) -> Option<String> {
518    events.iter().find_map(find_error_message)
519}
520
521fn find_error_message(value: &Value) -> Option<String> {
522    match value {
523        Value::Array(items) => items.iter().find_map(find_error_message),
524        Value::Object(record) => {
525            if let Some(message) = record.get("errorMessage").and_then(Value::as_str) {
526                return Some(message.to_string());
527            }
528            record.values().find_map(find_error_message)
529        }
530        _ => None,
531    }
532}
533
534fn extract_session_id(events: &[Value]) -> Option<String> {
535    for event in events {
536        if event.get("type").and_then(Value::as_str) == Some("session") {
537            if let Some(id) = event.get("id").and_then(Value::as_str) {
538                return Some(id.to_string());
539            }
540        }
541        if let Some(id) = event
542            .get("session_id")
543            .or_else(|| event.get("sessionId"))
544            .or_else(|| event.get("sessionID"))
545            .and_then(Value::as_str)
546        {
547            return Some(id.to_string());
548        }
549    }
550    None
551}
552
553fn extract_usage(events: &[Value]) -> Option<AgentUsage> {
554    let mut usage = None;
555    for event in events {
556        let mut candidates = Vec::new();
557        find_usage_objects(event, &mut candidates);
558        for candidate in candidates {
559            usage = Some(merge_usage_right(usage, normalize_usage(&candidate)));
560        }
561    }
562    usage
563}