1use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::collections::HashMap;
10
11#[derive(Debug, Clone, Default, Serialize, Deserialize)]
17pub struct TimeRecord {
18 #[serde(default)]
20 pub event_time: f64,
21 #[serde(default)]
23 pub message_time: Option<i64>,
24}
25
26impl TimeRecord {
27 pub fn now() -> Self {
29 Self {
30 event_time: Utc::now().timestamp_millis() as f64 / 1000.0,
31 message_time: None,
32 }
33 }
34}
35
36#[derive(Debug, Clone, Default, Serialize, Deserialize)]
38pub struct MessageContent {
39 #[serde(default)]
41 pub text: Option<String>,
42 #[serde(default)]
44 pub json_payload: Option<String>,
45}
46
47impl MessageContent {
48 pub fn from_text(text: impl Into<String>) -> Self {
50 Self {
51 text: Some(text.into()),
52 json_payload: None,
53 }
54 }
55
56 pub fn from_json(value: &Value) -> Self {
58 Self {
59 text: None,
60 json_payload: Some(value.to_string()),
61 }
62 }
63
64 pub fn as_text(&self) -> Option<&str> {
66 self.text.as_deref().or(self.json_payload.as_deref())
67 }
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
76#[serde(rename_all = "snake_case")]
77pub enum EventType {
78 Cais,
80 Environment,
82 Runtime,
84}
85
86impl std::fmt::Display for EventType {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 match self {
89 EventType::Cais => write!(f, "cais"),
90 EventType::Environment => write!(f, "environment"),
91 EventType::Runtime => write!(f, "runtime"),
92 }
93 }
94}
95
96#[derive(Debug, Clone, Default, Serialize, Deserialize)]
98pub struct BaseEventFields {
99 #[serde(default)]
101 pub system_instance_id: String,
102 #[serde(default)]
104 pub time_record: TimeRecord,
105 #[serde(default)]
107 pub metadata: HashMap<String, Value>,
108 #[serde(default)]
110 pub event_metadata: Option<Vec<Value>>,
111}
112
113impl BaseEventFields {
114 pub fn new(system_instance_id: impl Into<String>) -> Self {
116 Self {
117 system_instance_id: system_instance_id.into(),
118 time_record: TimeRecord::now(),
119 metadata: HashMap::new(),
120 event_metadata: None,
121 }
122 }
123}
124
125#[derive(Debug, Clone, Default, Serialize, Deserialize)]
127pub struct LMCAISEvent {
128 #[serde(flatten)]
130 pub base: BaseEventFields,
131 #[serde(default)]
133 pub model_name: String,
134 #[serde(default)]
136 pub provider: Option<String>,
137 #[serde(default)]
139 pub input_tokens: Option<i32>,
140 #[serde(default)]
142 pub output_tokens: Option<i32>,
143 #[serde(default)]
145 pub total_tokens: Option<i32>,
146 #[serde(default)]
148 pub cost_usd: Option<f64>,
149 #[serde(default)]
151 pub latency_ms: Option<i32>,
152 #[serde(default)]
154 pub span_id: Option<String>,
155 #[serde(default)]
157 pub trace_id: Option<String>,
158 #[serde(default)]
160 pub call_records: Option<Vec<LLMCallRecord>>,
161 #[serde(default)]
163 pub system_state_before: Option<Value>,
164 #[serde(default)]
166 pub system_state_after: Option<Value>,
167}
168
169#[derive(Debug, Clone, Default, Serialize, Deserialize)]
171pub struct EnvironmentEvent {
172 #[serde(flatten)]
174 pub base: BaseEventFields,
175 #[serde(default)]
177 pub reward: f64,
178 #[serde(default)]
180 pub terminated: bool,
181 #[serde(default)]
183 pub truncated: bool,
184 #[serde(default)]
186 pub system_state_before: Option<Value>,
187 #[serde(default)]
189 pub system_state_after: Option<Value>,
190}
191
192#[derive(Debug, Clone, Default, Serialize, Deserialize)]
194pub struct RuntimeEvent {
195 #[serde(flatten)]
197 pub base: BaseEventFields,
198 #[serde(default)]
200 pub actions: Vec<i32>,
201}
202
203#[derive(Debug, Clone, Serialize, Deserialize)]
205#[serde(tag = "event_type", rename_all = "snake_case")]
206pub enum TracingEvent {
207 Cais(LMCAISEvent),
209 Environment(EnvironmentEvent),
211 Runtime(RuntimeEvent),
213}
214
215impl TracingEvent {
216 pub fn event_type(&self) -> EventType {
218 match self {
219 TracingEvent::Cais(_) => EventType::Cais,
220 TracingEvent::Environment(_) => EventType::Environment,
221 TracingEvent::Runtime(_) => EventType::Runtime,
222 }
223 }
224
225 pub fn base(&self) -> &BaseEventFields {
227 match self {
228 TracingEvent::Cais(e) => &e.base,
229 TracingEvent::Environment(e) => &e.base,
230 TracingEvent::Runtime(e) => &e.base,
231 }
232 }
233
234 pub fn time_record(&self) -> &TimeRecord {
236 &self.base().time_record
237 }
238
239 pub fn system_instance_id(&self) -> &str {
241 &self.base().system_instance_id
242 }
243}
244
245#[derive(Debug, Clone, Default, Serialize, Deserialize)]
251pub struct LLMUsage {
252 #[serde(default)]
253 pub input_tokens: Option<i32>,
254 #[serde(default)]
255 pub output_tokens: Option<i32>,
256 #[serde(default)]
257 pub total_tokens: Option<i32>,
258 #[serde(default)]
259 pub reasoning_tokens: Option<i32>,
260 #[serde(default)]
261 pub reasoning_input_tokens: Option<i32>,
262 #[serde(default)]
263 pub reasoning_output_tokens: Option<i32>,
264 #[serde(default)]
265 pub cache_read_tokens: Option<i32>,
266 #[serde(default)]
267 pub cache_write_tokens: Option<i32>,
268 #[serde(default)]
269 pub billable_input_tokens: Option<i32>,
270 #[serde(default)]
271 pub billable_output_tokens: Option<i32>,
272 #[serde(default)]
273 pub cost_usd: Option<f64>,
274}
275
276#[derive(Debug, Clone, Default, Serialize, Deserialize)]
278pub struct LLMRequestParams {
279 #[serde(default)]
280 pub temperature: Option<f64>,
281 #[serde(default)]
282 pub top_p: Option<f64>,
283 #[serde(default)]
284 pub max_tokens: Option<i32>,
285 #[serde(default)]
286 pub stop: Option<Vec<String>>,
287 #[serde(default)]
288 pub top_k: Option<i32>,
289 #[serde(default)]
290 pub presence_penalty: Option<f64>,
291 #[serde(default)]
292 pub frequency_penalty: Option<f64>,
293 #[serde(default)]
294 pub repetition_penalty: Option<f64>,
295 #[serde(default)]
296 pub seed: Option<i32>,
297 #[serde(default)]
298 pub n: Option<i32>,
299 #[serde(default)]
300 pub best_of: Option<i32>,
301 #[serde(default)]
302 pub response_format: Option<Value>,
303 #[serde(default)]
304 pub json_mode: Option<bool>,
305 #[serde(default)]
306 pub tool_config: Option<Value>,
307 #[serde(default)]
308 pub raw_params: HashMap<String, Value>,
309}
310
311#[derive(Debug, Clone, Default, Serialize, Deserialize)]
313pub struct LLMContentPart {
314 #[serde(rename = "type", default)]
316 pub content_type: String,
317 #[serde(default)]
319 pub text: Option<String>,
320 #[serde(default)]
322 pub data: Option<Value>,
323 #[serde(default)]
325 pub mime_type: Option<String>,
326 #[serde(default)]
327 pub uri: Option<String>,
328 #[serde(default)]
329 pub base64_data: Option<String>,
330 #[serde(default)]
331 pub size_bytes: Option<i64>,
332 #[serde(default)]
333 pub sha256: Option<String>,
334 #[serde(default)]
335 pub width: Option<i32>,
336 #[serde(default)]
337 pub height: Option<i32>,
338 #[serde(default)]
339 pub duration_ms: Option<i32>,
340 #[serde(default)]
341 pub sample_rate: Option<i32>,
342 #[serde(default)]
343 pub channels: Option<i32>,
344 #[serde(default)]
345 pub language: Option<String>,
346}
347
348impl LLMContentPart {
349 pub fn text(text: impl Into<String>) -> Self {
351 Self {
352 content_type: "text".to_string(),
353 text: Some(text.into()),
354 data: None,
355 mime_type: None,
356 uri: None,
357 base64_data: None,
358 size_bytes: None,
359 sha256: None,
360 width: None,
361 height: None,
362 duration_ms: None,
363 sample_rate: None,
364 channels: None,
365 language: None,
366 }
367 }
368}
369
370#[derive(Debug, Clone, Default, Serialize, Deserialize)]
372pub struct LLMMessage {
373 #[serde(default)]
375 pub role: String,
376 #[serde(default)]
378 pub parts: Vec<LLMContentPart>,
379 #[serde(default)]
381 pub name: Option<String>,
382 #[serde(default)]
384 pub tool_call_id: Option<String>,
385 #[serde(default)]
387 pub metadata: HashMap<String, Value>,
388}
389
390impl LLMMessage {
391 pub fn new(role: impl Into<String>, text: impl Into<String>) -> Self {
393 Self {
394 role: role.into(),
395 parts: vec![LLMContentPart::text(text)],
396 name: None,
397 tool_call_id: None,
398 metadata: HashMap::new(),
399 }
400 }
401
402 pub fn text(&self) -> Option<&str> {
404 self.parts.iter().find_map(|p| p.text.as_deref())
405 }
406}
407
408#[derive(Debug, Clone, Default, Serialize, Deserialize)]
410pub struct ToolCallSpec {
411 #[serde(default)]
413 pub name: String,
414 #[serde(default)]
416 pub arguments_json: String,
417 #[serde(default)]
419 pub arguments: Option<Value>,
420 #[serde(default)]
422 pub call_id: Option<String>,
423 #[serde(default)]
425 pub index: Option<i32>,
426 #[serde(default)]
428 pub parent_call_id: Option<String>,
429 #[serde(default)]
431 pub metadata: HashMap<String, Value>,
432}
433
434#[derive(Debug, Clone, Default, Serialize, Deserialize)]
436pub struct ToolCallResult {
437 #[serde(default)]
439 pub call_id: Option<String>,
440 #[serde(default)]
442 pub output_text: Option<String>,
443 #[serde(default)]
445 pub exit_code: Option<i32>,
446 #[serde(default)]
448 pub status: Option<String>,
449 #[serde(default)]
451 pub error_message: Option<String>,
452 #[serde(default)]
454 pub started_at: Option<DateTime<Utc>>,
455 #[serde(default)]
457 pub completed_at: Option<DateTime<Utc>>,
458 #[serde(default)]
460 pub duration_ms: Option<i32>,
461 #[serde(default)]
463 pub metadata: HashMap<String, Value>,
464}
465
466#[derive(Debug, Clone, Default, Serialize, Deserialize)]
468pub struct LLMChunk {
469 #[serde(default)]
470 pub sequence_index: i32,
471 #[serde(default = "Utc::now")]
472 pub received_at: DateTime<Utc>,
473 #[serde(default)]
474 pub event_type: Option<String>,
475 #[serde(default)]
476 pub choice_index: Option<i32>,
477 #[serde(default)]
478 pub raw_json: Option<String>,
479 #[serde(default)]
480 pub delta_text: Option<String>,
481 #[serde(default)]
482 pub delta: Option<Value>,
483 #[serde(default)]
484 pub metadata: HashMap<String, Value>,
485}
486
487#[derive(Debug, Clone, Default, Serialize, Deserialize)]
489pub struct LLMCallRecord {
490 #[serde(default)]
492 pub call_id: String,
493 #[serde(default)]
495 pub api_type: String,
496 #[serde(default)]
498 pub provider: Option<String>,
499 #[serde(default)]
501 pub model_name: String,
502 #[serde(default)]
504 pub schema_version: Option<String>,
505 #[serde(default)]
507 pub started_at: Option<DateTime<Utc>>,
508 #[serde(default)]
510 pub completed_at: Option<DateTime<Utc>>,
511 #[serde(default)]
513 pub latency_ms: Option<i32>,
514 #[serde(default)]
516 pub request_params: LLMRequestParams,
517 #[serde(default)]
519 pub input_messages: Vec<LLMMessage>,
520 #[serde(default)]
522 pub input_text: Option<String>,
523 #[serde(default)]
525 pub tool_choice: Option<String>,
526 #[serde(default)]
528 pub output_messages: Vec<LLMMessage>,
529 #[serde(default)]
531 pub outputs: Vec<LLMMessage>,
532 #[serde(default)]
534 pub output_text: Option<String>,
535 #[serde(default)]
537 pub output_tool_calls: Vec<ToolCallSpec>,
538 #[serde(default)]
540 pub tool_results: Vec<ToolCallResult>,
541 #[serde(default)]
543 pub usage: Option<LLMUsage>,
544 #[serde(default)]
546 pub finish_reason: Option<String>,
547 #[serde(default)]
549 pub choice_index: Option<i32>,
550 #[serde(default)]
552 pub chunks: Option<Vec<LLMChunk>>,
553 #[serde(default)]
555 pub request_raw_json: Option<String>,
556 #[serde(default)]
558 pub response_raw_json: Option<String>,
559 #[serde(default)]
561 pub metadata: HashMap<String, Value>,
562 #[serde(default)]
564 pub provider_request_id: Option<String>,
565 #[serde(default)]
567 pub request_server_timing: Option<Value>,
568 #[serde(default)]
570 pub outcome: Option<String>,
571 #[serde(default)]
573 pub error: Option<Value>,
574 #[serde(default)]
576 pub token_traces: Option<Vec<Value>>,
577 #[serde(default)]
579 pub safety: Option<Value>,
580 #[serde(default)]
582 pub refusal: Option<Value>,
583 #[serde(default)]
585 pub redactions: Option<Vec<Value>>,
586}
587
588#[derive(Debug, Clone, Default, Serialize, Deserialize)]
594pub struct MarkovBlanketMessage {
595 #[serde(default)]
597 pub content: MessageContent,
598 #[serde(default)]
600 pub message_type: String,
601 #[serde(default)]
603 pub time_record: TimeRecord,
604 #[serde(default)]
606 pub metadata: HashMap<String, Value>,
607}
608
609#[derive(Debug, Clone, Serialize, Deserialize)]
611pub struct SessionTimeStep {
612 #[serde(default)]
614 pub step_id: String,
615 #[serde(default)]
617 pub step_index: i32,
618 #[serde(default = "Utc::now")]
620 pub timestamp: DateTime<Utc>,
621 #[serde(default)]
623 pub turn_number: Option<i32>,
624 #[serde(default)]
626 pub events: Vec<TracingEvent>,
627 #[serde(default)]
629 pub markov_blanket_messages: Vec<MarkovBlanketMessage>,
630 #[serde(default)]
632 pub step_metadata: HashMap<String, Value>,
633 #[serde(default)]
635 pub completed_at: Option<DateTime<Utc>>,
636}
637
638impl SessionTimeStep {
639 pub fn new(step_id: impl Into<String>, step_index: i32) -> Self {
641 Self {
642 step_id: step_id.into(),
643 step_index,
644 timestamp: Utc::now(),
645 turn_number: None,
646 events: Vec::new(),
647 markov_blanket_messages: Vec::new(),
648 step_metadata: HashMap::new(),
649 completed_at: None,
650 }
651 }
652
653 pub fn complete(&mut self) {
655 self.completed_at = Some(Utc::now());
656 }
657}
658
659#[derive(Debug, Clone, Serialize, Deserialize)]
661pub struct SessionTrace {
662 #[serde(default)]
664 pub session_id: String,
665 #[serde(default = "Utc::now")]
667 pub created_at: DateTime<Utc>,
668 #[serde(default)]
670 pub session_time_steps: Vec<SessionTimeStep>,
671 #[serde(default)]
673 pub event_history: Vec<TracingEvent>,
674 #[serde(default)]
676 pub markov_blanket_message_history: Vec<MarkovBlanketMessage>,
677 #[serde(default)]
679 pub metadata: HashMap<String, Value>,
680}
681
682impl SessionTrace {
683 pub fn new(session_id: impl Into<String>) -> Self {
685 Self {
686 session_id: session_id.into(),
687 created_at: Utc::now(),
688 session_time_steps: Vec::new(),
689 event_history: Vec::new(),
690 markov_blanket_message_history: Vec::new(),
691 metadata: HashMap::new(),
692 }
693 }
694
695 pub fn num_timesteps(&self) -> usize {
697 self.session_time_steps.len()
698 }
699
700 pub fn num_events(&self) -> usize {
702 self.event_history.len()
703 }
704
705 pub fn num_messages(&self) -> usize {
707 self.markov_blanket_message_history.len()
708 }
709}
710
711#[derive(Debug, Clone, Default, Serialize, Deserialize)]
717pub struct OutcomeReward {
718 #[serde(default = "default_objective_key")]
720 pub objective_key: String,
721 pub total_reward: f64,
723 #[serde(default)]
725 pub achievements_count: i32,
726 #[serde(default)]
728 pub total_steps: i32,
729 #[serde(default)]
731 pub reward_metadata: HashMap<String, Value>,
732 #[serde(default)]
734 pub annotation: Option<Value>,
735}
736
737fn default_objective_key() -> String {
738 "reward".to_string()
739}
740
741#[derive(Debug, Clone, Default, Serialize, Deserialize)]
743pub struct EventReward {
744 #[serde(default = "default_objective_key")]
746 pub objective_key: String,
747 pub reward_value: f64,
749 #[serde(default)]
751 pub reward_type: Option<String>,
752 #[serde(default)]
754 pub key: Option<String>,
755 #[serde(default)]
757 pub annotation: Option<Value>,
758 #[serde(default)]
760 pub source: Option<String>,
761}
762
763#[cfg(test)]
764mod tests {
765 use super::*;
766
767 #[test]
768 fn test_time_record() {
769 let tr = TimeRecord::now();
770 assert!(tr.event_time > 0.0);
771 assert!(tr.message_time.is_none());
772 }
773
774 #[test]
775 fn test_message_content() {
776 let mc = MessageContent::from_text("hello");
777 assert_eq!(mc.as_text(), Some("hello"));
778
779 let mc = MessageContent::from_json(&serde_json::json!({"key": "value"}));
780 assert!(mc.json_payload.is_some());
781 }
782
783 #[test]
784 fn test_event_serialization() {
785 let event = TracingEvent::Cais(LMCAISEvent {
786 base: BaseEventFields::new("test-system"),
787 model_name: "gpt-4".to_string(),
788 provider: Some("openai".to_string()),
789 input_tokens: Some(100),
790 output_tokens: Some(50),
791 ..Default::default()
792 });
793
794 let json = serde_json::to_string(&event).unwrap();
795 assert!(json.contains("cais"));
796 assert!(json.contains("gpt-4"));
797
798 let parsed: TracingEvent = serde_json::from_str(&json).unwrap();
799 assert_eq!(parsed.event_type(), EventType::Cais);
800 }
801
802 #[test]
803 fn test_session_trace() {
804 let mut trace = SessionTrace::new("test-session");
805 assert_eq!(trace.num_timesteps(), 0);
806
807 let step = SessionTimeStep::new("step-1", 0);
808 trace.session_time_steps.push(step);
809 assert_eq!(trace.num_timesteps(), 1);
810 }
811
812 #[test]
813 fn test_llm_message() {
814 let msg = LLMMessage::new("user", "Hello, world!");
815 assert_eq!(msg.role, "user");
816 assert_eq!(msg.text(), Some("Hello, world!"));
817 }
818}