Skip to main content

synth_ai_core/tracing/
models.rs

1//! Tracing data models.
2//!
3//! This module contains all the data structures for trace recording,
4//! corresponding to Python's `synth_ai.data.traces` and `synth_ai.data.llm_calls`.
5
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::collections::HashMap;
10
11// ============================================================================
12// TIME & CONTENT
13// ============================================================================
14
15/// Time record for events and messages.
16#[derive(Debug, Clone, Default, Serialize, Deserialize)]
17pub struct TimeRecord {
18    /// Unix timestamp for the event
19    pub event_time: f64,
20    /// Optional message-specific timestamp
21    #[serde(default)]
22    pub message_time: Option<i64>,
23}
24
25impl TimeRecord {
26    /// Create a new time record with the current time.
27    pub fn now() -> Self {
28        Self {
29            event_time: Utc::now().timestamp_millis() as f64 / 1000.0,
30            message_time: None,
31        }
32    }
33}
34
35/// Content for messages, supporting text or JSON.
36#[derive(Debug, Clone, Default, Serialize, Deserialize)]
37pub struct MessageContent {
38    /// Plain text content
39    #[serde(default)]
40    pub text: Option<String>,
41    /// JSON-serialized content
42    #[serde(default)]
43    pub json_payload: Option<String>,
44}
45
46impl MessageContent {
47    /// Create from text.
48    pub fn from_text(text: impl Into<String>) -> Self {
49        Self {
50            text: Some(text.into()),
51            json_payload: None,
52        }
53    }
54
55    /// Create from JSON value.
56    pub fn from_json(value: &Value) -> Self {
57        Self {
58            text: None,
59            json_payload: Some(value.to_string()),
60        }
61    }
62
63    /// Get content as text (either raw text or JSON string).
64    pub fn as_text(&self) -> Option<&str> {
65        self.text.as_deref().or(self.json_payload.as_deref())
66    }
67}
68
69// ============================================================================
70// EVENTS
71// ============================================================================
72
73/// Event type discriminator.
74#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
75#[serde(rename_all = "snake_case")]
76pub enum EventType {
77    /// LLM/CAIS call event
78    Cais,
79    /// Environment step event (Gym-style)
80    Environment,
81    /// Runtime/action selection event
82    Runtime,
83}
84
85impl std::fmt::Display for EventType {
86    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87        match self {
88            EventType::Cais => write!(f, "cais"),
89            EventType::Environment => write!(f, "environment"),
90            EventType::Runtime => write!(f, "runtime"),
91        }
92    }
93}
94
95/// Base fields common to all events.
96#[derive(Debug, Clone, Default, Serialize, Deserialize)]
97pub struct BaseEventFields {
98    /// System/component instance ID
99    pub system_instance_id: String,
100    /// Time record
101    pub time_record: TimeRecord,
102    /// Event metadata (key-value pairs)
103    #[serde(default)]
104    pub metadata: HashMap<String, Value>,
105    /// Structured event metadata
106    #[serde(default)]
107    pub event_metadata: Option<Vec<Value>>,
108}
109
110impl BaseEventFields {
111    /// Create new base fields with a system instance ID.
112    pub fn new(system_instance_id: impl Into<String>) -> Self {
113        Self {
114            system_instance_id: system_instance_id.into(),
115            time_record: TimeRecord::now(),
116            metadata: HashMap::new(),
117            event_metadata: None,
118        }
119    }
120}
121
122/// LLM/CAIS call event.
123#[derive(Debug, Clone, Default, Serialize, Deserialize)]
124pub struct LMCAISEvent {
125    /// Base event fields
126    #[serde(flatten)]
127    pub base: BaseEventFields,
128    /// Model name (e.g., "gpt-4", "claude-3-opus")
129    pub model_name: String,
130    /// Provider (e.g., "openai", "anthropic")
131    #[serde(default)]
132    pub provider: Option<String>,
133    /// Input/prompt tokens
134    #[serde(default)]
135    pub input_tokens: Option<i32>,
136    /// Output/completion tokens
137    #[serde(default)]
138    pub output_tokens: Option<i32>,
139    /// Total tokens
140    #[serde(default)]
141    pub total_tokens: Option<i32>,
142    /// Cost in USD
143    #[serde(default)]
144    pub cost_usd: Option<f64>,
145    /// Latency in milliseconds
146    #[serde(default)]
147    pub latency_ms: Option<i32>,
148    /// OpenTelemetry span ID
149    #[serde(default)]
150    pub span_id: Option<String>,
151    /// OpenTelemetry trace ID
152    #[serde(default)]
153    pub trace_id: Option<String>,
154    /// Detailed call records
155    #[serde(default)]
156    pub call_records: Option<Vec<LLMCallRecord>>,
157    /// System state before the call
158    #[serde(default)]
159    pub system_state_before: Option<Value>,
160    /// System state after the call
161    #[serde(default)]
162    pub system_state_after: Option<Value>,
163}
164
165/// Environment step event (Gymnasium/OpenAI Gym style).
166#[derive(Debug, Clone, Default, Serialize, Deserialize)]
167pub struct EnvironmentEvent {
168    /// Base event fields
169    #[serde(flatten)]
170    pub base: BaseEventFields,
171    /// Reward signal
172    #[serde(default)]
173    pub reward: f64,
174    /// Episode terminated flag
175    #[serde(default)]
176    pub terminated: bool,
177    /// Episode truncated flag
178    #[serde(default)]
179    pub truncated: bool,
180    /// System state before step
181    #[serde(default)]
182    pub system_state_before: Option<Value>,
183    /// System state after step (observations)
184    #[serde(default)]
185    pub system_state_after: Option<Value>,
186}
187
188/// Runtime/action selection event.
189#[derive(Debug, Clone, Default, Serialize, Deserialize)]
190pub struct RuntimeEvent {
191    /// Base event fields
192    #[serde(flatten)]
193    pub base: BaseEventFields,
194    /// Action indices/selections
195    #[serde(default)]
196    pub actions: Vec<i32>,
197}
198
199/// Unified event type using tagged enum.
200#[derive(Debug, Clone, Serialize, Deserialize)]
201#[serde(tag = "event_type", rename_all = "snake_case")]
202pub enum TracingEvent {
203    /// LLM/CAIS call
204    Cais(LMCAISEvent),
205    /// Environment step
206    Environment(EnvironmentEvent),
207    /// Runtime action
208    Runtime(RuntimeEvent),
209}
210
211impl TracingEvent {
212    /// Get the event type.
213    pub fn event_type(&self) -> EventType {
214        match self {
215            TracingEvent::Cais(_) => EventType::Cais,
216            TracingEvent::Environment(_) => EventType::Environment,
217            TracingEvent::Runtime(_) => EventType::Runtime,
218        }
219    }
220
221    /// Get the base event fields.
222    pub fn base(&self) -> &BaseEventFields {
223        match self {
224            TracingEvent::Cais(e) => &e.base,
225            TracingEvent::Environment(e) => &e.base,
226            TracingEvent::Runtime(e) => &e.base,
227        }
228    }
229
230    /// Get the time record.
231    pub fn time_record(&self) -> &TimeRecord {
232        &self.base().time_record
233    }
234
235    /// Get the system instance ID.
236    pub fn system_instance_id(&self) -> &str {
237        &self.base().system_instance_id
238    }
239}
240
241// ============================================================================
242// LLM CALL RECORDS
243// ============================================================================
244
245/// Token usage statistics.
246#[derive(Debug, Clone, Default, Serialize, Deserialize)]
247pub struct LLMUsage {
248    #[serde(default)]
249    pub input_tokens: Option<i32>,
250    #[serde(default)]
251    pub output_tokens: Option<i32>,
252    #[serde(default)]
253    pub total_tokens: Option<i32>,
254    #[serde(default)]
255    pub reasoning_tokens: Option<i32>,
256    #[serde(default)]
257    pub cache_read_tokens: Option<i32>,
258    #[serde(default)]
259    pub cache_write_tokens: Option<i32>,
260    #[serde(default)]
261    pub cost_usd: Option<f64>,
262}
263
264/// LLM message content part.
265#[derive(Debug, Clone, Default, Serialize, Deserialize)]
266pub struct LLMContentPart {
267    /// Content type (text, image, audio, etc.)
268    #[serde(rename = "type")]
269    pub content_type: String,
270    /// Text content
271    #[serde(default)]
272    pub text: Option<String>,
273    /// Generic data payload
274    #[serde(default)]
275    pub data: Option<Value>,
276    /// MIME type for media
277    #[serde(default)]
278    pub mime_type: Option<String>,
279}
280
281impl LLMContentPart {
282    /// Create a text content part.
283    pub fn text(text: impl Into<String>) -> Self {
284        Self {
285            content_type: "text".to_string(),
286            text: Some(text.into()),
287            data: None,
288            mime_type: None,
289        }
290    }
291}
292
293/// LLM message.
294#[derive(Debug, Clone, Default, Serialize, Deserialize)]
295pub struct LLMMessage {
296    /// Role (system, user, assistant, tool)
297    pub role: String,
298    /// Message content parts
299    #[serde(default)]
300    pub parts: Vec<LLMContentPart>,
301    /// Optional message name
302    #[serde(default)]
303    pub name: Option<String>,
304    /// Tool call ID for tool messages
305    #[serde(default)]
306    pub tool_call_id: Option<String>,
307    /// Additional metadata
308    #[serde(default)]
309    pub metadata: HashMap<String, Value>,
310}
311
312impl LLMMessage {
313    /// Create a simple text message.
314    pub fn new(role: impl Into<String>, text: impl Into<String>) -> Self {
315        Self {
316            role: role.into(),
317            parts: vec![LLMContentPart::text(text)],
318            name: None,
319            tool_call_id: None,
320            metadata: HashMap::new(),
321        }
322    }
323
324    /// Get the text content of the message.
325    pub fn text(&self) -> Option<&str> {
326        self.parts.iter().find_map(|p| p.text.as_deref())
327    }
328}
329
330/// Tool call specification.
331#[derive(Debug, Clone, Default, Serialize, Deserialize)]
332pub struct ToolCallSpec {
333    /// Tool/function name
334    pub name: String,
335    /// Arguments as JSON string
336    pub arguments_json: String,
337    /// Call ID
338    #[serde(default)]
339    pub call_id: Option<String>,
340    /// Index in batch
341    #[serde(default)]
342    pub index: Option<i32>,
343}
344
345/// Tool call result.
346#[derive(Debug, Clone, Default, Serialize, Deserialize)]
347pub struct ToolCallResult {
348    /// Correlates to ToolCallSpec
349    #[serde(default)]
350    pub call_id: Option<String>,
351    /// Execution result text
352    #[serde(default)]
353    pub output_text: Option<String>,
354    /// Exit code
355    #[serde(default)]
356    pub exit_code: Option<i32>,
357    /// Status (ok, error)
358    #[serde(default)]
359    pub status: Option<String>,
360    /// Error message
361    #[serde(default)]
362    pub error_message: Option<String>,
363    /// Duration in milliseconds
364    #[serde(default)]
365    pub duration_ms: Option<i32>,
366}
367
368/// Normalized LLM call record.
369#[derive(Debug, Clone, Default, Serialize, Deserialize)]
370pub struct LLMCallRecord {
371    /// Unique call ID
372    pub call_id: String,
373    /// API type (chat_completions, completions, responses)
374    pub api_type: String,
375    /// Provider (openai, anthropic, etc.)
376    #[serde(default)]
377    pub provider: Option<String>,
378    /// Model name
379    pub model_name: String,
380    /// Call start time
381    #[serde(default)]
382    pub started_at: Option<DateTime<Utc>>,
383    /// Call completion time
384    #[serde(default)]
385    pub completed_at: Option<DateTime<Utc>>,
386    /// Latency in milliseconds
387    #[serde(default)]
388    pub latency_ms: Option<i32>,
389    /// Input messages
390    #[serde(default)]
391    pub input_messages: Vec<LLMMessage>,
392    /// Output messages
393    #[serde(default)]
394    pub output_messages: Vec<LLMMessage>,
395    /// Tool calls in response
396    #[serde(default)]
397    pub output_tool_calls: Vec<ToolCallSpec>,
398    /// Tool execution results
399    #[serde(default)]
400    pub tool_results: Vec<ToolCallResult>,
401    /// Token usage
402    #[serde(default)]
403    pub usage: Option<LLMUsage>,
404    /// Finish reason
405    #[serde(default)]
406    pub finish_reason: Option<String>,
407    /// Additional metadata
408    #[serde(default)]
409    pub metadata: HashMap<String, Value>,
410}
411
412// ============================================================================
413// SESSION STRUCTURE
414// ============================================================================
415
416/// Inter-system message (Markov blanket).
417#[derive(Debug, Clone, Default, Serialize, Deserialize)]
418pub struct MarkovBlanketMessage {
419    /// Message content
420    pub content: MessageContent,
421    /// Message type (user, assistant, system, tool_use, tool_result)
422    pub message_type: String,
423    /// Time record
424    pub time_record: TimeRecord,
425    /// Additional metadata
426    #[serde(default)]
427    pub metadata: HashMap<String, Value>,
428}
429
430/// A timestep within a session.
431#[derive(Debug, Clone, Serialize, Deserialize)]
432pub struct SessionTimeStep {
433    /// Unique step ID
434    pub step_id: String,
435    /// Sequential step index
436    pub step_index: i32,
437    /// Step start time
438    pub timestamp: DateTime<Utc>,
439    /// Conversation turn number
440    #[serde(default)]
441    pub turn_number: Option<i32>,
442    /// Events in this step
443    #[serde(default)]
444    pub events: Vec<TracingEvent>,
445    /// Messages in this step
446    #[serde(default)]
447    pub markov_blanket_messages: Vec<MarkovBlanketMessage>,
448    /// Step-specific metadata
449    #[serde(default)]
450    pub step_metadata: HashMap<String, Value>,
451    /// Step completion time
452    #[serde(default)]
453    pub completed_at: Option<DateTime<Utc>>,
454}
455
456impl SessionTimeStep {
457    /// Create a new timestep.
458    pub fn new(step_id: impl Into<String>, step_index: i32) -> Self {
459        Self {
460            step_id: step_id.into(),
461            step_index,
462            timestamp: Utc::now(),
463            turn_number: None,
464            events: Vec::new(),
465            markov_blanket_messages: Vec::new(),
466            step_metadata: HashMap::new(),
467            completed_at: None,
468        }
469    }
470
471    /// Mark the timestep as complete.
472    pub fn complete(&mut self) {
473        self.completed_at = Some(Utc::now());
474    }
475}
476
477/// A complete session trace.
478#[derive(Debug, Clone, Serialize, Deserialize)]
479pub struct SessionTrace {
480    /// Session ID
481    pub session_id: String,
482    /// Session creation time
483    pub created_at: DateTime<Utc>,
484    /// Ordered timesteps
485    #[serde(default)]
486    pub session_time_steps: Vec<SessionTimeStep>,
487    /// Flattened event history
488    #[serde(default)]
489    pub event_history: Vec<TracingEvent>,
490    /// Flattened message history
491    #[serde(default)]
492    pub markov_blanket_message_history: Vec<MarkovBlanketMessage>,
493    /// Session-level metadata
494    #[serde(default)]
495    pub metadata: HashMap<String, Value>,
496}
497
498impl SessionTrace {
499    /// Create a new session trace.
500    pub fn new(session_id: impl Into<String>) -> Self {
501        Self {
502            session_id: session_id.into(),
503            created_at: Utc::now(),
504            session_time_steps: Vec::new(),
505            event_history: Vec::new(),
506            markov_blanket_message_history: Vec::new(),
507            metadata: HashMap::new(),
508        }
509    }
510
511    /// Get the number of timesteps.
512    pub fn num_timesteps(&self) -> usize {
513        self.session_time_steps.len()
514    }
515
516    /// Get the number of events.
517    pub fn num_events(&self) -> usize {
518        self.event_history.len()
519    }
520
521    /// Get the number of messages.
522    pub fn num_messages(&self) -> usize {
523        self.markov_blanket_message_history.len()
524    }
525}
526
527// ============================================================================
528// REWARDS
529// ============================================================================
530
531/// Session-level outcome reward.
532#[derive(Debug, Clone, Default, Serialize, Deserialize)]
533pub struct OutcomeReward {
534    /// Objective key (default: "reward")
535    #[serde(default = "default_objective_key")]
536    pub objective_key: String,
537    /// Total reward value
538    pub total_reward: f64,
539    /// Number of achievements
540    #[serde(default)]
541    pub achievements_count: i32,
542    /// Total steps in session
543    #[serde(default)]
544    pub total_steps: i32,
545    /// Additional metadata
546    #[serde(default)]
547    pub reward_metadata: HashMap<String, Value>,
548    /// Annotation
549    #[serde(default)]
550    pub annotation: Option<Value>,
551}
552
553fn default_objective_key() -> String {
554    "reward".to_string()
555}
556
557/// Event-level reward.
558#[derive(Debug, Clone, Default, Serialize, Deserialize)]
559pub struct EventReward {
560    /// Objective key (default: "reward")
561    #[serde(default = "default_objective_key")]
562    pub objective_key: String,
563    /// Reward value
564    pub reward_value: f64,
565    /// Reward type (shaped, sparse, achievement, penalty, evaluator, human)
566    #[serde(default)]
567    pub reward_type: Option<String>,
568    /// Key (e.g., achievement name)
569    #[serde(default)]
570    pub key: Option<String>,
571    /// Annotation
572    #[serde(default)]
573    pub annotation: Option<Value>,
574    /// Source (environment, runner, evaluator, human)
575    #[serde(default)]
576    pub source: Option<String>,
577}
578
579#[cfg(test)]
580mod tests {
581    use super::*;
582
583    #[test]
584    fn test_time_record() {
585        let tr = TimeRecord::now();
586        assert!(tr.event_time > 0.0);
587        assert!(tr.message_time.is_none());
588    }
589
590    #[test]
591    fn test_message_content() {
592        let mc = MessageContent::from_text("hello");
593        assert_eq!(mc.as_text(), Some("hello"));
594
595        let mc = MessageContent::from_json(&serde_json::json!({"key": "value"}));
596        assert!(mc.json_payload.is_some());
597    }
598
599    #[test]
600    fn test_event_serialization() {
601        let event = TracingEvent::Cais(LMCAISEvent {
602            base: BaseEventFields::new("test-system"),
603            model_name: "gpt-4".to_string(),
604            provider: Some("openai".to_string()),
605            input_tokens: Some(100),
606            output_tokens: Some(50),
607            ..Default::default()
608        });
609
610        let json = serde_json::to_string(&event).unwrap();
611        assert!(json.contains("cais"));
612        assert!(json.contains("gpt-4"));
613
614        let parsed: TracingEvent = serde_json::from_str(&json).unwrap();
615        assert_eq!(parsed.event_type(), EventType::Cais);
616    }
617
618    #[test]
619    fn test_session_trace() {
620        let mut trace = SessionTrace::new("test-session");
621        assert_eq!(trace.num_timesteps(), 0);
622
623        let step = SessionTimeStep::new("step-1", 0);
624        trace.session_time_steps.push(step);
625        assert_eq!(trace.num_timesteps(), 1);
626    }
627
628    #[test]
629    fn test_llm_message() {
630        let msg = LLMMessage::new("user", "Hello, world!");
631        assert_eq!(msg.role, "user");
632        assert_eq!(msg.text(), Some("Hello, world!"));
633    }
634}