1use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::collections::HashMap;
10
11#[derive(Debug, Clone, Default, Serialize, Deserialize)]
17pub struct TimeRecord {
18 pub event_time: f64,
20 #[serde(default)]
22 pub message_time: Option<i64>,
23}
24
25impl TimeRecord {
26 pub fn now() -> Self {
28 Self {
29 event_time: Utc::now().timestamp_millis() as f64 / 1000.0,
30 message_time: None,
31 }
32 }
33}
34
35#[derive(Debug, Clone, Default, Serialize, Deserialize)]
37pub struct MessageContent {
38 #[serde(default)]
40 pub text: Option<String>,
41 #[serde(default)]
43 pub json_payload: Option<String>,
44}
45
46impl MessageContent {
47 pub fn from_text(text: impl Into<String>) -> Self {
49 Self {
50 text: Some(text.into()),
51 json_payload: None,
52 }
53 }
54
55 pub fn from_json(value: &Value) -> Self {
57 Self {
58 text: None,
59 json_payload: Some(value.to_string()),
60 }
61 }
62
63 pub fn as_text(&self) -> Option<&str> {
65 self.text.as_deref().or(self.json_payload.as_deref())
66 }
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
75#[serde(rename_all = "snake_case")]
76pub enum EventType {
77 Cais,
79 Environment,
81 Runtime,
83}
84
85impl std::fmt::Display for EventType {
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 match self {
88 EventType::Cais => write!(f, "cais"),
89 EventType::Environment => write!(f, "environment"),
90 EventType::Runtime => write!(f, "runtime"),
91 }
92 }
93}
94
95#[derive(Debug, Clone, Default, Serialize, Deserialize)]
97pub struct BaseEventFields {
98 pub system_instance_id: String,
100 pub time_record: TimeRecord,
102 #[serde(default)]
104 pub metadata: HashMap<String, Value>,
105 #[serde(default)]
107 pub event_metadata: Option<Vec<Value>>,
108}
109
110impl BaseEventFields {
111 pub fn new(system_instance_id: impl Into<String>) -> Self {
113 Self {
114 system_instance_id: system_instance_id.into(),
115 time_record: TimeRecord::now(),
116 metadata: HashMap::new(),
117 event_metadata: None,
118 }
119 }
120}
121
122#[derive(Debug, Clone, Default, Serialize, Deserialize)]
124pub struct LMCAISEvent {
125 #[serde(flatten)]
127 pub base: BaseEventFields,
128 pub model_name: String,
130 #[serde(default)]
132 pub provider: Option<String>,
133 #[serde(default)]
135 pub input_tokens: Option<i32>,
136 #[serde(default)]
138 pub output_tokens: Option<i32>,
139 #[serde(default)]
141 pub total_tokens: Option<i32>,
142 #[serde(default)]
144 pub cost_usd: Option<f64>,
145 #[serde(default)]
147 pub latency_ms: Option<i32>,
148 #[serde(default)]
150 pub span_id: Option<String>,
151 #[serde(default)]
153 pub trace_id: Option<String>,
154 #[serde(default)]
156 pub call_records: Option<Vec<LLMCallRecord>>,
157 #[serde(default)]
159 pub system_state_before: Option<Value>,
160 #[serde(default)]
162 pub system_state_after: Option<Value>,
163}
164
165#[derive(Debug, Clone, Default, Serialize, Deserialize)]
167pub struct EnvironmentEvent {
168 #[serde(flatten)]
170 pub base: BaseEventFields,
171 #[serde(default)]
173 pub reward: f64,
174 #[serde(default)]
176 pub terminated: bool,
177 #[serde(default)]
179 pub truncated: bool,
180 #[serde(default)]
182 pub system_state_before: Option<Value>,
183 #[serde(default)]
185 pub system_state_after: Option<Value>,
186}
187
188#[derive(Debug, Clone, Default, Serialize, Deserialize)]
190pub struct RuntimeEvent {
191 #[serde(flatten)]
193 pub base: BaseEventFields,
194 #[serde(default)]
196 pub actions: Vec<i32>,
197}
198
199#[derive(Debug, Clone, Serialize, Deserialize)]
201#[serde(tag = "event_type", rename_all = "snake_case")]
202pub enum TracingEvent {
203 Cais(LMCAISEvent),
205 Environment(EnvironmentEvent),
207 Runtime(RuntimeEvent),
209}
210
211impl TracingEvent {
212 pub fn event_type(&self) -> EventType {
214 match self {
215 TracingEvent::Cais(_) => EventType::Cais,
216 TracingEvent::Environment(_) => EventType::Environment,
217 TracingEvent::Runtime(_) => EventType::Runtime,
218 }
219 }
220
221 pub fn base(&self) -> &BaseEventFields {
223 match self {
224 TracingEvent::Cais(e) => &e.base,
225 TracingEvent::Environment(e) => &e.base,
226 TracingEvent::Runtime(e) => &e.base,
227 }
228 }
229
230 pub fn time_record(&self) -> &TimeRecord {
232 &self.base().time_record
233 }
234
235 pub fn system_instance_id(&self) -> &str {
237 &self.base().system_instance_id
238 }
239}
240
241#[derive(Debug, Clone, Default, Serialize, Deserialize)]
247pub struct LLMUsage {
248 #[serde(default)]
249 pub input_tokens: Option<i32>,
250 #[serde(default)]
251 pub output_tokens: Option<i32>,
252 #[serde(default)]
253 pub total_tokens: Option<i32>,
254 #[serde(default)]
255 pub reasoning_tokens: Option<i32>,
256 #[serde(default)]
257 pub cache_read_tokens: Option<i32>,
258 #[serde(default)]
259 pub cache_write_tokens: Option<i32>,
260 #[serde(default)]
261 pub cost_usd: Option<f64>,
262}
263
264#[derive(Debug, Clone, Default, Serialize, Deserialize)]
266pub struct LLMContentPart {
267 #[serde(rename = "type")]
269 pub content_type: String,
270 #[serde(default)]
272 pub text: Option<String>,
273 #[serde(default)]
275 pub data: Option<Value>,
276 #[serde(default)]
278 pub mime_type: Option<String>,
279}
280
281impl LLMContentPart {
282 pub fn text(text: impl Into<String>) -> Self {
284 Self {
285 content_type: "text".to_string(),
286 text: Some(text.into()),
287 data: None,
288 mime_type: None,
289 }
290 }
291}
292
293#[derive(Debug, Clone, Default, Serialize, Deserialize)]
295pub struct LLMMessage {
296 pub role: String,
298 #[serde(default)]
300 pub parts: Vec<LLMContentPart>,
301 #[serde(default)]
303 pub name: Option<String>,
304 #[serde(default)]
306 pub tool_call_id: Option<String>,
307 #[serde(default)]
309 pub metadata: HashMap<String, Value>,
310}
311
312impl LLMMessage {
313 pub fn new(role: impl Into<String>, text: impl Into<String>) -> Self {
315 Self {
316 role: role.into(),
317 parts: vec![LLMContentPart::text(text)],
318 name: None,
319 tool_call_id: None,
320 metadata: HashMap::new(),
321 }
322 }
323
324 pub fn text(&self) -> Option<&str> {
326 self.parts.iter().find_map(|p| p.text.as_deref())
327 }
328}
329
330#[derive(Debug, Clone, Default, Serialize, Deserialize)]
332pub struct ToolCallSpec {
333 pub name: String,
335 pub arguments_json: String,
337 #[serde(default)]
339 pub call_id: Option<String>,
340 #[serde(default)]
342 pub index: Option<i32>,
343}
344
345#[derive(Debug, Clone, Default, Serialize, Deserialize)]
347pub struct ToolCallResult {
348 #[serde(default)]
350 pub call_id: Option<String>,
351 #[serde(default)]
353 pub output_text: Option<String>,
354 #[serde(default)]
356 pub exit_code: Option<i32>,
357 #[serde(default)]
359 pub status: Option<String>,
360 #[serde(default)]
362 pub error_message: Option<String>,
363 #[serde(default)]
365 pub duration_ms: Option<i32>,
366}
367
368#[derive(Debug, Clone, Default, Serialize, Deserialize)]
370pub struct LLMCallRecord {
371 pub call_id: String,
373 pub api_type: String,
375 #[serde(default)]
377 pub provider: Option<String>,
378 pub model_name: String,
380 #[serde(default)]
382 pub started_at: Option<DateTime<Utc>>,
383 #[serde(default)]
385 pub completed_at: Option<DateTime<Utc>>,
386 #[serde(default)]
388 pub latency_ms: Option<i32>,
389 #[serde(default)]
391 pub input_messages: Vec<LLMMessage>,
392 #[serde(default)]
394 pub output_messages: Vec<LLMMessage>,
395 #[serde(default)]
397 pub output_tool_calls: Vec<ToolCallSpec>,
398 #[serde(default)]
400 pub tool_results: Vec<ToolCallResult>,
401 #[serde(default)]
403 pub usage: Option<LLMUsage>,
404 #[serde(default)]
406 pub finish_reason: Option<String>,
407 #[serde(default)]
409 pub metadata: HashMap<String, Value>,
410}
411
412#[derive(Debug, Clone, Default, Serialize, Deserialize)]
418pub struct MarkovBlanketMessage {
419 pub content: MessageContent,
421 pub message_type: String,
423 pub time_record: TimeRecord,
425 #[serde(default)]
427 pub metadata: HashMap<String, Value>,
428}
429
430#[derive(Debug, Clone, Serialize, Deserialize)]
432pub struct SessionTimeStep {
433 pub step_id: String,
435 pub step_index: i32,
437 pub timestamp: DateTime<Utc>,
439 #[serde(default)]
441 pub turn_number: Option<i32>,
442 #[serde(default)]
444 pub events: Vec<TracingEvent>,
445 #[serde(default)]
447 pub markov_blanket_messages: Vec<MarkovBlanketMessage>,
448 #[serde(default)]
450 pub step_metadata: HashMap<String, Value>,
451 #[serde(default)]
453 pub completed_at: Option<DateTime<Utc>>,
454}
455
456impl SessionTimeStep {
457 pub fn new(step_id: impl Into<String>, step_index: i32) -> Self {
459 Self {
460 step_id: step_id.into(),
461 step_index,
462 timestamp: Utc::now(),
463 turn_number: None,
464 events: Vec::new(),
465 markov_blanket_messages: Vec::new(),
466 step_metadata: HashMap::new(),
467 completed_at: None,
468 }
469 }
470
471 pub fn complete(&mut self) {
473 self.completed_at = Some(Utc::now());
474 }
475}
476
477#[derive(Debug, Clone, Serialize, Deserialize)]
479pub struct SessionTrace {
480 pub session_id: String,
482 pub created_at: DateTime<Utc>,
484 #[serde(default)]
486 pub session_time_steps: Vec<SessionTimeStep>,
487 #[serde(default)]
489 pub event_history: Vec<TracingEvent>,
490 #[serde(default)]
492 pub markov_blanket_message_history: Vec<MarkovBlanketMessage>,
493 #[serde(default)]
495 pub metadata: HashMap<String, Value>,
496}
497
498impl SessionTrace {
499 pub fn new(session_id: impl Into<String>) -> Self {
501 Self {
502 session_id: session_id.into(),
503 created_at: Utc::now(),
504 session_time_steps: Vec::new(),
505 event_history: Vec::new(),
506 markov_blanket_message_history: Vec::new(),
507 metadata: HashMap::new(),
508 }
509 }
510
511 pub fn num_timesteps(&self) -> usize {
513 self.session_time_steps.len()
514 }
515
516 pub fn num_events(&self) -> usize {
518 self.event_history.len()
519 }
520
521 pub fn num_messages(&self) -> usize {
523 self.markov_blanket_message_history.len()
524 }
525}
526
527#[derive(Debug, Clone, Default, Serialize, Deserialize)]
533pub struct OutcomeReward {
534 #[serde(default = "default_objective_key")]
536 pub objective_key: String,
537 pub total_reward: f64,
539 #[serde(default)]
541 pub achievements_count: i32,
542 #[serde(default)]
544 pub total_steps: i32,
545 #[serde(default)]
547 pub reward_metadata: HashMap<String, Value>,
548 #[serde(default)]
550 pub annotation: Option<Value>,
551}
552
553fn default_objective_key() -> String {
554 "reward".to_string()
555}
556
557#[derive(Debug, Clone, Default, Serialize, Deserialize)]
559pub struct EventReward {
560 #[serde(default = "default_objective_key")]
562 pub objective_key: String,
563 pub reward_value: f64,
565 #[serde(default)]
567 pub reward_type: Option<String>,
568 #[serde(default)]
570 pub key: Option<String>,
571 #[serde(default)]
573 pub annotation: Option<Value>,
574 #[serde(default)]
576 pub source: Option<String>,
577}
578
579#[cfg(test)]
580mod tests {
581 use super::*;
582
583 #[test]
584 fn test_time_record() {
585 let tr = TimeRecord::now();
586 assert!(tr.event_time > 0.0);
587 assert!(tr.message_time.is_none());
588 }
589
590 #[test]
591 fn test_message_content() {
592 let mc = MessageContent::from_text("hello");
593 assert_eq!(mc.as_text(), Some("hello"));
594
595 let mc = MessageContent::from_json(&serde_json::json!({"key": "value"}));
596 assert!(mc.json_payload.is_some());
597 }
598
599 #[test]
600 fn test_event_serialization() {
601 let event = TracingEvent::Cais(LMCAISEvent {
602 base: BaseEventFields::new("test-system"),
603 model_name: "gpt-4".to_string(),
604 provider: Some("openai".to_string()),
605 input_tokens: Some(100),
606 output_tokens: Some(50),
607 ..Default::default()
608 });
609
610 let json = serde_json::to_string(&event).unwrap();
611 assert!(json.contains("cais"));
612 assert!(json.contains("gpt-4"));
613
614 let parsed: TracingEvent = serde_json::from_str(&json).unwrap();
615 assert_eq!(parsed.event_type(), EventType::Cais);
616 }
617
618 #[test]
619 fn test_session_trace() {
620 let mut trace = SessionTrace::new("test-session");
621 assert_eq!(trace.num_timesteps(), 0);
622
623 let step = SessionTimeStep::new("step-1", 0);
624 trace.session_time_steps.push(step);
625 assert_eq!(trace.num_timesteps(), 1);
626 }
627
628 #[test]
629 fn test_llm_message() {
630 let msg = LLMMessage::new("user", "Hello, world!");
631 assert_eq!(msg.role, "user");
632 assert_eq!(msg.text(), Some("Hello, world!"));
633 }
634}