Skip to main content

smol_workflow_engine/agent_providers/
pi.rs

1use super::common::*;
2use super::types::*;
3use crate::environment::EnvironmentPath;
4use anyhow::bail;
5use serde_json::{json, Value};
6use std::collections::HashMap;
7use std::path::PathBuf;
8
9#[derive(Debug, Clone, Default)]
10pub struct PiAgentProviderOptions {
11    pub command: Option<String>,
12    pub subcommand: Vec<String>,
13    pub args: Vec<String>,
14    pub cwd: Option<PathBuf>,
15    pub env: HashMap<String, String>,
16    pub timeout_ms: Option<u64>,
17}
18
19#[derive(Debug, Clone, Default)]
20pub struct PiAgentProvider {
21    options: PiAgentProviderOptions,
22}
23impl PiAgentProvider {
24    pub fn new(options: PiAgentProviderOptions) -> Self {
25        Self { options }
26    }
27}
28
29#[async_trait::async_trait]
30impl AgentProvider for PiAgentProvider {
31    fn name(&self) -> &str {
32        "pi"
33    }
34    fn schema_mode(&self) -> AgentProviderSchemaMode {
35        AgentProviderSchemaMode::Builtin
36    }
37    fn usage_mode(&self) -> AgentProviderUsageMode {
38        AgentProviderUsageMode::Builtin
39    }
40    async fn run(&self, input: AgentProviderRunInput) -> anyhow::Result<AgentProviderResult> {
41        run_pi(input, &self.options).await
42    }
43}
44
45async fn run_pi(
46    input: AgentProviderRunInput,
47    options: &PiAgentProviderOptions,
48) -> anyhow::Result<AgentProviderResult> {
49    let command = options.command.as_deref().unwrap_or("pi");
50    let has_schema = option_schema(&input.options).is_some();
51    let prompt = if has_schema {
52        with_structured_output_tool_instruction(&input.prompt)
53    } else {
54        input.prompt.clone()
55    };
56    let temp = input.environment.create_temp_dir("smol-wf-pi-").await?;
57    let extension_path =
58        has_schema.then(|| join_environment_path(&temp, "structured-output-extension.ts"));
59    let prompt_path = join_environment_path(&temp, "prompt.md");
60
61    if let Some(path) = &extension_path {
62        input
63            .environment
64            .write_file(
65                path,
66                build_structured_output_extension(option_schema(&input.options).unwrap())
67                    .as_bytes(),
68            )
69            .await?;
70    }
71    input
72        .environment
73        .write_file(&prompt_path, prompt.as_bytes())
74        .await?;
75
76    let prompt_arg = format!("@{}", prompt_path.0);
77
78    let mut args = Vec::new();
79    args.extend(options.subcommand.clone());
80    args.extend(options.args.clone());
81    if let Some(path) = &extension_path {
82        args.extend(["--extension".into(), path.0.clone()]);
83    }
84    args.extend(["--print".into(), "--mode".into(), "json".into()]);
85    if let Some(model) = option_str(&input.options, "model") {
86        args.extend(["--model".into(), model]);
87    }
88    if let Some(thinking) = option_str(&input.options, "thinking") {
89        args.extend(["--thinking".into(), thinking]);
90    }
91    args.push(prompt_arg);
92
93    let cwd = input.context.cwd.as_deref().or(options.cwd.as_deref());
94    let (stdout, stderr) = run_command(RunCommandRequest {
95        provider: "Pi",
96        command,
97        args: &args,
98        stdin: None,
99        cwd,
100        env: &options.env,
101        timeout_ms: options.timeout_ms,
102        environment: input.environment.as_ref(),
103    })
104    .await?;
105    let events = parse_json_lines(&stdout);
106    let output = if has_schema {
107        extract_structured_tool_output(&events)?
108    } else {
109        let candidate = extract_output(&events)
110            .or_else(|| extract_last_tool_result_text(&events))
111            .ok_or_else(|| {
112                let message = extract_error_message(&events)
113                    .or_else(|| (!stderr.trim().is_empty()).then(|| stderr.trim().to_string()))
114                    .unwrap_or_else(|| {
115                        let event_types = events
116                            .iter()
117                            .filter_map(|event| event.get("type").and_then(Value::as_str))
118                            .collect::<Vec<_>>()
119                            .join(",");
120                        let last_event = events
121                            .last()
122                            .map(|event| truncate(&event.to_string(), 1000))
123                            .unwrap_or_else(|| "<none>".to_string());
124                        format!(
125                            "Pi provider did not return assistant output; event_types=[{event_types}] last_event={last_event}"
126                        )
127                    });
128                anyhow::anyhow!(message)
129            })?;
130        Value::String(candidate.trim_end().to_string())
131    };
132    let session_id = extract_session_id(&events)
133        .ok_or_else(|| anyhow::anyhow!("Pi provider response did not include a session id"))?;
134
135    Ok(AgentProviderResult {
136        output,
137        session_id: Some(session_id),
138        model: extract_model(&Value::Array(events.clone()))
139            .or_else(|| option_model(&input.options)),
140        usage: extract_usage(&events),
141        isolation: None,
142        raw: Some(to_json_value(
143            json!({ "events": events, "stderr": stderr, "extensionPath": extension_path.map(|p| p.0) }),
144        )),
145    })
146}
147
148fn join_environment_path(base: &EnvironmentPath, child: &str) -> EnvironmentPath {
149    EnvironmentPath(format!("{}/{}", base.as_str().trim_end_matches('/'), child))
150}
151
152fn with_structured_output_tool_instruction(prompt: &str) -> String {
153    [
154        prompt,
155        "",
156        "Use the smol_workflows_structured_output tool as your final action exactly once.",
157        "Do not emit a final assistant message after calling smol_workflows_structured_output.",
158    ]
159    .join("\n")
160}
161
162fn build_structured_output_extension(schema: &Value) -> String {
163    let wrapped = !schema
164        .as_object()
165        .is_some_and(|o| o.get("type") == Some(&Value::String("object".into())));
166    let parameters = if wrapped {
167        format!(
168            "Type.Object({{ value: {} }})",
169            json_schema_to_typebox_expression(schema)
170        )
171    } else {
172        json_schema_to_typebox_expression(schema)
173    };
174    let details = if wrapped { "params.value" } else { "params" };
175    format!(
176        r#"import {{ defineTool, type ExtensionAPI }} from "@earendil-works/pi-coding-agent";
177import {{ Type }} from "typebox";
178
179const structuredOutputTool = defineTool({{
180  name: "smol_workflows_structured_output",
181  label: "Structured Output",
182  description: "Submit the final structured response for this agent call.",
183  promptSnippet: "Submit the final structured response with the smol_workflows_structured_output tool.",
184  promptGuidelines: [
185    "Use smol_workflows_structured_output as your final action exactly once.",
186    "The tool parameters are generated from the caller's JSON Schema.",
187    "After calling smol_workflows_structured_output, do not emit another assistant response in the same turn.",
188  ],
189  parameters: {parameters},
190  async execute(_toolCallId, params) {{
191    return {{
192      content: [{{ type: "text", text: "Structured output captured successfully." }}],
193      details: {details},
194      terminate: true,
195    }};
196  }},
197}});
198
199export default function (pi: ExtensionAPI) {{
200  pi.registerTool(structuredOutputTool);
201}}
202"#
203    )
204}
205
206fn json_schema_to_typebox_expression(schema: &Value) -> String {
207    match schema {
208        Value::Bool(true) => "Type.Any()".into(),
209        Value::Bool(false) => "Type.Never()".into(),
210        Value::Object(record) => {
211            if let Some(value) = record.get("const") {
212                return format!("Type.Literal({})", serde_json::to_string(value).unwrap());
213            }
214            if let Some(values) = record.get("enum").and_then(Value::as_array) {
215                if !values.is_empty() {
216                    return if values.len() == 1 {
217                        format!(
218                            "Type.Literal({})",
219                            serde_json::to_string(&values[0]).unwrap()
220                        )
221                    } else {
222                        format!(
223                            "Type.Union([{}])",
224                            values
225                                .iter()
226                                .map(|v| format!(
227                                    "Type.Literal({})",
228                                    serde_json::to_string(v).unwrap()
229                                ))
230                                .collect::<Vec<_>>()
231                                .join(", ")
232                        )
233                    };
234                }
235            }
236            for key in ["oneOf", "anyOf"] {
237                if let Some(values) = record.get(key).and_then(Value::as_array) {
238                    if !values.is_empty() {
239                        return format!(
240                            "Type.Union([{}])",
241                            values
242                                .iter()
243                                .map(json_schema_to_typebox_expression)
244                                .collect::<Vec<_>>()
245                                .join(", ")
246                        );
247                    }
248                }
249            }
250            match first_schema_type(record.get("type")).or_else(|| infer_schema_type(record)) {
251                Some("null") => "Type.Null()".into(),
252                Some("boolean") => format!("Type.Boolean({})", typebox_options(record)),
253                Some("integer") => format!("Type.Integer({})", typebox_options(record)),
254                Some("number") => format!("Type.Number({})", typebox_options(record)),
255                Some("string") => format!("Type.String({})", typebox_options(record)),
256                Some("array") => array_schema_to_typebox_expression(record),
257                Some("object") => object_schema_to_typebox_expression(record),
258                _ => "Type.Any()".into(),
259            }
260        }
261        _ => "Type.Any()".into(),
262    }
263}
264
265fn object_schema_to_typebox_expression(schema: &serde_json::Map<String, Value>) -> String {
266    let properties = schema
267        .get("properties")
268        .and_then(Value::as_object)
269        .cloned()
270        .unwrap_or_default();
271    let required = schema
272        .get("required")
273        .and_then(Value::as_array)
274        .map(|items| items.iter().filter_map(Value::as_str).collect::<Vec<_>>())
275        .unwrap_or_default();
276    let entries = properties
277        .iter()
278        .map(|(key, value)| {
279            let expression = json_schema_to_typebox_expression(value);
280            if required.iter().any(|required| required == key) {
281                format!("{}: {}", serde_json::to_string(key).unwrap(), expression)
282            } else {
283                format!(
284                    "{}: Type.Optional({})",
285                    serde_json::to_string(key).unwrap(),
286                    expression
287                )
288            }
289        })
290        .collect::<Vec<_>>()
291        .join(", ");
292    format!("Type.Object({{ {entries} }}, {})", typebox_options(schema))
293}
294
295fn array_schema_to_typebox_expression(schema: &serde_json::Map<String, Value>) -> String {
296    let options = typebox_options(schema);
297    if let Some(prefix_items) = schema.get("prefixItems").and_then(Value::as_array) {
298        if !prefix_items.is_empty() {
299            return format!(
300                "Type.Tuple([{}], {options})",
301                prefix_items
302                    .iter()
303                    .map(json_schema_to_typebox_expression)
304                    .collect::<Vec<_>>()
305                    .join(", ")
306            );
307        }
308    }
309    let item_schema = schema
310        .get("items")
311        .filter(|value| !value.is_array())
312        .unwrap_or(&Value::Bool(true));
313    format!(
314        "Type.Array({}, {options})",
315        json_schema_to_typebox_expression(item_schema)
316    )
317}
318
319fn typebox_options(schema: &serde_json::Map<String, Value>) -> String {
320    let option_keys = [
321        "title",
322        "description",
323        "default",
324        "examples",
325        "minimum",
326        "maximum",
327        "exclusiveMinimum",
328        "exclusiveMaximum",
329        "multipleOf",
330        "minLength",
331        "maxLength",
332        "pattern",
333        "format",
334        "minItems",
335        "maxItems",
336        "uniqueItems",
337        "additionalProperties",
338    ];
339    let mut options = serde_json::Map::new();
340    for key in option_keys {
341        if let Some(value) = schema.get(key) {
342            options.insert(key.to_string(), value.clone());
343        }
344    }
345    serde_json::to_string(&Value::Object(options)).unwrap()
346}
347
348fn first_schema_type(value: Option<&Value>) -> Option<&str> {
349    match value {
350        Some(Value::String(value)) => Some(value),
351        Some(Value::Array(values)) => values.iter().find_map(Value::as_str),
352        _ => None,
353    }
354}
355
356fn infer_schema_type(schema: &serde_json::Map<String, Value>) -> Option<&'static str> {
357    if schema.contains_key("properties")
358        || schema.contains_key("required")
359        || schema.contains_key("additionalProperties")
360    {
361        Some("object")
362    } else if schema.contains_key("items") || schema.contains_key("prefixItems") {
363        Some("array")
364    } else if schema.contains_key("minimum")
365        || schema.contains_key("maximum")
366        || schema.contains_key("multipleOf")
367    {
368        Some("number")
369    } else if schema.contains_key("minLength")
370        || schema.contains_key("maxLength")
371        || schema.contains_key("pattern")
372        || schema.contains_key("format")
373    {
374        Some("string")
375    } else {
376        None
377    }
378}
379
380fn extract_structured_tool_output(events: &[Value]) -> anyhow::Result<Value> {
381    let mut output = None;
382    let mut recovered_output = None;
383    let mut started_args = HashMap::<String, Value>::new();
384    let mut calls = 0;
385    let mut successes = 0;
386    let mut errors = 0;
387
388    for event in events {
389        let Some(record) = event.as_object() else {
390            continue;
391        };
392        if record.get("toolName").and_then(Value::as_str)
393            != Some("smol_workflows_structured_output")
394        {
395            continue;
396        }
397
398        if record.get("type").and_then(Value::as_str) == Some("tool_execution_start") {
399            if let (Some(tool_call_id), Some(args)) = (
400                record.get("toolCallId").and_then(Value::as_str),
401                record.get("args").or_else(|| record.get("parameters")),
402            ) {
403                started_args.insert(tool_call_id.to_string(), args.clone());
404            }
405            continue;
406        }
407
408        if record.get("type").and_then(Value::as_str) != Some("tool_execution_end") {
409            continue;
410        }
411
412        calls += 1;
413        if record.get("isError").and_then(Value::as_bool) == Some(true) {
414            errors += 1;
415            if recovered_output.is_none() {
416                recovered_output = recover_structured_tool_arguments(event, &started_args);
417            }
418            continue;
419        }
420
421        if let Some(details) = get_path(event, &["result", "details"]) {
422            successes += 1;
423            output = Some(details.clone());
424        }
425    }
426
427    if let Some(output) = output {
428        if errors > 0 {
429            log::debug!(
430                "Pi structured-output tool had {errors} failed attempt(s) before a successful output"
431            );
432        }
433        if successes > 1 {
434            log::debug!("Pi structured-output tool returned {successes} successful outputs; using the last one");
435        }
436        return Ok(output);
437    }
438
439    if let Some(output) = recovered_output {
440        log::debug!(
441            "Pi structured-output tool failed, but attempted tool arguments were recovered from events"
442        );
443        return Ok(output);
444    }
445
446    if calls == 0 {
447        bail!("Pi provider did not call smol_workflows_structured_output for schema output");
448    }
449    if errors > 0 {
450        bail!("Pi smol_workflows_structured_output tool failed");
451    }
452    bail!("Pi smol_workflows_structured_output tool did not return details")
453}
454
455fn recover_structured_tool_arguments(
456    event: &Value,
457    started_args: &HashMap<String, Value>,
458) -> Option<Value> {
459    for path in [
460        &["result", "details"][..],
461        &["result", "input"],
462        &["state", "input"],
463        &["input"],
464        &["args"],
465        &["parameters"],
466    ] {
467        if let Some(value) = get_path(event, path) {
468            return Some(value.clone());
469        }
470    }
471
472    event
473        .get("toolCallId")
474        .and_then(Value::as_str)
475        .and_then(|tool_call_id| started_args.get(tool_call_id))
476        .cloned()
477}
478
479fn extract_output(events: &[Value]) -> Option<String> {
480    let mut output = None;
481    for event in events {
482        if let Some(value) = extract_output_from_event(event) {
483            output = Some(value);
484        }
485    }
486    output
487}
488
489fn extract_output_from_event(event: &Value) -> Option<String> {
490    let record = event.as_object()?;
491    match record.get("type").and_then(Value::as_str) {
492        Some("message_end" | "turn_end") => record
493            .get("message")
494            .and_then(extract_assistant_message_text),
495        Some("agent_end") => {
496            record
497                .get("messages")
498                .and_then(Value::as_array)
499                .and_then(|messages| {
500                    messages
501                        .iter()
502                        .rev()
503                        .find_map(extract_assistant_message_text)
504                })
505        }
506        Some("message_update") => record
507            .get("message")
508            .and_then(extract_assistant_message_text)
509            .or_else(|| {
510                record
511                    .get("assistantMessageEvent")
512                    .and_then(extract_assistant_message_event_text)
513            }),
514        _ => None,
515    }
516}
517
518fn extract_assistant_message_event_text(event: &Value) -> Option<String> {
519    let record = event.as_object()?;
520    match record.get("type").and_then(Value::as_str) {
521        Some("text_end") => record
522            .get("content")
523            .and_then(Value::as_str)
524            .map(ToString::to_string),
525        Some("text_delta") => record
526            .get("partial")
527            .and_then(extract_assistant_message_text),
528        _ => record
529            .get("partial")
530            .and_then(extract_assistant_message_text),
531    }
532}
533
534fn extract_assistant_message_text(message: &Value) -> Option<String> {
535    let record = message.as_object()?;
536    if record.get("role").is_some()
537        && record.get("role").and_then(Value::as_str) != Some("assistant")
538    {
539        return None;
540    }
541    record.get("content").and_then(extract_text)
542}
543
544fn extract_text(value: &Value) -> Option<String> {
545    match value {
546        Value::String(text) => Some(text.clone()),
547        Value::Array(items) => {
548            let text = items
549                .iter()
550                .map(|item| extract_text(item).unwrap_or_default())
551                .collect::<Vec<_>>()
552                .join("");
553            (!text.is_empty()).then_some(text)
554        }
555        Value::Object(record) => record
556            .get("text")
557            .or_else(|| record.get("content"))
558            .or_else(|| record.get("message"))
559            .and_then(extract_text),
560        _ => None,
561    }
562}
563
564fn extract_last_tool_result_text(events: &[Value]) -> Option<String> {
565    events.iter().rev().find_map(|event| {
566        let record = event.as_object()?;
567        if record.get("type").and_then(Value::as_str) != Some("tool_execution_end") {
568            return None;
569        }
570        record
571            .get("result")
572            .and_then(|result| {
573                result
574                    .get("content")
575                    .or_else(|| result.get("message"))
576                    .or_else(|| result.get("text"))
577                    .and_then(extract_text)
578            })
579            .or_else(|| record.get("message").and_then(extract_text))
580    })
581}
582
583fn extract_error_message(events: &[Value]) -> Option<String> {
584    events.iter().find_map(find_error_message)
585}
586
587fn find_error_message(value: &Value) -> Option<String> {
588    match value {
589        Value::Array(items) => items.iter().find_map(find_error_message),
590        Value::Object(record) => {
591            if let Some(message) = record.get("errorMessage").and_then(Value::as_str) {
592                return Some(message.to_string());
593            }
594            record.values().find_map(find_error_message)
595        }
596        _ => None,
597    }
598}
599
600fn extract_session_id(events: &[Value]) -> Option<String> {
601    for event in events {
602        if event.get("type").and_then(Value::as_str) == Some("session") {
603            if let Some(id) = event.get("id").and_then(Value::as_str) {
604                return Some(id.to_string());
605            }
606        }
607        if let Some(id) = event
608            .get("session_id")
609            .or_else(|| event.get("sessionId"))
610            .or_else(|| event.get("sessionID"))
611            .and_then(Value::as_str)
612        {
613            return Some(id.to_string());
614        }
615    }
616    None
617}
618
619fn extract_usage(events: &[Value]) -> Option<AgentUsage> {
620    let mut usage = None;
621    for event in events {
622        let mut candidates = Vec::new();
623        find_usage_objects(event, &mut candidates);
624        for candidate in candidates {
625            usage = Some(merge_usage_right(usage, normalize_usage(&candidate)));
626        }
627    }
628    usage
629}