1#![warn(missing_docs)]
88#![warn(clippy::all)]
89#![deny(unsafe_code)]
90
91pub mod buffer;
92pub mod environment;
93pub mod error;
94pub mod gradient;
95pub mod policy;
96pub mod training;
97
98pub mod prelude {
100 pub use crate::buffer::{BufferConfig, ReplayBuffer, SampleBatch, TrajectoryBuffer};
101 pub use crate::environment::{
102 BinaryChoice, CartPole, CartPoleConfig, Environment, GridWorld, GridWorldConfig, StepResult,
103 };
104 pub use crate::error::{
105 BufferError, EnvironmentError, GradientError, PolicyError, QarlpError, Result,
106 TrainingError,
107 };
108 pub use crate::gradient::{
109 compute_gae, normalize_advantages, Experience, GradientConfig, PolicyGradient, Trajectory,
110 };
111 pub use crate::policy::{PolicyConfig, QuantumPolicy};
112 pub use crate::training::{
113 Checkpoint, EvaluationResult, IterationMetrics, LoggingCallback, NoOpCallback, Trainer,
114 TrainerConfig, TrainingCallback, TrainingOutcome,
115 };
116}
117
118#[cfg(test)]
119mod tests {
120 use super::prelude::*;
121
122 #[test]
123 fn test_end_to_end_gridworld() {
124 let policy_config = PolicyConfig {
126 num_qubits: 4,
127 num_layers: 1,
128 num_actions: 4,
129 seed: Some(42),
130 ..Default::default()
131 };
132 let policy = QuantumPolicy::new(policy_config).unwrap();
133
134 let env_config = GridWorldConfig {
136 width: 3,
137 height: 3,
138 goal: (2, 2),
139 start: Some((0, 0)),
140 max_steps: 20,
141 ..Default::default()
142 };
143 let env = GridWorld::new(env_config).unwrap();
144
145 let trainer_config = TrainerConfig {
147 episodes_per_update: 5,
148 max_steps_per_episode: 20,
149 verbose: false,
150 seed: Some(42),
151 ..Default::default()
152 };
153 let mut trainer = Trainer::new(trainer_config, policy, env).unwrap();
154
155 let result = trainer.train(5).unwrap();
157
158 assert_eq!(result.total_iterations, 5);
160 assert!(result.final_average_reward.is_finite());
161 }
162
163 #[test]
164 fn test_end_to_end_cartpole() {
165 let policy_config = PolicyConfig {
167 num_qubits: 4,
168 num_layers: 2,
169 num_actions: 2,
170 seed: Some(42),
171 ..Default::default()
172 };
173 let policy = QuantumPolicy::new(policy_config).unwrap();
174
175 let env_config = CartPoleConfig {
177 max_steps: 50,
178 seed: Some(42),
179 ..Default::default()
180 };
181 let env = CartPole::new(env_config).unwrap();
182
183 let trainer_config = TrainerConfig {
185 episodes_per_update: 5,
186 max_steps_per_episode: 50,
187 verbose: false,
188 seed: Some(42),
189 gradient_config: GradientConfig {
190 learning_rate: 0.01,
191 ..Default::default()
192 },
193 ..Default::default()
194 };
195 let mut trainer = Trainer::new(trainer_config, policy, env).unwrap();
196
197 let result = trainer.train(3).unwrap();
199
200 assert_eq!(result.total_iterations, 3);
202 assert!(result.total_episodes > 0);
203 }
204
205 #[test]
206 fn test_end_to_end_binary_choice() {
207 let policy_config = PolicyConfig {
209 num_qubits: 4,
210 num_layers: 1,
211 num_actions: 2,
212 seed: Some(42),
213 ..Default::default()
214 };
215 let policy = QuantumPolicy::new(policy_config).unwrap();
216
217 let env = BinaryChoice::new(10).unwrap();
219
220 let trainer_config = TrainerConfig {
222 episodes_per_update: 10,
223 max_steps_per_episode: 10,
224 verbose: false,
225 seed: Some(42),
226 ..Default::default()
227 };
228 let mut trainer = Trainer::new(trainer_config, policy, env).unwrap();
229
230 let result = trainer.train(5).unwrap();
232
233 assert_eq!(result.total_iterations, 5);
235 }
236
237 #[test]
238 fn test_policy_forward_pass() {
239 let config = PolicyConfig {
240 num_qubits: 4,
241 num_layers: 2,
242 num_actions: 2,
243 seed: Some(42),
244 ..Default::default()
245 };
246 let policy = QuantumPolicy::new(config).unwrap();
247
248 let state = vec![0.5, -0.3, 0.1, 0.8];
249 let probs = policy.forward(&state).unwrap();
250
251 let sum: f64 = probs.iter().sum();
253 assert!((sum - 1.0).abs() < 1e-6);
254 }
255
256 #[test]
257 fn test_gradient_computation() {
258 let policy_config = PolicyConfig {
259 num_qubits: 4,
260 num_layers: 1,
261 num_actions: 2,
262 seed: Some(42),
263 ..Default::default()
264 };
265 let policy = QuantumPolicy::new(policy_config).unwrap();
266
267 let mut trajectory = Trajectory::new();
269 for _ in 0..5 {
270 trajectory.push(Experience {
271 state: vec![0.1, 0.2, 0.3, 0.4],
272 action: 0,
273 reward: 1.0,
274 next_state: vec![0.2, 0.3, 0.4, 0.5],
275 done: false,
276 log_prob: -0.5,
277 });
278 }
279
280 let config = GradientConfig::default();
281 let mut pg = PolicyGradient::new(config).unwrap();
282
283 let gradients = pg.compute_gradient(&policy, &trajectory).unwrap();
284
285 assert!(gradients.iter().all(|g| g.is_finite()));
287 assert_eq!(gradients.len(), policy.num_parameters());
288 }
289
290 #[test]
291 fn test_replay_buffer() {
292 let config = BufferConfig {
293 capacity: 100,
294 ..Default::default()
295 };
296 let mut buffer = ReplayBuffer::new(config).unwrap();
297
298 for i in 0..50 {
300 buffer.push(
301 Experience {
302 state: vec![i as f64; 4],
303 action: i % 2,
304 reward: 1.0,
305 next_state: vec![(i + 1) as f64; 4],
306 done: false,
307 log_prob: -0.5,
308 },
309 None,
310 );
311 }
312
313 let batch = buffer.sample(10).unwrap();
315 assert_eq!(batch.experiences.len(), 10);
316 }
317
318 #[test]
319 fn test_environment_interface() {
320 let config = GridWorldConfig::default();
321 let mut env = GridWorld::new(config).unwrap();
322
323 let state = env.reset().unwrap();
325 assert_eq!(state.len(), env.state_dim());
326
327 let result = env.step(0).unwrap();
329 assert_eq!(result.state.len(), env.state_dim());
330 }
331
332 #[test]
333 fn test_learning_signal() {
334 let policy_config = PolicyConfig {
336 num_qubits: 4,
337 num_layers: 1,
338 num_actions: 4,
339 seed: Some(42),
340 ..Default::default()
341 };
342 let policy = QuantumPolicy::new(policy_config).unwrap();
343 let initial_params = policy.get_parameters_flat();
344
345 let env = GridWorld::new(GridWorldConfig::default()).unwrap();
346
347 let trainer_config = TrainerConfig {
348 episodes_per_update: 5,
349 max_steps_per_episode: 20,
350 verbose: false,
351 seed: Some(42),
352 ..Default::default()
353 };
354 let mut trainer = Trainer::new(trainer_config, policy, env).unwrap();
355
356 trainer.train(3).unwrap();
357
358 let final_params = trainer.policy().get_parameters_flat();
359
360 let param_changed = initial_params
362 .iter()
363 .zip(final_params.iter())
364 .any(|(a, b)| (a - b).abs() > 1e-10);
365
366 assert!(param_changed, "Policy parameters should change during training");
367 }
368}