1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::sync::Arc;
4use std::sync::atomic::{AtomicBool, Ordering};
5
6#[derive(Clone, Debug, Default)]
22pub struct CancelToken {
23 flag: Arc<AtomicBool>,
24}
25
26impl CancelToken {
27 pub fn new() -> Self {
29 Self::default()
30 }
31
32 pub fn from_flag(flag: Arc<AtomicBool>) -> Self {
36 Self { flag }
37 }
38
39 pub fn flag(&self) -> Arc<AtomicBool> {
41 self.flag.clone()
42 }
43
44 pub fn is_cancelled(&self) -> bool {
46 self.flag.load(Ordering::Relaxed)
47 }
48
49 pub fn cancel(&self) {
51 self.flag.store(true, Ordering::Relaxed);
52 }
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct TextContent {
61 pub text: String,
62 #[serde(skip_serializing_if = "Option::is_none")]
63 pub text_signature: Option<String>,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct ThinkingContent {
68 pub thinking: String,
69 #[serde(skip_serializing_if = "Option::is_none")]
70 pub thinking_signature: Option<String>,
71 #[serde(default, skip_serializing_if = "std::ops::Not::not")]
72 pub redacted: bool,
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct ImageContent {
77 pub data: String, pub mime_type: String,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct ToolCall {
83 pub id: String,
84 pub name: String,
85 pub arguments: serde_json::Value,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
89#[serde(tag = "type", rename_all = "snake_case")]
90pub enum UserContent {
91 Text(TextContent),
92 Image(ImageContent),
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
96#[serde(tag = "type", rename_all = "snake_case")]
97pub enum AssistantContent {
98 Text(TextContent),
99 Thinking(ThinkingContent),
100 ToolCall(ToolCall),
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
104#[serde(tag = "type", rename_all = "snake_case")]
105pub enum ToolResultContent {
106 Text(TextContent),
107 Image(ImageContent),
108}
109
110impl ToolResultContent {
111 pub fn text(&self) -> &str {
113 match self {
114 Self::Text(t) => &t.text,
115 Self::Image(_) => "",
116 }
117 }
118}
119
120#[derive(Debug, Clone, Default, Serialize, Deserialize)]
125pub struct Cost {
126 pub input: f64,
127 pub output: f64,
128 pub cache_read: f64,
129 pub cache_write: f64,
130 pub total: f64,
131}
132
133#[derive(Debug, Clone, Default, Serialize, Deserialize)]
134pub struct Usage {
135 pub input: u64,
136 pub output: u64,
137 pub cache_read: u64,
138 pub cache_write: u64,
139 pub total_tokens: u64,
140 pub cost: Cost,
141}
142
143impl Usage {
144 pub fn recompute_total(&mut self) {
146 self.total_tokens = self.input + self.output + self.cache_read + self.cache_write;
147 }
148}
149
150#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
155#[serde(rename_all = "snake_case")]
156pub enum StopReason {
157 Stop,
158 Length,
159 ToolUse,
160 Error,
161 Aborted,
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct UserMessage {
170 pub content: Vec<UserContent>,
171 pub timestamp: u64,
172}
173
174impl UserMessage {
175 pub fn text(text: impl Into<String>) -> Self {
176 Self {
177 content: vec![UserContent::Text(TextContent {
178 text: text.into(),
179 text_signature: None,
180 })],
181 timestamp: timestamp_ms(),
182 }
183 }
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize)]
187pub struct AssistantMessage {
188 pub content: Vec<AssistantContent>,
189 pub api: String,
190 pub provider: String,
191 pub model: String,
192 #[serde(skip_serializing_if = "Option::is_none")]
193 pub response_id: Option<String>,
194 pub usage: Usage,
195 pub stop_reason: StopReason,
196 #[serde(skip_serializing_if = "Option::is_none")]
197 pub error_message: Option<String>,
198 pub timestamp: u64,
199}
200
201impl AssistantMessage {
202 pub fn empty(api: &str, provider: &str, model: &str) -> Self {
203 Self {
204 content: Vec::new(),
205 api: api.to_string(),
206 provider: provider.to_string(),
207 model: model.to_string(),
208 response_id: None,
209 usage: Usage::default(),
210 stop_reason: StopReason::Stop,
211 error_message: None,
212 timestamp: timestamp_ms(),
213 }
214 }
215
216 pub fn text(&self) -> String {
218 self.content
219 .iter()
220 .filter_map(|c| match c {
221 AssistantContent::Text(t) => Some(t.text.as_str()),
222 _ => None,
223 })
224 .collect::<Vec<_>>()
225 .join("")
226 }
227}
228
229#[derive(Debug, Clone, Serialize, Deserialize)]
230pub struct ToolResultMessage {
231 pub tool_call_id: String,
232 pub tool_name: String,
233 pub content: Vec<ToolResultContent>,
234 #[serde(skip_serializing_if = "Option::is_none")]
235 pub details: Option<serde_json::Value>,
236 pub is_error: bool,
237 pub timestamp: u64,
238 #[serde(default, skip_serializing_if = "Option::is_none")]
239 pub duration_ms: Option<u64>,
240 #[serde(default, skip_serializing_if = "Option::is_none")]
241 pub summary: Option<String>,
242 #[serde(default, skip_serializing, skip_deserializing)]
251 pub post_persist_actions: Vec<PostPersistAction>,
252}
253
254#[derive(Debug, Clone, Serialize, Deserialize)]
257#[serde(tag = "type", rename_all = "snake_case")]
258pub enum PostPersistAction {
259 EmitInfoMessage {
263 target_session_id: String,
264 text: String,
265 },
266 StopAgentLoop {
274 reason: String,
278 },
279}
280
281#[derive(Debug, Clone, Serialize, Deserialize)]
285#[serde(tag = "type", rename_all = "snake_case")]
286pub enum PostIdleAction {
287 ArchiveTaskSessions { task_id: i64 },
291}
292
293impl ToolResultMessage {
294 pub fn success(
295 id: impl Into<String>,
296 name: impl Into<String>,
297 text: impl Into<String>,
298 ) -> Self {
299 Self {
300 tool_call_id: id.into(),
301 tool_name: name.into(),
302 content: vec![ToolResultContent::Text(TextContent {
303 text: text.into(),
304 text_signature: None,
305 })],
306 details: None,
307 is_error: false,
308 timestamp: timestamp_ms(),
309 duration_ms: None,
310 summary: None,
311 post_persist_actions: Vec::new(),
312 }
313 }
314
315 pub fn error(id: impl Into<String>, name: impl Into<String>, text: impl Into<String>) -> Self {
316 Self {
317 tool_call_id: id.into(),
318 tool_name: name.into(),
319 content: vec![ToolResultContent::Text(TextContent {
320 text: text.into(),
321 text_signature: None,
322 })],
323 details: None,
324 is_error: true,
325 timestamp: timestamp_ms(),
326 duration_ms: None,
327 summary: None,
328 post_persist_actions: Vec::new(),
329 }
330 }
331}
332
333#[derive(Debug, Clone, Serialize, Deserialize)]
334pub struct CompactionSummaryMessage {
335 pub summary: String,
336 pub tokens_before: u64,
338 pub timestamp: u64,
339}
340
341#[derive(Debug, Clone, Serialize, Deserialize)]
342pub struct InfoMessage {
343 pub text: String,
344 pub timestamp: u64,
345}
346
347impl InfoMessage {
348 pub fn new(text: impl Into<String>) -> Self {
349 Self {
350 text: text.into(),
351 timestamp: timestamp_ms(),
352 }
353 }
354}
355
356#[derive(Debug, Clone, Serialize, Deserialize)]
357#[serde(tag = "role", rename_all = "snake_case")]
358pub enum Message {
359 User(UserMessage),
360 Assistant(AssistantMessage),
361 ToolResult(ToolResultMessage),
362 CompactionSummary(CompactionSummaryMessage),
363 Info(InfoMessage),
364}
365
366#[derive(Debug, Clone, Serialize, Deserialize)]
371pub struct ModelCost {
372 pub input: f64, pub output: f64, pub cache_read: f64,
375 pub cache_write: f64,
376}
377
378impl Default for ModelCost {
379 fn default() -> Self {
380 Self {
381 input: 0.0,
382 output: 0.0,
383 cache_read: 0.0,
384 cache_write: 0.0,
385 }
386 }
387}
388
389#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)]
391#[serde(rename_all = "snake_case")]
392pub enum ThinkingStyle {
393 #[default]
395 None,
396 Anthropic,
398 #[serde(alias = "openai")]
400 OpenAi,
401 Qwen,
403}
404
405#[derive(Debug, Clone, Serialize, Deserialize)]
406pub struct Model {
407 pub id: String,
408 pub name: String,
409 pub api: String,
410 pub provider: String,
411 pub base_url: String,
412 #[serde(default)]
413 pub thinking: ThinkingStyle,
414 pub cost: ModelCost,
415 pub context_window: u64,
416 pub max_tokens: u64,
417 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
418 pub headers: HashMap<String, String>,
419}
420
421impl Model {
422 pub fn calculate_cost(&self, usage: &mut Usage) {
423 usage.cost.input = (self.cost.input / 1_000_000.0) * usage.input as f64;
424 usage.cost.output = (self.cost.output / 1_000_000.0) * usage.output as f64;
425 usage.cost.cache_read = (self.cost.cache_read / 1_000_000.0) * usage.cache_read as f64;
426 usage.cost.cache_write = (self.cost.cache_write / 1_000_000.0) * usage.cache_write as f64;
427 usage.cost.total =
428 usage.cost.input + usage.cost.output + usage.cost.cache_read + usage.cost.cache_write;
429 }
430}
431
432#[derive(Debug, Clone, Serialize, Deserialize)]
437pub struct Tool {
438 pub name: String,
439 pub description: String,
440 pub parameters: serde_json::Value,
442}
443
444#[derive(Debug, Clone, Serialize, Deserialize)]
449pub struct Context {
450 #[serde(skip_serializing_if = "Option::is_none")]
451 pub system_prompt: Option<String>,
452 pub messages: Vec<Message>,
453 #[serde(default, skip_serializing_if = "Vec::is_empty")]
454 pub tools: Vec<Tool>,
455}
456
457#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
468#[serde(rename_all = "snake_case")]
469pub enum ThinkingEffort {
470 Low,
471 Medium,
472 High,
473 XHigh,
474 Max,
475}
476
477#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
483#[serde(rename_all = "snake_case")]
484pub enum ThinkingDisplay {
485 Summarized,
487 Omitted,
490}
491
492#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
499#[serde(rename_all = "lowercase")]
500pub enum CacheRetention {
501 None,
502 Short,
503 Long,
504}
505
506impl Default for CacheRetention {
507 fn default() -> Self {
508 Self::Short
509 }
510}
511
512impl CacheRetention {
513 pub fn resolve(opt: Option<Self>) -> Self {
519 opt.unwrap_or_default()
520 }
521
522 pub fn resolve_with_env(opt: Option<Self>) -> Self {
528 if let Some(v) = opt {
529 return v;
530 }
531 match std::env::var("PI_CACHE_RETENTION").ok().as_deref() {
532 Some("long") => Self::Long,
533 _ => Self::Short,
534 }
535 }
536}
537
538#[derive(Debug, Clone, Default, Serialize, Deserialize)]
539pub struct StreamOptions {
540 #[serde(skip_serializing_if = "Option::is_none")]
541 pub temperature: Option<f64>,
542 #[serde(skip_serializing_if = "Option::is_none")]
543 pub max_tokens: Option<u64>,
544 #[serde(skip_serializing_if = "Option::is_none")]
545 pub api_key: Option<String>,
546 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
547 pub headers: HashMap<String, String>,
548 #[serde(skip_serializing_if = "Option::is_none")]
550 pub thinking_budget: Option<u64>,
551 #[serde(skip_serializing_if = "Option::is_none")]
559 pub thinking_enabled: Option<bool>,
560 #[serde(skip_serializing_if = "Option::is_none")]
562 pub thinking_effort: Option<ThinkingEffort>,
563 #[serde(skip_serializing_if = "Option::is_none")]
567 pub thinking_display: Option<ThinkingDisplay>,
568 #[serde(skip_serializing_if = "Option::is_none")]
573 pub session_id: Option<String>,
574 #[serde(skip_serializing_if = "Option::is_none")]
578 pub cache_retention: Option<CacheRetention>,
579}
580
581#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
589pub enum AgentPhase {
590 #[default]
592 Idle,
593 Waiting,
595 Preparing,
597 Connecting,
599 Thinking,
601 Responding,
603 ToolExec,
605 Compacting,
607 RateLimited,
609}
610
611impl AgentPhase {
612 pub fn label(&self) -> &'static str {
614 match self {
615 Self::Idle => "idle",
616 Self::Waiting => "waiting...",
617 Self::Preparing => "preparing...",
618 Self::Connecting => "sending request...",
619 Self::Thinking => "thinking...",
620 Self::Responding => "working...",
621 Self::ToolExec => "running tools...",
622 Self::Compacting => "compacting...",
623 Self::RateLimited => "rate limited...",
624 }
625 }
626}
627
628#[derive(Debug, Clone, Serialize, Deserialize)]
633#[serde(tag = "type", rename_all = "snake_case")]
634pub enum StreamEvent {
635 Start {
636 partial: AssistantMessage,
637 },
638 TextStart {
639 content_index: usize,
640 partial: AssistantMessage,
641 },
642 TextDelta {
643 content_index: usize,
644 delta: String,
645 partial: AssistantMessage,
646 },
647 TextEnd {
648 content_index: usize,
649 content: String,
650 partial: AssistantMessage,
651 },
652 ThinkingStart {
653 content_index: usize,
654 partial: AssistantMessage,
655 },
656 ThinkingDelta {
657 content_index: usize,
658 delta: String,
659 partial: AssistantMessage,
660 },
661 ThinkingEnd {
662 content_index: usize,
663 content: String,
664 partial: AssistantMessage,
665 },
666 ToolcallStart {
667 content_index: usize,
668 partial: AssistantMessage,
669 },
670 ToolcallDelta {
671 content_index: usize,
672 delta: String,
673 partial: AssistantMessage,
674 },
675 ToolcallEnd {
676 content_index: usize,
677 tool_call: ToolCall,
678 partial: AssistantMessage,
679 },
680 ToolOutputDelta {
682 tool_call_id: String,
683 delta: String,
684 },
685 ToolResult {
687 tool_call_id: String,
688 tool_name: String,
689 is_error: bool,
690 content: String,
692 summary: Option<String>,
693 },
694 Done {
695 reason: StopReason,
696 message: AssistantMessage,
697 },
698 Error {
699 reason: StopReason,
700 error: AssistantMessage,
701 },
702 SteerMessage {
704 message: UserMessage,
705 },
706 Phase {
729 phase: AgentPhase,
730 #[serde(default, skip_serializing_if = "Option::is_none")]
731 turn_started_at_ms: Option<u64>,
732 #[serde(default, skip_serializing_if = "Option::is_none")]
733 phase_started_at_ms: Option<u64>,
734 },
735 Status {
737 message: String,
738 },
739}
740
741pub fn timestamp_ms() -> u64 {
746 std::time::SystemTime::now()
747 .duration_since(std::time::UNIX_EPOCH)
748 .unwrap_or_default()
749 .as_millis() as u64
750}
751
752#[cfg(test)]
753mod tests {
754 use super::*;
755
756 #[test]
757 fn usage_recompute_total_sums_fields() {
758 let mut u = Usage {
759 input: 10,
760 output: 20,
761 cache_read: 3,
762 cache_write: 4,
763 total_tokens: 0,
764 cost: Cost::default(),
765 };
766 u.recompute_total();
767 assert_eq!(u.total_tokens, 37);
768
769 u.recompute_total();
771 assert_eq!(u.total_tokens, 37);
772
773 u.total_tokens = 999;
775 u.recompute_total();
776 assert_eq!(u.total_tokens, 37);
777 }
778
779 #[test]
780 fn info_message_serde_roundtrip() {
781 let msg = Message::Info(InfoMessage {
782 text: "task state changed".into(),
783 timestamp: 12345,
784 });
785 let json = serde_json::to_string(&msg).unwrap();
786 assert!(json.contains(r#""role":"info""#));
787 assert!(json.contains(r#""text":"task state changed""#));
788 let deserialized: Message = serde_json::from_str(&json).unwrap();
789 assert!(
790 matches!(deserialized, Message::Info(i) if i.text == "task state changed" && i.timestamp == 12345)
791 );
792 }
793
794 #[test]
795 fn tool_result_message_duration_ms_roundtrip() {
796 let msg = ToolResultMessage {
797 tool_call_id: "tc1".into(),
798 tool_name: "bash".into(),
799 content: vec![ToolResultContent::Text(TextContent {
800 text: "ok".into(),
801 text_signature: None,
802 })],
803 details: None,
804 is_error: false,
805 timestamp: 1000,
806 duration_ms: Some(1234),
807 summary: None,
808 post_persist_actions: Vec::new(),
809 };
810 let json = serde_json::to_string(&msg).expect("serialize");
811 assert!(json.contains("\"duration_ms\":1234"));
812 let deserialized: ToolResultMessage = serde_json::from_str(&json).expect("deserialize");
813 assert_eq!(deserialized.duration_ms, Some(1234));
814 }
815
816 #[test]
817 fn tool_result_message_duration_ms_backward_compat() {
818 let json = r#"{"tool_call_id":"tc1","tool_name":"bash","content":[{"type":"text","text":"ok"}],"is_error":false,"timestamp":1000}"#;
820 let msg: ToolResultMessage = serde_json::from_str(json).expect("deserialize");
821 assert_eq!(msg.duration_ms, None);
822 }
823
824 #[test]
825 fn tool_result_message_duration_ms_none_not_serialized() {
826 let msg = ToolResultMessage::success("tc1", "bash", "ok");
827 let json = serde_json::to_string(&msg).expect("serialize");
828 assert!(
829 !json.contains("duration_ms"),
830 "duration_ms: None should not appear in JSON"
831 );
832 }
833
834 #[test]
835 fn tool_result_message_summary_roundtrip() {
836 let mut msg = ToolResultMessage::success("tc1", "read", "file contents...");
837 msg.summary = Some("read: src/main.rs (42 lines)".into());
838 let json = serde_json::to_string(&msg).expect("serialize");
839 assert!(json.contains("\"summary\":\"read: src/main.rs (42 lines)\""));
840 let deserialized: ToolResultMessage = serde_json::from_str(&json).expect("deserialize");
841 assert_eq!(
842 deserialized.summary,
843 Some("read: src/main.rs (42 lines)".into())
844 );
845 }
846
847 #[test]
848 fn tool_result_message_summary_backward_compat() {
849 let json = r#"{"tool_call_id":"tc1","tool_name":"bash","content":[{"type":"text","text":"ok"}],"is_error":false,"timestamp":1000}"#;
850 let msg: ToolResultMessage = serde_json::from_str(json).expect("deserialize");
851 assert_eq!(msg.summary, None);
852 }
853
854 #[test]
855 fn cache_retention_resolve_defaults_to_short() {
856 assert_eq!(CacheRetention::resolve(None), CacheRetention::Short);
857 assert_eq!(
858 CacheRetention::resolve(Some(CacheRetention::Long)),
859 CacheRetention::Long
860 );
861 assert_eq!(
862 CacheRetention::resolve(Some(CacheRetention::None)),
863 CacheRetention::None
864 );
865 }
866
867 #[test]
868 fn cache_retention_resolve_with_env_explicit_wins() {
869 assert_eq!(
872 CacheRetention::resolve_with_env(Some(CacheRetention::None)),
873 CacheRetention::None
874 );
875 assert_eq!(
876 CacheRetention::resolve_with_env(Some(CacheRetention::Short)),
877 CacheRetention::Short
878 );
879 assert_eq!(
880 CacheRetention::resolve_with_env(Some(CacheRetention::Long)),
881 CacheRetention::Long
882 );
883 }
884
885 #[test]
886
887 fn tool_result_message_summary_none_not_serialized() {
888 let msg = ToolResultMessage::success("tc1", "bash", "ok");
889 let json = serde_json::to_string(&msg).expect("serialize");
890 assert!(
891 !json.contains("summary"),
892 "summary: None should not appear in JSON"
893 );
894 }
895}