1use super::common::*;
2use super::types::*;
3use crate::environment::EnvironmentPath;
4use anyhow::bail;
5use serde_json::{json, Value};
6use std::collections::HashMap;
7use std::path::PathBuf;
8
9#[derive(Debug, Clone, Default)]
10pub struct PiAgentProviderOptions {
11 pub command: Option<String>,
12 pub subcommand: Vec<String>,
13 pub args: Vec<String>,
14 pub cwd: Option<PathBuf>,
15 pub env: HashMap<String, String>,
16 pub timeout_ms: Option<u64>,
17}
18
19#[derive(Debug, Clone, Default)]
20pub struct PiAgentProvider {
21 options: PiAgentProviderOptions,
22}
23impl PiAgentProvider {
24 pub fn new(options: PiAgentProviderOptions) -> Self {
25 Self { options }
26 }
27}
28
29#[async_trait::async_trait]
30impl AgentProvider for PiAgentProvider {
31 fn name(&self) -> &str {
32 "pi"
33 }
34 fn schema_mode(&self) -> AgentProviderSchemaMode {
35 AgentProviderSchemaMode::Builtin
36 }
37 fn usage_mode(&self) -> AgentProviderUsageMode {
38 AgentProviderUsageMode::Builtin
39 }
40 async fn run(&self, input: AgentProviderRunInput) -> anyhow::Result<AgentProviderResult> {
41 run_pi(input, &self.options).await
42 }
43}
44
45async fn run_pi(
46 input: AgentProviderRunInput,
47 options: &PiAgentProviderOptions,
48) -> anyhow::Result<AgentProviderResult> {
49 let command = options.command.as_deref().unwrap_or("pi");
50 let has_schema = option_schema(&input.options).is_some();
51 let prompt = if has_schema {
52 with_structured_output_tool_instruction(&input.prompt)
53 } else {
54 input.prompt.clone()
55 };
56 let temp = input.environment.create_temp_dir("smol-wf-pi-").await?;
57 let extension_path =
58 has_schema.then(|| join_environment_path(&temp, "structured-output-extension.ts"));
59 let prompt_path = join_environment_path(&temp, "prompt.md");
60
61 if let Some(path) = &extension_path {
62 input
63 .environment
64 .write_file(
65 path,
66 build_structured_output_extension(option_schema(&input.options).unwrap())
67 .as_bytes(),
68 )
69 .await?;
70 }
71 input
72 .environment
73 .write_file(&prompt_path, prompt.as_bytes())
74 .await?;
75
76 let prompt_arg = format!("@{}", prompt_path.0);
77
78 let mut args = Vec::new();
79 args.extend(options.subcommand.clone());
80 args.extend(options.args.clone());
81 if let Some(path) = &extension_path {
82 args.extend(["--extension".into(), path.0.clone()]);
83 }
84 args.extend(["--print".into(), "--mode".into(), "json".into()]);
85 if let Some(model) = option_str(&input.options, "model") {
86 args.extend(["--model".into(), model]);
87 }
88 if let Some(thinking) = option_str(&input.options, "thinking") {
89 args.extend(["--thinking".into(), thinking]);
90 }
91 args.push(prompt_arg);
92
93 let cwd = input.context.cwd.as_deref().or(options.cwd.as_deref());
94 let (stdout, stderr) = run_command(RunCommandRequest {
95 provider: "Pi",
96 command,
97 args: &args,
98 stdin: None,
99 cwd,
100 env: &options.env,
101 timeout_ms: options.timeout_ms,
102 environment: input.environment.as_ref(),
103 })
104 .await?;
105 let events = parse_json_lines(&stdout);
106 let output = if has_schema {
107 extract_structured_tool_output(&events)?
108 } else {
109 let candidate = extract_output(&events)
110 .or_else(|| extract_last_tool_result_text(&events))
111 .ok_or_else(|| {
112 let message = extract_error_message(&events)
113 .or_else(|| (!stderr.trim().is_empty()).then(|| stderr.trim().to_string()))
114 .unwrap_or_else(|| {
115 let event_types = events
116 .iter()
117 .filter_map(|event| event.get("type").and_then(Value::as_str))
118 .collect::<Vec<_>>()
119 .join(",");
120 let last_event = events
121 .last()
122 .map(|event| truncate(&event.to_string(), 1000))
123 .unwrap_or_else(|| "<none>".to_string());
124 format!(
125 "Pi provider did not return assistant output; event_types=[{event_types}] last_event={last_event}"
126 )
127 });
128 anyhow::anyhow!(message)
129 })?;
130 Value::String(candidate.trim_end().to_string())
131 };
132 let session_id = extract_session_id(&events)
133 .ok_or_else(|| anyhow::anyhow!("Pi provider response did not include a session id"))?;
134
135 Ok(AgentProviderResult {
136 output,
137 session_id: Some(session_id),
138 model: extract_model(&Value::Array(events.clone()))
139 .or_else(|| option_model(&input.options)),
140 usage: extract_usage(&events),
141 isolation: None,
142 raw: Some(to_json_value(
143 json!({ "events": events, "stderr": stderr, "extensionPath": extension_path.map(|p| p.0) }),
144 )),
145 })
146}
147
148fn join_environment_path(base: &EnvironmentPath, child: &str) -> EnvironmentPath {
149 EnvironmentPath(format!("{}/{}", base.as_str().trim_end_matches('/'), child))
150}
151
152fn with_structured_output_tool_instruction(prompt: &str) -> String {
153 [
154 prompt,
155 "",
156 "Use the smol_workflows_structured_output tool as your final action exactly once.",
157 "Do not emit a final assistant message after calling smol_workflows_structured_output.",
158 ]
159 .join("\n")
160}
161
162fn build_structured_output_extension(schema: &Value) -> String {
163 let wrapped = !schema
164 .as_object()
165 .is_some_and(|o| o.get("type") == Some(&Value::String("object".into())));
166 let parameters = if wrapped {
167 format!(
168 "Type.Object({{ value: {} }})",
169 json_schema_to_typebox_expression(schema)
170 )
171 } else {
172 json_schema_to_typebox_expression(schema)
173 };
174 let details = if wrapped { "params.value" } else { "params" };
175 format!(
176 r#"import {{ defineTool, type ExtensionAPI }} from "@earendil-works/pi-coding-agent";
177import {{ Type }} from "typebox";
178
179const structuredOutputTool = defineTool({{
180 name: "smol_workflows_structured_output",
181 label: "Structured Output",
182 description: "Submit the final structured response for this agent call.",
183 promptSnippet: "Submit the final structured response with the smol_workflows_structured_output tool.",
184 promptGuidelines: [
185 "Use smol_workflows_structured_output as your final action exactly once.",
186 "The tool parameters are generated from the caller's JSON Schema.",
187 "After calling smol_workflows_structured_output, do not emit another assistant response in the same turn.",
188 ],
189 parameters: {parameters},
190 async execute(_toolCallId, params) {{
191 return {{
192 content: [{{ type: "text", text: "Structured output captured successfully." }}],
193 details: {details},
194 terminate: true,
195 }};
196 }},
197}});
198
199export default function (pi: ExtensionAPI) {{
200 pi.registerTool(structuredOutputTool);
201}}
202"#
203 )
204}
205
206fn json_schema_to_typebox_expression(schema: &Value) -> String {
207 match schema {
208 Value::Bool(true) => "Type.Any()".into(),
209 Value::Bool(false) => "Type.Never()".into(),
210 Value::Object(record) => {
211 if let Some(value) = record.get("const") {
212 return format!("Type.Literal({})", serde_json::to_string(value).unwrap());
213 }
214 if let Some(values) = record.get("enum").and_then(Value::as_array) {
215 if !values.is_empty() {
216 return if values.len() == 1 {
217 format!(
218 "Type.Literal({})",
219 serde_json::to_string(&values[0]).unwrap()
220 )
221 } else {
222 format!(
223 "Type.Union([{}])",
224 values
225 .iter()
226 .map(|v| format!(
227 "Type.Literal({})",
228 serde_json::to_string(v).unwrap()
229 ))
230 .collect::<Vec<_>>()
231 .join(", ")
232 )
233 };
234 }
235 }
236 for key in ["oneOf", "anyOf"] {
237 if let Some(values) = record.get(key).and_then(Value::as_array) {
238 if !values.is_empty() {
239 return format!(
240 "Type.Union([{}])",
241 values
242 .iter()
243 .map(json_schema_to_typebox_expression)
244 .collect::<Vec<_>>()
245 .join(", ")
246 );
247 }
248 }
249 }
250 match first_schema_type(record.get("type")).or_else(|| infer_schema_type(record)) {
251 Some("null") => "Type.Null()".into(),
252 Some("boolean") => format!("Type.Boolean({})", typebox_options(record)),
253 Some("integer") => format!("Type.Integer({})", typebox_options(record)),
254 Some("number") => format!("Type.Number({})", typebox_options(record)),
255 Some("string") => format!("Type.String({})", typebox_options(record)),
256 Some("array") => array_schema_to_typebox_expression(record),
257 Some("object") => object_schema_to_typebox_expression(record),
258 _ => "Type.Any()".into(),
259 }
260 }
261 _ => "Type.Any()".into(),
262 }
263}
264
265fn object_schema_to_typebox_expression(schema: &serde_json::Map<String, Value>) -> String {
266 let properties = schema
267 .get("properties")
268 .and_then(Value::as_object)
269 .cloned()
270 .unwrap_or_default();
271 let required = schema
272 .get("required")
273 .and_then(Value::as_array)
274 .map(|items| items.iter().filter_map(Value::as_str).collect::<Vec<_>>())
275 .unwrap_or_default();
276 let entries = properties
277 .iter()
278 .map(|(key, value)| {
279 let expression = json_schema_to_typebox_expression(value);
280 if required.iter().any(|required| required == key) {
281 format!("{}: {}", serde_json::to_string(key).unwrap(), expression)
282 } else {
283 format!(
284 "{}: Type.Optional({})",
285 serde_json::to_string(key).unwrap(),
286 expression
287 )
288 }
289 })
290 .collect::<Vec<_>>()
291 .join(", ");
292 format!("Type.Object({{ {entries} }}, {})", typebox_options(schema))
293}
294
295fn array_schema_to_typebox_expression(schema: &serde_json::Map<String, Value>) -> String {
296 let options = typebox_options(schema);
297 if let Some(prefix_items) = schema.get("prefixItems").and_then(Value::as_array) {
298 if !prefix_items.is_empty() {
299 return format!(
300 "Type.Tuple([{}], {options})",
301 prefix_items
302 .iter()
303 .map(json_schema_to_typebox_expression)
304 .collect::<Vec<_>>()
305 .join(", ")
306 );
307 }
308 }
309 let item_schema = schema
310 .get("items")
311 .filter(|value| !value.is_array())
312 .unwrap_or(&Value::Bool(true));
313 format!(
314 "Type.Array({}, {options})",
315 json_schema_to_typebox_expression(item_schema)
316 )
317}
318
319fn typebox_options(schema: &serde_json::Map<String, Value>) -> String {
320 let option_keys = [
321 "title",
322 "description",
323 "default",
324 "examples",
325 "minimum",
326 "maximum",
327 "exclusiveMinimum",
328 "exclusiveMaximum",
329 "multipleOf",
330 "minLength",
331 "maxLength",
332 "pattern",
333 "format",
334 "minItems",
335 "maxItems",
336 "uniqueItems",
337 "additionalProperties",
338 ];
339 let mut options = serde_json::Map::new();
340 for key in option_keys {
341 if let Some(value) = schema.get(key) {
342 options.insert(key.to_string(), value.clone());
343 }
344 }
345 serde_json::to_string(&Value::Object(options)).unwrap()
346}
347
348fn first_schema_type(value: Option<&Value>) -> Option<&str> {
349 match value {
350 Some(Value::String(value)) => Some(value),
351 Some(Value::Array(values)) => values.iter().find_map(Value::as_str),
352 _ => None,
353 }
354}
355
356fn infer_schema_type(schema: &serde_json::Map<String, Value>) -> Option<&'static str> {
357 if schema.contains_key("properties")
358 || schema.contains_key("required")
359 || schema.contains_key("additionalProperties")
360 {
361 Some("object")
362 } else if schema.contains_key("items") || schema.contains_key("prefixItems") {
363 Some("array")
364 } else if schema.contains_key("minimum")
365 || schema.contains_key("maximum")
366 || schema.contains_key("multipleOf")
367 {
368 Some("number")
369 } else if schema.contains_key("minLength")
370 || schema.contains_key("maxLength")
371 || schema.contains_key("pattern")
372 || schema.contains_key("format")
373 {
374 Some("string")
375 } else {
376 None
377 }
378}
379
380fn extract_structured_tool_output(events: &[Value]) -> anyhow::Result<Value> {
381 let mut output = None;
382 let mut recovered_output = None;
383 let mut started_args = HashMap::<String, Value>::new();
384 let mut calls = 0;
385 let mut successes = 0;
386 let mut errors = 0;
387
388 for event in events {
389 let Some(record) = event.as_object() else {
390 continue;
391 };
392 if record.get("toolName").and_then(Value::as_str)
393 != Some("smol_workflows_structured_output")
394 {
395 continue;
396 }
397
398 if record.get("type").and_then(Value::as_str) == Some("tool_execution_start") {
399 if let (Some(tool_call_id), Some(args)) = (
400 record.get("toolCallId").and_then(Value::as_str),
401 record.get("args").or_else(|| record.get("parameters")),
402 ) {
403 started_args.insert(tool_call_id.to_string(), args.clone());
404 }
405 continue;
406 }
407
408 if record.get("type").and_then(Value::as_str) != Some("tool_execution_end") {
409 continue;
410 }
411
412 calls += 1;
413 if record.get("isError").and_then(Value::as_bool) == Some(true) {
414 errors += 1;
415 if recovered_output.is_none() {
416 recovered_output = recover_structured_tool_arguments(event, &started_args);
417 }
418 continue;
419 }
420
421 if let Some(details) = get_path(event, &["result", "details"]) {
422 successes += 1;
423 output = Some(details.clone());
424 }
425 }
426
427 if let Some(output) = output {
428 if errors > 0 {
429 log::debug!(
430 "Pi structured-output tool had {errors} failed attempt(s) before a successful output"
431 );
432 }
433 if successes > 1 {
434 log::debug!("Pi structured-output tool returned {successes} successful outputs; using the last one");
435 }
436 return Ok(output);
437 }
438
439 if let Some(output) = recovered_output {
440 log::debug!(
441 "Pi structured-output tool failed, but attempted tool arguments were recovered from events"
442 );
443 return Ok(output);
444 }
445
446 if calls == 0 {
447 bail!("Pi provider did not call smol_workflows_structured_output for schema output");
448 }
449 if errors > 0 {
450 bail!("Pi smol_workflows_structured_output tool failed");
451 }
452 bail!("Pi smol_workflows_structured_output tool did not return details")
453}
454
455fn recover_structured_tool_arguments(
456 event: &Value,
457 started_args: &HashMap<String, Value>,
458) -> Option<Value> {
459 for path in [
460 &["result", "details"][..],
461 &["result", "input"],
462 &["state", "input"],
463 &["input"],
464 &["args"],
465 &["parameters"],
466 ] {
467 if let Some(value) = get_path(event, path) {
468 return Some(value.clone());
469 }
470 }
471
472 event
473 .get("toolCallId")
474 .and_then(Value::as_str)
475 .and_then(|tool_call_id| started_args.get(tool_call_id))
476 .cloned()
477}
478
479fn extract_output(events: &[Value]) -> Option<String> {
480 let mut output = None;
481 for event in events {
482 if let Some(value) = extract_output_from_event(event) {
483 output = Some(value);
484 }
485 }
486 output
487}
488
489fn extract_output_from_event(event: &Value) -> Option<String> {
490 let record = event.as_object()?;
491 match record.get("type").and_then(Value::as_str) {
492 Some("message_end" | "turn_end") => record
493 .get("message")
494 .and_then(extract_assistant_message_text),
495 Some("agent_end") => {
496 record
497 .get("messages")
498 .and_then(Value::as_array)
499 .and_then(|messages| {
500 messages
501 .iter()
502 .rev()
503 .find_map(extract_assistant_message_text)
504 })
505 }
506 Some("message_update") => record
507 .get("message")
508 .and_then(extract_assistant_message_text)
509 .or_else(|| {
510 record
511 .get("assistantMessageEvent")
512 .and_then(extract_assistant_message_event_text)
513 }),
514 _ => None,
515 }
516}
517
518fn extract_assistant_message_event_text(event: &Value) -> Option<String> {
519 let record = event.as_object()?;
520 match record.get("type").and_then(Value::as_str) {
521 Some("text_end") => record
522 .get("content")
523 .and_then(Value::as_str)
524 .map(ToString::to_string),
525 Some("text_delta") => record
526 .get("partial")
527 .and_then(extract_assistant_message_text),
528 _ => record
529 .get("partial")
530 .and_then(extract_assistant_message_text),
531 }
532}
533
534fn extract_assistant_message_text(message: &Value) -> Option<String> {
535 let record = message.as_object()?;
536 if record.get("role").is_some()
537 && record.get("role").and_then(Value::as_str) != Some("assistant")
538 {
539 return None;
540 }
541 record.get("content").and_then(extract_text)
542}
543
544fn extract_text(value: &Value) -> Option<String> {
545 match value {
546 Value::String(text) => Some(text.clone()),
547 Value::Array(items) => {
548 let text = items
549 .iter()
550 .map(|item| extract_text(item).unwrap_or_default())
551 .collect::<Vec<_>>()
552 .join("");
553 (!text.is_empty()).then_some(text)
554 }
555 Value::Object(record) => record
556 .get("text")
557 .or_else(|| record.get("content"))
558 .or_else(|| record.get("message"))
559 .and_then(extract_text),
560 _ => None,
561 }
562}
563
564fn extract_last_tool_result_text(events: &[Value]) -> Option<String> {
565 events.iter().rev().find_map(|event| {
566 let record = event.as_object()?;
567 if record.get("type").and_then(Value::as_str) != Some("tool_execution_end") {
568 return None;
569 }
570 record
571 .get("result")
572 .and_then(|result| {
573 result
574 .get("content")
575 .or_else(|| result.get("message"))
576 .or_else(|| result.get("text"))
577 .and_then(extract_text)
578 })
579 .or_else(|| record.get("message").and_then(extract_text))
580 })
581}
582
583fn extract_error_message(events: &[Value]) -> Option<String> {
584 events.iter().find_map(find_error_message)
585}
586
587fn find_error_message(value: &Value) -> Option<String> {
588 match value {
589 Value::Array(items) => items.iter().find_map(find_error_message),
590 Value::Object(record) => {
591 if let Some(message) = record.get("errorMessage").and_then(Value::as_str) {
592 return Some(message.to_string());
593 }
594 record.values().find_map(find_error_message)
595 }
596 _ => None,
597 }
598}
599
600fn extract_session_id(events: &[Value]) -> Option<String> {
601 for event in events {
602 if event.get("type").and_then(Value::as_str) == Some("session") {
603 if let Some(id) = event.get("id").and_then(Value::as_str) {
604 return Some(id.to_string());
605 }
606 }
607 if let Some(id) = event
608 .get("session_id")
609 .or_else(|| event.get("sessionId"))
610 .or_else(|| event.get("sessionID"))
611 .and_then(Value::as_str)
612 {
613 return Some(id.to_string());
614 }
615 }
616 None
617}
618
619fn extract_usage(events: &[Value]) -> Option<AgentUsage> {
620 let mut usage = None;
621 for event in events {
622 let mut candidates = Vec::new();
623 find_usage_objects(event, &mut candidates);
624 for candidate in candidates {
625 usage = Some(merge_usage_right(usage, normalize_usage(&candidate)));
626 }
627 }
628 usage
629}