1use pyo3::prelude::*;
13use numpy::{PyArray1, PyReadonlyArray1, PyReadonlyArray2, ToPyArray as _};
14use ndarray::Array1;
15use rayon::prelude::*;
16use std::collections::HashMap;
17
18mod advantage;
19mod gae;
20mod trajectory;
21mod rewards;
22
23pub use advantage::*;
25pub use gae::*;
26pub use trajectory::*;
27pub use rewards::{
28 normalize_with_running_stats, batch_normalize, exponential_moving_average,
29 shape_rewards, auto_scale_rewards, clip_rewards, RewardStatistics,
30};
31
32#[pyfunction]
45fn compute_group_advantages<'py>(
46 py: Python<'py>,
47 rewards: PyReadonlyArray2<'py, f64>,
48 baseline_type: &str,
49 normalize: bool,
50) -> PyResult<Bound<'py, PyArray1<f64>>> {
51 let rewards = rewards.as_array();
52 let (num_groups, group_size) = rewards.dim();
53
54 let mut all_advantages: Vec<f64> = Vec::with_capacity(num_groups * group_size);
55
56 let group_advantages: Vec<Vec<f64>> = (0..num_groups)
58 .into_par_iter()
59 .map(|g| {
60 let group_rewards: Vec<f64> = (0..group_size)
61 .map(|i| rewards[[g, i]])
62 .collect();
63
64 advantage::compute_advantages_for_group(&group_rewards, baseline_type, normalize)
65 })
66 .collect();
67
68 for group in group_advantages {
70 all_advantages.extend(group);
71 }
72
73 Ok(Array1::from_vec(all_advantages).to_pyarray_bound(py))
74}
75
76#[pyfunction]
89#[pyo3(signature = (rewards, values, gamma=0.99, gae_lambda=0.95))]
90fn compute_gae<'py>(
91 py: Python<'py>,
92 rewards: PyReadonlyArray1<'py, f64>,
93 values: PyReadonlyArray1<'py, f64>,
94 gamma: f64,
95 gae_lambda: f64,
96) -> PyResult<Bound<'py, PyArray1<f64>>> {
97 let rewards = rewards.as_slice()?;
98 let values = values.as_slice()?;
99
100 let advantages = gae::compute_gae_internal(rewards, values, gamma, gae_lambda);
101
102 Ok(Array1::from_vec(advantages).to_pyarray_bound(py))
103}
104
105#[pyfunction]
116#[pyo3(signature = (all_rewards, all_values, gamma=0.99, gae_lambda=0.95))]
117fn batch_compute_gae<'py>(
118 py: Python<'py>,
119 all_rewards: Vec<PyReadonlyArray1<'py, f64>>,
120 all_values: Vec<PyReadonlyArray1<'py, f64>>,
121 gamma: f64,
122 gae_lambda: f64,
123) -> PyResult<Vec<Bound<'py, PyArray1<f64>>>> {
124 let rewards_vecs: Vec<Vec<f64>> = all_rewards
125 .iter()
126 .map(|r| r.as_slice().map(|s| s.to_vec()).unwrap_or_default())
127 .collect();
128
129 let values_vecs: Vec<Vec<f64>> = all_values
130 .iter()
131 .map(|v| v.as_slice().map(|s| s.to_vec()).unwrap_or_default())
132 .collect();
133
134 let results: Vec<Vec<f64>> = rewards_vecs
136 .par_iter()
137 .zip(values_vecs.par_iter())
138 .map(|(rewards, values)| {
139 gae::compute_gae_internal(rewards, values, gamma, gae_lambda)
140 })
141 .collect();
142
143 Ok(results
144 .into_iter()
145 .map(|v| Array1::from_vec(v).to_pyarray_bound(py))
146 .collect())
147}
148
149#[pyfunction]
163#[pyo3(signature = (rewards, running_mean=0.0, running_var=1.0, count=0, epsilon=1e-8))]
164fn normalize_rewards<'py>(
165 py: Python<'py>,
166 rewards: PyReadonlyArray1<'py, f64>,
167 running_mean: f64,
168 running_var: f64,
169 count: i64,
170 epsilon: f64,
171) -> PyResult<(Bound<'py, PyArray1<f64>>, f64, f64, i64)> {
172 let rewards = rewards.as_slice()?;
173
174 let (normalized, new_mean, new_var, new_count) =
175 rewards::normalize_with_running_stats(rewards, running_mean, running_var, count, epsilon);
176
177 Ok((Array1::from_vec(normalized).to_pyarray_bound(py), new_mean, new_var, new_count))
178}
179
180#[pyfunction]
182fn clip_rewards_py<'py>(
183 py: Python<'py>,
184 rewards: PyReadonlyArray1<'py, f64>,
185 min_val: f64,
186 max_val: f64,
187) -> PyResult<Bound<'py, PyArray1<f64>>> {
188 let rewards = rewards.as_slice()?;
189
190 let clipped: Vec<f64> = rewards
191 .iter()
192 .map(|&r| r.clamp(min_val, max_val))
193 .collect();
194
195 Ok(Array1::from_vec(clipped).to_pyarray_bound(py))
196}
197
198#[pyfunction]
211fn compute_gspo_importance_ratios<'py>(
212 py: Python<'py>,
213 log_probs_new: PyReadonlyArray1<'py, f64>,
214 log_probs_old: PyReadonlyArray1<'py, f64>,
215 sequence_lengths: PyReadonlyArray1<'py, i64>,
216) -> PyResult<Bound<'py, PyArray1<f64>>> {
217 let new_probs = log_probs_new.as_slice()?;
218 let old_probs = log_probs_old.as_slice()?;
219 let lengths = sequence_lengths.as_slice()?;
220
221 let ratios: Vec<f64> = new_probs
222 .par_iter()
223 .zip(old_probs.par_iter())
224 .zip(lengths.par_iter())
225 .map(|((&new, &old), &len)| {
226 if len <= 0 {
227 return 1.0;
228 }
229 let log_ratio = new - old;
230 let normalized_log_ratio = log_ratio / (len as f64);
231 normalized_log_ratio.exp()
232 })
233 .collect();
234
235 Ok(Array1::from_vec(ratios).to_pyarray_bound(py))
236}
237
238#[pyfunction]
249#[pyo3(signature = (ratios, advantages, clip_left=0.0003, clip_right=0.0004))]
250fn apply_gspo_clipping<'py>(
251 py: Python<'py>,
252 ratios: PyReadonlyArray1<'py, f64>,
253 advantages: PyReadonlyArray1<'py, f64>,
254 clip_left: f64,
255 clip_right: f64,
256) -> PyResult<Bound<'py, PyArray1<f64>>> {
257 let ratios = ratios.as_slice()?;
258 let advantages = advantages.as_slice()?;
259
260 let clipped: Vec<f64> = ratios
261 .par_iter()
262 .zip(advantages.par_iter())
263 .map(|(&ratio, &adv)| {
264 let unclipped = ratio * adv;
265 let clipped_ratio = if adv >= 0.0 {
266 ratio.min(1.0 + clip_right)
267 } else {
268 ratio.max(1.0 - clip_left)
269 };
270 let clipped_obj = clipped_ratio * adv;
271 unclipped.min(clipped_obj)
272 })
273 .collect();
274
275 Ok(Array1::from_vec(clipped).to_pyarray_bound(py))
276}
277
278#[pyfunction]
280#[pyo3(signature = (ratios, advantages, clip_epsilon=0.2))]
281fn compute_ppo_surrogate<'py>(
282 py: Python<'py>,
283 ratios: PyReadonlyArray1<'py, f64>,
284 advantages: PyReadonlyArray1<'py, f64>,
285 clip_epsilon: f64,
286) -> PyResult<Bound<'py, PyArray1<f64>>> {
287 let ratios = ratios.as_slice()?;
288 let advantages = advantages.as_slice()?;
289
290 let objectives: Vec<f64> = ratios
291 .par_iter()
292 .zip(advantages.par_iter())
293 .map(|(&ratio, &adv)| {
294 let unclipped = ratio * adv;
295 let clipped = ratio.clamp(1.0 - clip_epsilon, 1.0 + clip_epsilon) * adv;
296 unclipped.min(clipped)
297 })
298 .collect();
299
300 Ok(Array1::from_vec(objectives).to_pyarray_bound(py))
301}
302
303#[pyfunction]
305fn compute_reward_statistics(rewards: Vec<f64>) -> PyResult<HashMap<String, f64>> {
306 if rewards.is_empty() {
307 return Ok(HashMap::from([
308 ("mean".to_string(), 0.0),
309 ("std".to_string(), 0.0),
310 ("min".to_string(), 0.0),
311 ("max".to_string(), 0.0),
312 ("median".to_string(), 0.0),
313 ]));
314 }
315
316 let n = rewards.len() as f64;
317 let mean = rewards.iter().sum::<f64>() / n;
318 let variance = rewards.iter().map(|r| (r - mean).powi(2)).sum::<f64>() / n;
319 let std = variance.sqrt();
320
321 let min = rewards.iter().cloned().fold(f64::INFINITY, f64::min);
322 let max = rewards.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
323
324 let mut sorted = rewards.clone();
325 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
326 let median = if sorted.len() % 2 == 0 {
327 (sorted[sorted.len() / 2 - 1] + sorted[sorted.len() / 2]) / 2.0
328 } else {
329 sorted[sorted.len() / 2]
330 };
331
332 Ok(HashMap::from([
333 ("mean".to_string(), mean),
334 ("std".to_string(), std),
335 ("min".to_string(), min),
336 ("max".to_string(), max),
337 ("median".to_string(), median),
338 ("count".to_string(), n),
339 ]))
340}
341
342#[pymodule]
344fn stateset_rl_core(m: &Bound<'_, PyModule>) -> PyResult<()> {
345 m.add_function(wrap_pyfunction!(compute_group_advantages, m)?)?;
346 m.add_function(wrap_pyfunction!(compute_gae, m)?)?;
347 m.add_function(wrap_pyfunction!(batch_compute_gae, m)?)?;
348 m.add_function(wrap_pyfunction!(normalize_rewards, m)?)?;
349 m.add_function(wrap_pyfunction!(clip_rewards_py, m)?)?;
350 m.add_function(wrap_pyfunction!(compute_gspo_importance_ratios, m)?)?;
351 m.add_function(wrap_pyfunction!(apply_gspo_clipping, m)?)?;
352 m.add_function(wrap_pyfunction!(compute_ppo_surrogate, m)?)?;
353 m.add_function(wrap_pyfunction!(compute_reward_statistics, m)?)?;
354
355 m.add("__version__", "0.1.0")?;
357
358 Ok(())
359}