ruqu_qarlp/
lib.rs

1//! # ruqu-qarlp: Quantum-Assisted Reinforcement Learning Policy
2//!
3//! A Rust implementation of quantum-assisted reinforcement learning using
4//! variational quantum circuits as policy networks. This crate provides
5//! a complete framework for training quantum RL agents.
6//!
7//! ## Overview
8//!
9//! This crate implements the QARLP (Quantum-Assisted Reinforcement Learning Policy)
10//! algorithm, which uses variational quantum circuits (VQCs) to represent policies
11//! in reinforcement learning. The key components are:
12//!
13//! - **Quantum Policy Network**: A variational quantum circuit that maps states to
14//!   action probabilities through parameterized rotation gates.
15//!
16//! - **Policy Gradient**: REINFORCE algorithm with baseline subtraction, using the
17//!   parameter-shift rule for exact gradient computation on quantum circuits.
18//!
19//! - **Environment Interface**: Generic trait for RL environments, with included
20//!   implementations of GridWorld and CartPole for testing.
21//!
22//! - **Training Loop**: Complete training infrastructure with checkpointing,
23//!   logging, and metrics.
24//!
25//! ## Architecture
26//!
27//! The quantum policy network consists of:
28//!
29//! 1. **State Encoding**: Classical state vectors are encoded as rotation angles
30//!    on qubits using RX gates.
31//!
32//! 2. **Variational Layers**: Parameterized RY and RZ rotations with CNOT
33//!    entanglement gates form the trainable part of the circuit.
34//!
35//! 3. **Measurement**: Computational basis measurement probabilities are mapped
36//!    to action probabilities via softmax.
37//!
38//! ## Example
39//!
40//! ```
41//! use ruqu_qarlp::prelude::*;
42//!
43//! // Create a quantum policy
44//! let policy_config = PolicyConfig {
45//!     num_qubits: 4,
46//!     num_layers: 2,
47//!     num_actions: 4,
48//!     ..Default::default()
49//! };
50//! let policy = QuantumPolicy::new(policy_config).unwrap();
51//!
52//! // Create an environment
53//! let env_config = GridWorldConfig::default();
54//! let env = GridWorld::new(env_config).unwrap();
55//!
56//! // Create trainer and train
57//! let trainer_config = TrainerConfig {
58//!     episodes_per_update: 10,
59//!     max_steps_per_episode: 100,
60//!     ..Default::default()
61//! };
62//! let mut trainer = Trainer::new(trainer_config, policy, env).unwrap();
63//!
64//! // Train for 100 iterations
65//! let result = trainer.train(100).unwrap();
66//! println!("Final reward: {}", result.final_average_reward);
67//! ```
68//!
69//! ## Tier 3 Capability (Exploratory)
70//!
71//! This crate represents a Tier 3 (Score 69) exploratory quantum RL implementation.
72//! The two-week test criteria are:
73//!
74//! - Policy gradient update works correctly
75//! - Simple environment shows learning signal
76//!
77//! ## Features
78//!
79//! - `parallel`: Enable parallel gradient computation using rayon (not yet implemented)
80//!
81//! ## References
82//!
83//! - Schuld, M., & Petruccione, F. (2018). Supervised Learning with Quantum Computers
84//! - Mitarai, K., et al. (2018). Quantum circuit learning
85//! - Jerbi, S., et al. (2021). Parametrized quantum policies for reinforcement learning
86
87#![warn(missing_docs)]
88#![warn(clippy::all)]
89#![deny(unsafe_code)]
90
91pub mod buffer;
92pub mod environment;
93pub mod error;
94pub mod gradient;
95pub mod policy;
96pub mod training;
97
98/// Prelude module for convenient imports.
99pub mod prelude {
100    pub use crate::buffer::{BufferConfig, ReplayBuffer, SampleBatch, TrajectoryBuffer};
101    pub use crate::environment::{
102        BinaryChoice, CartPole, CartPoleConfig, Environment, GridWorld, GridWorldConfig, StepResult,
103    };
104    pub use crate::error::{
105        BufferError, EnvironmentError, GradientError, PolicyError, QarlpError, Result,
106        TrainingError,
107    };
108    pub use crate::gradient::{
109        compute_gae, normalize_advantages, Experience, GradientConfig, PolicyGradient, Trajectory,
110    };
111    pub use crate::policy::{PolicyConfig, QuantumPolicy};
112    pub use crate::training::{
113        Checkpoint, EvaluationResult, IterationMetrics, LoggingCallback, NoOpCallback, Trainer,
114        TrainerConfig, TrainingCallback, TrainingOutcome,
115    };
116}
117
118#[cfg(test)]
119mod tests {
120    use super::prelude::*;
121
122    #[test]
123    fn test_end_to_end_gridworld() {
124        // Create policy
125        let policy_config = PolicyConfig {
126            num_qubits: 4,
127            num_layers: 1,
128            num_actions: 4,
129            seed: Some(42),
130            ..Default::default()
131        };
132        let policy = QuantumPolicy::new(policy_config).unwrap();
133
134        // Create environment
135        let env_config = GridWorldConfig {
136            width: 3,
137            height: 3,
138            goal: (2, 2),
139            start: Some((0, 0)),
140            max_steps: 20,
141            ..Default::default()
142        };
143        let env = GridWorld::new(env_config).unwrap();
144
145        // Create trainer
146        let trainer_config = TrainerConfig {
147            episodes_per_update: 5,
148            max_steps_per_episode: 20,
149            verbose: false,
150            seed: Some(42),
151            ..Default::default()
152        };
153        let mut trainer = Trainer::new(trainer_config, policy, env).unwrap();
154
155        // Train
156        let result = trainer.train(5).unwrap();
157
158        // Verify training completed
159        assert_eq!(result.total_iterations, 5);
160        assert!(result.final_average_reward.is_finite());
161    }
162
163    #[test]
164    fn test_end_to_end_cartpole() {
165        // Create policy
166        let policy_config = PolicyConfig {
167            num_qubits: 4,
168            num_layers: 2,
169            num_actions: 2,
170            seed: Some(42),
171            ..Default::default()
172        };
173        let policy = QuantumPolicy::new(policy_config).unwrap();
174
175        // Create environment
176        let env_config = CartPoleConfig {
177            max_steps: 50,
178            seed: Some(42),
179            ..Default::default()
180        };
181        let env = CartPole::new(env_config).unwrap();
182
183        // Create trainer
184        let trainer_config = TrainerConfig {
185            episodes_per_update: 5,
186            max_steps_per_episode: 50,
187            verbose: false,
188            seed: Some(42),
189            gradient_config: GradientConfig {
190                learning_rate: 0.01,
191                ..Default::default()
192            },
193            ..Default::default()
194        };
195        let mut trainer = Trainer::new(trainer_config, policy, env).unwrap();
196
197        // Train
198        let result = trainer.train(3).unwrap();
199
200        // Verify training completed
201        assert_eq!(result.total_iterations, 3);
202        assert!(result.total_episodes > 0);
203    }
204
205    #[test]
206    fn test_end_to_end_binary_choice() {
207        // Create policy
208        let policy_config = PolicyConfig {
209            num_qubits: 4,
210            num_layers: 1,
211            num_actions: 2,
212            seed: Some(42),
213            ..Default::default()
214        };
215        let policy = QuantumPolicy::new(policy_config).unwrap();
216
217        // Create simple environment
218        let env = BinaryChoice::new(10).unwrap();
219
220        // Create trainer
221        let trainer_config = TrainerConfig {
222            episodes_per_update: 10,
223            max_steps_per_episode: 10,
224            verbose: false,
225            seed: Some(42),
226            ..Default::default()
227        };
228        let mut trainer = Trainer::new(trainer_config, policy, env).unwrap();
229
230        // Train
231        let result = trainer.train(5).unwrap();
232
233        // Verify training completed without errors
234        assert_eq!(result.total_iterations, 5);
235    }
236
237    #[test]
238    fn test_policy_forward_pass() {
239        let config = PolicyConfig {
240            num_qubits: 4,
241            num_layers: 2,
242            num_actions: 2,
243            seed: Some(42),
244            ..Default::default()
245        };
246        let policy = QuantumPolicy::new(config).unwrap();
247
248        let state = vec![0.5, -0.3, 0.1, 0.8];
249        let probs = policy.forward(&state).unwrap();
250
251        // Probabilities should sum to 1
252        let sum: f64 = probs.iter().sum();
253        assert!((sum - 1.0).abs() < 1e-6);
254    }
255
256    #[test]
257    fn test_gradient_computation() {
258        let policy_config = PolicyConfig {
259            num_qubits: 4,
260            num_layers: 1,
261            num_actions: 2,
262            seed: Some(42),
263            ..Default::default()
264        };
265        let policy = QuantumPolicy::new(policy_config).unwrap();
266
267        // Create trajectory
268        let mut trajectory = Trajectory::new();
269        for _ in 0..5 {
270            trajectory.push(Experience {
271                state: vec![0.1, 0.2, 0.3, 0.4],
272                action: 0,
273                reward: 1.0,
274                next_state: vec![0.2, 0.3, 0.4, 0.5],
275                done: false,
276                log_prob: -0.5,
277            });
278        }
279
280        let config = GradientConfig::default();
281        let mut pg = PolicyGradient::new(config).unwrap();
282
283        let gradients = pg.compute_gradient(&policy, &trajectory).unwrap();
284
285        // Gradients should be finite
286        assert!(gradients.iter().all(|g| g.is_finite()));
287        assert_eq!(gradients.len(), policy.num_parameters());
288    }
289
290    #[test]
291    fn test_replay_buffer() {
292        let config = BufferConfig {
293            capacity: 100,
294            ..Default::default()
295        };
296        let mut buffer = ReplayBuffer::new(config).unwrap();
297
298        // Add experiences
299        for i in 0..50 {
300            buffer.push(
301                Experience {
302                    state: vec![i as f64; 4],
303                    action: i % 2,
304                    reward: 1.0,
305                    next_state: vec![(i + 1) as f64; 4],
306                    done: false,
307                    log_prob: -0.5,
308                },
309                None,
310            );
311        }
312
313        // Sample
314        let batch = buffer.sample(10).unwrap();
315        assert_eq!(batch.experiences.len(), 10);
316    }
317
318    #[test]
319    fn test_environment_interface() {
320        let config = GridWorldConfig::default();
321        let mut env = GridWorld::new(config).unwrap();
322
323        // Reset
324        let state = env.reset().unwrap();
325        assert_eq!(state.len(), env.state_dim());
326
327        // Step
328        let result = env.step(0).unwrap();
329        assert_eq!(result.state.len(), env.state_dim());
330    }
331
332    #[test]
333    fn test_learning_signal() {
334        // Test that policy parameters change during training
335        let policy_config = PolicyConfig {
336            num_qubits: 4,
337            num_layers: 1,
338            num_actions: 4,
339            seed: Some(42),
340            ..Default::default()
341        };
342        let policy = QuantumPolicy::new(policy_config).unwrap();
343        let initial_params = policy.get_parameters_flat();
344
345        let env = GridWorld::new(GridWorldConfig::default()).unwrap();
346
347        let trainer_config = TrainerConfig {
348            episodes_per_update: 5,
349            max_steps_per_episode: 20,
350            verbose: false,
351            seed: Some(42),
352            ..Default::default()
353        };
354        let mut trainer = Trainer::new(trainer_config, policy, env).unwrap();
355
356        trainer.train(3).unwrap();
357
358        let final_params = trainer.policy().get_parameters_flat();
359
360        // Parameters should have changed
361        let param_changed = initial_params
362            .iter()
363            .zip(final_params.iter())
364            .any(|(a, b)| (a - b).abs() > 1e-10);
365
366        assert!(param_changed, "Policy parameters should change during training");
367    }
368}