stateset_rl_core/
trajectory.rs

1//! Trajectory processing module
2//!
3//! Efficient trajectory batching and processing utilities.
4
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8/// Lightweight trajectory representation for Rust processing
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct RustTrajectory {
11    pub trajectory_id: String,
12    pub rewards: Vec<f64>,
13    pub log_probs: Vec<f64>,
14    pub sequence_length: usize,
15    pub total_reward: f64,
16    pub metadata: HashMap<String, String>,
17}
18
19impl RustTrajectory {
20    pub fn new(trajectory_id: String) -> Self {
21        Self {
22            trajectory_id,
23            rewards: Vec::new(),
24            log_probs: Vec::new(),
25            sequence_length: 0,
26            total_reward: 0.0,
27            metadata: HashMap::new(),
28        }
29    }
30
31    pub fn add_step(&mut self, reward: f64, log_prob: f64) {
32        self.rewards.push(reward);
33        self.log_probs.push(log_prob);
34        self.total_reward += reward;
35        self.sequence_length += 1;
36    }
37
38    pub fn average_reward(&self) -> f64 {
39        if self.sequence_length == 0 {
40            0.0
41        } else {
42            self.total_reward / self.sequence_length as f64
43        }
44    }
45}
46
47/// Trajectory group for GRPO processing
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct RustTrajectoryGroup {
50    pub scenario_id: String,
51    pub trajectories: Vec<RustTrajectory>,
52}
53
54impl RustTrajectoryGroup {
55    pub fn new(scenario_id: String) -> Self {
56        Self {
57            scenario_id,
58            trajectories: Vec::new(),
59        }
60    }
61
62    pub fn add_trajectory(&mut self, trajectory: RustTrajectory) {
63        self.trajectories.push(trajectory);
64    }
65
66    pub fn rewards(&self) -> Vec<f64> {
67        self.trajectories.iter().map(|t| t.total_reward).collect()
68    }
69
70    pub fn compute_advantages(&self, baseline_type: &str) -> Vec<f64> {
71        let rewards = self.rewards();
72        crate::advantage::compute_advantages_for_group(&rewards, baseline_type, false)
73    }
74}
75
76/// Batch multiple trajectory groups for efficient processing
77pub fn batch_trajectories(
78    groups: &[RustTrajectoryGroup],
79) -> (Vec<f64>, Vec<usize>) {
80    let mut all_rewards = Vec::new();
81    let mut group_indices = Vec::new();
82
83    for (idx, group) in groups.iter().enumerate() {
84        for traj in &group.trajectories {
85            all_rewards.push(traj.total_reward);
86            group_indices.push(idx);
87        }
88    }
89
90    (all_rewards, group_indices)
91}
92
93/// Compute cumulative rewards for a trajectory
94pub fn compute_cumulative_rewards(rewards: &[f64]) -> Vec<f64> {
95    let mut cumulative = Vec::with_capacity(rewards.len());
96    let mut total = 0.0;
97
98    for &reward in rewards {
99        total += reward;
100        cumulative.push(total);
101    }
102
103    cumulative
104}
105
106/// Compute discounted cumulative rewards
107pub fn compute_discounted_rewards(rewards: &[f64], gamma: f64) -> Vec<f64> {
108    let n = rewards.len();
109    if n == 0 {
110        return vec![];
111    }
112
113    let mut discounted = vec![0.0; n];
114    let mut running = 0.0;
115
116    for t in (0..n).rev() {
117        running = rewards[t] + gamma * running;
118        discounted[t] = running;
119    }
120
121    discounted
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127
128    #[test]
129    fn test_trajectory_creation() {
130        let mut traj = RustTrajectory::new("test-1".to_string());
131        traj.add_step(1.0, -0.5);
132        traj.add_step(2.0, -0.3);
133
134        assert_eq!(traj.sequence_length, 2);
135        assert!((traj.total_reward - 3.0).abs() < 1e-10);
136        assert!((traj.average_reward() - 1.5).abs() < 1e-10);
137    }
138
139    #[test]
140    fn test_trajectory_group() {
141        let mut group = RustTrajectoryGroup::new("scenario-1".to_string());
142
143        let mut t1 = RustTrajectory::new("t1".to_string());
144        t1.add_step(1.0, -0.1);
145        group.add_trajectory(t1);
146
147        let mut t2 = RustTrajectory::new("t2".to_string());
148        t2.add_step(3.0, -0.2);
149        group.add_trajectory(t2);
150
151        let rewards = group.rewards();
152        assert_eq!(rewards, vec![1.0, 3.0]);
153
154        let advantages = group.compute_advantages("mean");
155        assert_eq!(advantages.len(), 2);
156        // Mean is 2.0, so advantages should be [-1.0, 1.0]
157        assert!((advantages[0] - (-1.0)).abs() < 1e-10);
158        assert!((advantages[1] - 1.0).abs() < 1e-10);
159    }
160
161    #[test]
162    fn test_cumulative_rewards() {
163        let rewards = vec![1.0, 2.0, 3.0];
164        let cumulative = compute_cumulative_rewards(&rewards);
165        assert_eq!(cumulative, vec![1.0, 3.0, 6.0]);
166    }
167
168    #[test]
169    fn test_discounted_rewards() {
170        let rewards = vec![1.0, 1.0, 1.0];
171        let discounted = compute_discounted_rewards(&rewards, 0.9);
172
173        // G_2 = 1.0
174        // G_1 = 1.0 + 0.9 * 1.0 = 1.9
175        // G_0 = 1.0 + 0.9 * 1.9 = 2.71
176        assert!((discounted[2] - 1.0).abs() < 1e-10);
177        assert!((discounted[1] - 1.9).abs() < 1e-10);
178        assert!((discounted[0] - 2.71).abs() < 1e-10);
179    }
180}