stateset_rl_core/
lib.rs

1//! StateSet RL Core - High-performance Rust implementations for RL operations
2//!
3//! This crate provides optimized implementations of performance-critical
4//! operations for the StateSet RL framework, exposed to Python via PyO3.
5//!
6//! # Features
7//! - SIMD-accelerated advantage computation
8//! - Parallel trajectory processing
9//! - Efficient reward normalization
10//! - Fast GAE computation
11
12use pyo3::prelude::*;
13use numpy::{PyArray1, PyReadonlyArray1, PyReadonlyArray2, ToPyArray as _};
14use ndarray::Array1;
15use rayon::prelude::*;
16use std::collections::HashMap;
17
18mod advantage;
19mod gae;
20mod trajectory;
21mod rewards;
22
23// Re-export core functions for Rust usage
24pub use advantage::*;
25pub use gae::*;
26pub use trajectory::*;
27pub use rewards::{
28    normalize_with_running_stats, batch_normalize, exponential_moving_average,
29    shape_rewards, auto_scale_rewards, clip_rewards, RewardStatistics,
30};
31
32/// Compute group-relative advantages for GRPO training
33///
34/// This is a high-performance implementation that uses SIMD where available
35/// and parallelizes across groups.
36///
37/// Args:
38///     rewards: 2D array of shape (num_groups, group_size) containing rewards
39///     baseline_type: "mean", "median", or "min"
40///     normalize: Whether to normalize advantages
41///
42/// Returns:
43///     2D array of advantages with same shape as input
44#[pyfunction]
45fn compute_group_advantages<'py>(
46    py: Python<'py>,
47    rewards: PyReadonlyArray2<'py, f64>,
48    baseline_type: &str,
49    normalize: bool,
50) -> PyResult<Bound<'py, PyArray1<f64>>> {
51    let rewards = rewards.as_array();
52    let (num_groups, group_size) = rewards.dim();
53
54    let mut all_advantages: Vec<f64> = Vec::with_capacity(num_groups * group_size);
55
56    // Process each group in parallel
57    let group_advantages: Vec<Vec<f64>> = (0..num_groups)
58        .into_par_iter()
59        .map(|g| {
60            let group_rewards: Vec<f64> = (0..group_size)
61                .map(|i| rewards[[g, i]])
62                .collect();
63
64            advantage::compute_advantages_for_group(&group_rewards, baseline_type, normalize)
65        })
66        .collect();
67
68    // Flatten results
69    for group in group_advantages {
70        all_advantages.extend(group);
71    }
72
73    Ok(Array1::from_vec(all_advantages).to_pyarray_bound(py))
74}
75
76/// Compute Generalized Advantage Estimation (GAE) for a trajectory
77///
78/// High-performance GAE computation with configurable gamma and lambda.
79///
80/// Args:
81///     rewards: Array of per-step rewards
82///     values: Array of value estimates (one more than rewards for bootstrap)
83///     gamma: Discount factor (default 0.99)
84///     gae_lambda: GAE lambda parameter (default 0.95)
85///
86/// Returns:
87///     Array of advantage estimates
88#[pyfunction]
89#[pyo3(signature = (rewards, values, gamma=0.99, gae_lambda=0.95))]
90fn compute_gae<'py>(
91    py: Python<'py>,
92    rewards: PyReadonlyArray1<'py, f64>,
93    values: PyReadonlyArray1<'py, f64>,
94    gamma: f64,
95    gae_lambda: f64,
96) -> PyResult<Bound<'py, PyArray1<f64>>> {
97    let rewards = rewards.as_slice()?;
98    let values = values.as_slice()?;
99
100    let advantages = gae::compute_gae_internal(rewards, values, gamma, gae_lambda);
101
102    Ok(Array1::from_vec(advantages).to_pyarray_bound(py))
103}
104
105/// Batch compute GAE for multiple trajectories in parallel
106///
107/// Args:
108///     all_rewards: List of reward arrays
109///     all_values: List of value arrays
110///     gamma: Discount factor
111///     gae_lambda: GAE lambda parameter
112///
113/// Returns:
114///     List of advantage arrays
115#[pyfunction]
116#[pyo3(signature = (all_rewards, all_values, gamma=0.99, gae_lambda=0.95))]
117fn batch_compute_gae<'py>(
118    py: Python<'py>,
119    all_rewards: Vec<PyReadonlyArray1<'py, f64>>,
120    all_values: Vec<PyReadonlyArray1<'py, f64>>,
121    gamma: f64,
122    gae_lambda: f64,
123) -> PyResult<Vec<Bound<'py, PyArray1<f64>>>> {
124    let rewards_vecs: Vec<Vec<f64>> = all_rewards
125        .iter()
126        .map(|r| r.as_slice().map(|s| s.to_vec()).unwrap_or_default())
127        .collect();
128
129    let values_vecs: Vec<Vec<f64>> = all_values
130        .iter()
131        .map(|v| v.as_slice().map(|s| s.to_vec()).unwrap_or_default())
132        .collect();
133
134    // Parallel GAE computation
135    let results: Vec<Vec<f64>> = rewards_vecs
136        .par_iter()
137        .zip(values_vecs.par_iter())
138        .map(|(rewards, values)| {
139            gae::compute_gae_internal(rewards, values, gamma, gae_lambda)
140        })
141        .collect();
142
143    Ok(results
144        .into_iter()
145        .map(|v| Array1::from_vec(v).to_pyarray_bound(py))
146        .collect())
147}
148
149/// Normalize rewards using running statistics
150///
151/// Efficient online normalization with Welford's algorithm.
152///
153/// Args:
154///     rewards: Array of rewards to normalize
155///     running_mean: Current running mean (will be updated)
156///     running_var: Current running variance (will be updated)
157///     count: Current count (will be updated)
158///     epsilon: Small value for numerical stability
159///
160/// Returns:
161///     Tuple of (normalized_rewards, new_mean, new_var, new_count)
162#[pyfunction]
163#[pyo3(signature = (rewards, running_mean=0.0, running_var=1.0, count=0, epsilon=1e-8))]
164fn normalize_rewards<'py>(
165    py: Python<'py>,
166    rewards: PyReadonlyArray1<'py, f64>,
167    running_mean: f64,
168    running_var: f64,
169    count: i64,
170    epsilon: f64,
171) -> PyResult<(Bound<'py, PyArray1<f64>>, f64, f64, i64)> {
172    let rewards = rewards.as_slice()?;
173
174    let (normalized, new_mean, new_var, new_count) =
175        rewards::normalize_with_running_stats(rewards, running_mean, running_var, count, epsilon);
176
177    Ok((Array1::from_vec(normalized).to_pyarray_bound(py), new_mean, new_var, new_count))
178}
179
180/// Clip rewards to a specified range
181#[pyfunction]
182fn clip_rewards_py<'py>(
183    py: Python<'py>,
184    rewards: PyReadonlyArray1<'py, f64>,
185    min_val: f64,
186    max_val: f64,
187) -> PyResult<Bound<'py, PyArray1<f64>>> {
188    let rewards = rewards.as_slice()?;
189
190    let clipped: Vec<f64> = rewards
191        .iter()
192        .map(|&r| r.clamp(min_val, max_val))
193        .collect();
194
195    Ok(Array1::from_vec(clipped).to_pyarray_bound(py))
196}
197
198/// Compute GSPO sequence-level importance ratios
199///
200/// Implements the length-normalized sequence importance ratio from GSPO:
201/// s_i(θ) = (π_θ(y_i|x) / π_θ_old(y_i|x))^(1/|y_i|)
202///
203/// Args:
204///     log_probs_new: Log probabilities under new policy (sum per sequence)
205///     log_probs_old: Log probabilities under old policy (sum per sequence)
206///     sequence_lengths: Length of each sequence
207///
208/// Returns:
209///     Array of sequence-level importance ratios
210#[pyfunction]
211fn compute_gspo_importance_ratios<'py>(
212    py: Python<'py>,
213    log_probs_new: PyReadonlyArray1<'py, f64>,
214    log_probs_old: PyReadonlyArray1<'py, f64>,
215    sequence_lengths: PyReadonlyArray1<'py, i64>,
216) -> PyResult<Bound<'py, PyArray1<f64>>> {
217    let new_probs = log_probs_new.as_slice()?;
218    let old_probs = log_probs_old.as_slice()?;
219    let lengths = sequence_lengths.as_slice()?;
220
221    let ratios: Vec<f64> = new_probs
222        .par_iter()
223        .zip(old_probs.par_iter())
224        .zip(lengths.par_iter())
225        .map(|((&new, &old), &len)| {
226            if len <= 0 {
227                return 1.0;
228            }
229            let log_ratio = new - old;
230            let normalized_log_ratio = log_ratio / (len as f64);
231            normalized_log_ratio.exp()
232        })
233        .collect();
234
235    Ok(Array1::from_vec(ratios).to_pyarray_bound(py))
236}
237
238/// Apply GSPO clipping to importance ratios
239///
240/// Args:
241///     ratios: Importance ratios
242///     advantages: Advantage values
243///     clip_left: Left clipping bound (default 3e-4)
244///     clip_right: Right clipping bound (default 4e-4)
245///
246/// Returns:
247///     Clipped surrogate objectives
248#[pyfunction]
249#[pyo3(signature = (ratios, advantages, clip_left=0.0003, clip_right=0.0004))]
250fn apply_gspo_clipping<'py>(
251    py: Python<'py>,
252    ratios: PyReadonlyArray1<'py, f64>,
253    advantages: PyReadonlyArray1<'py, f64>,
254    clip_left: f64,
255    clip_right: f64,
256) -> PyResult<Bound<'py, PyArray1<f64>>> {
257    let ratios = ratios.as_slice()?;
258    let advantages = advantages.as_slice()?;
259
260    let clipped: Vec<f64> = ratios
261        .par_iter()
262        .zip(advantages.par_iter())
263        .map(|(&ratio, &adv)| {
264            let unclipped = ratio * adv;
265            let clipped_ratio = if adv >= 0.0 {
266                ratio.min(1.0 + clip_right)
267            } else {
268                ratio.max(1.0 - clip_left)
269            };
270            let clipped_obj = clipped_ratio * adv;
271            unclipped.min(clipped_obj)
272        })
273        .collect();
274
275    Ok(Array1::from_vec(clipped).to_pyarray_bound(py))
276}
277
278/// Compute PPO clipped surrogate objective
279#[pyfunction]
280#[pyo3(signature = (ratios, advantages, clip_epsilon=0.2))]
281fn compute_ppo_surrogate<'py>(
282    py: Python<'py>,
283    ratios: PyReadonlyArray1<'py, f64>,
284    advantages: PyReadonlyArray1<'py, f64>,
285    clip_epsilon: f64,
286) -> PyResult<Bound<'py, PyArray1<f64>>> {
287    let ratios = ratios.as_slice()?;
288    let advantages = advantages.as_slice()?;
289
290    let objectives: Vec<f64> = ratios
291        .par_iter()
292        .zip(advantages.par_iter())
293        .map(|(&ratio, &adv)| {
294            let unclipped = ratio * adv;
295            let clipped = ratio.clamp(1.0 - clip_epsilon, 1.0 + clip_epsilon) * adv;
296            unclipped.min(clipped)
297        })
298        .collect();
299
300    Ok(Array1::from_vec(objectives).to_pyarray_bound(py))
301}
302
303/// Compute reward statistics for a batch of trajectories
304#[pyfunction]
305fn compute_reward_statistics(rewards: Vec<f64>) -> PyResult<HashMap<String, f64>> {
306    if rewards.is_empty() {
307        return Ok(HashMap::from([
308            ("mean".to_string(), 0.0),
309            ("std".to_string(), 0.0),
310            ("min".to_string(), 0.0),
311            ("max".to_string(), 0.0),
312            ("median".to_string(), 0.0),
313        ]));
314    }
315
316    let n = rewards.len() as f64;
317    let mean = rewards.iter().sum::<f64>() / n;
318    let variance = rewards.iter().map(|r| (r - mean).powi(2)).sum::<f64>() / n;
319    let std = variance.sqrt();
320
321    let min = rewards.iter().cloned().fold(f64::INFINITY, f64::min);
322    let max = rewards.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
323
324    let mut sorted = rewards.clone();
325    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
326    let median = if sorted.len() % 2 == 0 {
327        (sorted[sorted.len() / 2 - 1] + sorted[sorted.len() / 2]) / 2.0
328    } else {
329        sorted[sorted.len() / 2]
330    };
331
332    Ok(HashMap::from([
333        ("mean".to_string(), mean),
334        ("std".to_string(), std),
335        ("min".to_string(), min),
336        ("max".to_string(), max),
337        ("median".to_string(), median),
338        ("count".to_string(), n),
339    ]))
340}
341
342/// Python module definition
343#[pymodule]
344fn stateset_rl_core(m: &Bound<'_, PyModule>) -> PyResult<()> {
345    m.add_function(wrap_pyfunction!(compute_group_advantages, m)?)?;
346    m.add_function(wrap_pyfunction!(compute_gae, m)?)?;
347    m.add_function(wrap_pyfunction!(batch_compute_gae, m)?)?;
348    m.add_function(wrap_pyfunction!(normalize_rewards, m)?)?;
349    m.add_function(wrap_pyfunction!(clip_rewards_py, m)?)?;
350    m.add_function(wrap_pyfunction!(compute_gspo_importance_ratios, m)?)?;
351    m.add_function(wrap_pyfunction!(apply_gspo_clipping, m)?)?;
352    m.add_function(wrap_pyfunction!(compute_ppo_surrogate, m)?)?;
353    m.add_function(wrap_pyfunction!(compute_reward_statistics, m)?)?;
354
355    // Add version info
356    m.add("__version__", "0.1.0")?;
357
358    Ok(())
359}