stateset_rl_core/
trajectory.rs1use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct RustTrajectory {
11 pub trajectory_id: String,
12 pub rewards: Vec<f64>,
13 pub log_probs: Vec<f64>,
14 pub sequence_length: usize,
15 pub total_reward: f64,
16 pub metadata: HashMap<String, String>,
17}
18
19impl RustTrajectory {
20 pub fn new(trajectory_id: String) -> Self {
21 Self {
22 trajectory_id,
23 rewards: Vec::new(),
24 log_probs: Vec::new(),
25 sequence_length: 0,
26 total_reward: 0.0,
27 metadata: HashMap::new(),
28 }
29 }
30
31 pub fn add_step(&mut self, reward: f64, log_prob: f64) {
32 self.rewards.push(reward);
33 self.log_probs.push(log_prob);
34 self.total_reward += reward;
35 self.sequence_length += 1;
36 }
37
38 pub fn average_reward(&self) -> f64 {
39 if self.sequence_length == 0 {
40 0.0
41 } else {
42 self.total_reward / self.sequence_length as f64
43 }
44 }
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct RustTrajectoryGroup {
50 pub scenario_id: String,
51 pub trajectories: Vec<RustTrajectory>,
52}
53
54impl RustTrajectoryGroup {
55 pub fn new(scenario_id: String) -> Self {
56 Self {
57 scenario_id,
58 trajectories: Vec::new(),
59 }
60 }
61
62 pub fn add_trajectory(&mut self, trajectory: RustTrajectory) {
63 self.trajectories.push(trajectory);
64 }
65
66 pub fn rewards(&self) -> Vec<f64> {
67 self.trajectories.iter().map(|t| t.total_reward).collect()
68 }
69
70 pub fn compute_advantages(&self, baseline_type: &str) -> Vec<f64> {
71 let rewards = self.rewards();
72 crate::advantage::compute_advantages_for_group(&rewards, baseline_type, false)
73 }
74}
75
76pub fn batch_trajectories(
78 groups: &[RustTrajectoryGroup],
79) -> (Vec<f64>, Vec<usize>) {
80 let mut all_rewards = Vec::new();
81 let mut group_indices = Vec::new();
82
83 for (idx, group) in groups.iter().enumerate() {
84 for traj in &group.trajectories {
85 all_rewards.push(traj.total_reward);
86 group_indices.push(idx);
87 }
88 }
89
90 (all_rewards, group_indices)
91}
92
93pub fn compute_cumulative_rewards(rewards: &[f64]) -> Vec<f64> {
95 let mut cumulative = Vec::with_capacity(rewards.len());
96 let mut total = 0.0;
97
98 for &reward in rewards {
99 total += reward;
100 cumulative.push(total);
101 }
102
103 cumulative
104}
105
106pub fn compute_discounted_rewards(rewards: &[f64], gamma: f64) -> Vec<f64> {
108 let n = rewards.len();
109 if n == 0 {
110 return vec![];
111 }
112
113 let mut discounted = vec![0.0; n];
114 let mut running = 0.0;
115
116 for t in (0..n).rev() {
117 running = rewards[t] + gamma * running;
118 discounted[t] = running;
119 }
120
121 discounted
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127
128 #[test]
129 fn test_trajectory_creation() {
130 let mut traj = RustTrajectory::new("test-1".to_string());
131 traj.add_step(1.0, -0.5);
132 traj.add_step(2.0, -0.3);
133
134 assert_eq!(traj.sequence_length, 2);
135 assert!((traj.total_reward - 3.0).abs() < 1e-10);
136 assert!((traj.average_reward() - 1.5).abs() < 1e-10);
137 }
138
139 #[test]
140 fn test_trajectory_group() {
141 let mut group = RustTrajectoryGroup::new("scenario-1".to_string());
142
143 let mut t1 = RustTrajectory::new("t1".to_string());
144 t1.add_step(1.0, -0.1);
145 group.add_trajectory(t1);
146
147 let mut t2 = RustTrajectory::new("t2".to_string());
148 t2.add_step(3.0, -0.2);
149 group.add_trajectory(t2);
150
151 let rewards = group.rewards();
152 assert_eq!(rewards, vec![1.0, 3.0]);
153
154 let advantages = group.compute_advantages("mean");
155 assert_eq!(advantages.len(), 2);
156 assert!((advantages[0] - (-1.0)).abs() < 1e-10);
158 assert!((advantages[1] - 1.0).abs() < 1e-10);
159 }
160
161 #[test]
162 fn test_cumulative_rewards() {
163 let rewards = vec![1.0, 2.0, 3.0];
164 let cumulative = compute_cumulative_rewards(&rewards);
165 assert_eq!(cumulative, vec![1.0, 3.0, 6.0]);
166 }
167
168 #[test]
169 fn test_discounted_rewards() {
170 let rewards = vec![1.0, 1.0, 1.0];
171 let discounted = compute_discounted_rewards(&rewards, 0.9);
172
173 assert!((discounted[2] - 1.0).abs() < 1e-10);
177 assert!((discounted[1] - 1.9).abs() < 1e-10);
178 assert!((discounted[0] - 2.71).abs() < 1e-10);
179 }
180}