1use crate::error::{ClusteringError, Result};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::{Path, PathBuf};
10use std::time::{SystemTime, UNIX_EPOCH};
11
12use super::core::SerializableModel;
13
14#[derive(Serialize, Deserialize, Debug, Clone)]
16pub struct ClusteringWorkflow {
17 pub workflow_id: String,
19 pub current_step: usize,
21 pub steps: Vec<TrainingStep>,
23 pub current_state: AlgorithmState,
25 pub config: WorkflowConfig,
27 pub execution_history: Vec<ExecutionRecord>,
29 pub intermediate_results: HashMap<String, serde_json::Value>,
31}
32
33#[derive(Serialize, Deserialize, Debug, Clone)]
35pub enum AlgorithmState {
36 NotStarted,
38 Running {
40 iteration: usize,
42 start_time: u64,
44 progress: f32,
46 },
47 Completed {
49 iterations: usize,
51 execution_time: f64,
53 final_metrics: HashMap<String, f64>,
55 },
56 Failed {
58 error: String,
60 failure_time: u64,
62 },
63 Paused {
65 pause_time: u64,
67 paused_at_iteration: usize,
69 },
70}
71
72#[derive(Serialize, Deserialize, Debug, Clone)]
74pub struct TrainingStep {
75 pub name: String,
77 pub algorithm: String,
79 pub parameters: HashMap<String, serde_json::Value>,
81 pub dependencies: Vec<String>,
83 pub completed: bool,
85 pub execution_time: Option<f64>,
87 pub results: Option<serde_json::Value>,
89}
90
91#[derive(Serialize, Deserialize, Debug, Clone, Default)]
93pub struct WorkflowConfig {
94 pub auto_save_interval: Option<u64>,
96 pub max_retries: usize,
98 pub step_timeout: Option<u64>,
100 pub parallel_execution: bool,
102 pub checkpoint_dir: Option<PathBuf>,
104}
105
106#[derive(Serialize, Deserialize, Debug, Clone)]
108pub struct ExecutionRecord {
109 pub timestamp: u64,
111 pub step_name: String,
113 pub action: String,
115 pub result: ExecutionResult,
117 pub metadata: HashMap<String, serde_json::Value>,
119}
120
121#[derive(Serialize, Deserialize, Debug, Clone)]
123pub enum ExecutionResult {
124 Success {
126 duration: f64,
128 output: Option<serde_json::Value>,
130 },
131 Failure {
133 error: String,
135 error_code: Option<String>,
137 },
138 Skipped {
140 reason: String,
142 },
143}
144
145impl ClusteringWorkflow {
146 pub fn new(workflow_id: String, config: WorkflowConfig) -> Self {
148 Self {
149 workflow_id,
150 current_step: 0,
151 steps: Vec::new(),
152 current_state: AlgorithmState::NotStarted,
153 config,
154 execution_history: Vec::new(),
155 intermediate_results: HashMap::new(),
156 }
157 }
158
159 pub fn add_step(&mut self, step: TrainingStep) {
161 self.steps.push(step);
162 }
163
164 pub fn execute(&mut self) -> Result<()> {
166 self.current_state = AlgorithmState::Running {
167 iteration: 0,
168 start_time: SystemTime::now()
169 .duration_since(UNIX_EPOCH)
170 .unwrap_or_default()
171 .as_secs(),
172 progress: 0.0,
173 };
174
175 let start_time = std::time::Instant::now();
176
177 let steps_len = self.steps.len();
178 for i in 0..steps_len {
179 self.current_step = i;
180
181 let dependencies = self.steps[i].dependencies.clone();
183 if !self.check_dependencies(&dependencies)? {
184 return Err(ClusteringError::InvalidInput(format!(
185 "Dependencies not satisfied for step: {}",
186 self.steps[i].name
187 )));
188 }
189
190 let step_start = std::time::Instant::now();
192 let step_clone = self.steps[i].clone();
193 let result = self.execute_step(&step_clone)?;
194 let step_duration = step_start.elapsed().as_secs_f64();
195
196 self.steps[i].completed = true;
198 self.steps[i].execution_time = Some(step_duration);
199 self.steps[i].results = Some(result.clone());
200
201 let step_name = self.steps[i].name.clone();
203 self.record_execution(
204 &step_name,
205 "execute",
206 ExecutionResult::Success {
207 duration: step_duration,
208 output: Some(result),
209 },
210 );
211
212 let progress = ((i + 1) as f32 / steps_len as f32) * 100.0;
214 self.update_progress(progress);
215
216 if let Some(interval) = self.config.auto_save_interval {
218 if step_duration > interval as f64 {
219 self.save_checkpoint()?;
220 }
221 }
222 }
223
224 let total_time = start_time.elapsed().as_secs_f64();
225 self.current_state = AlgorithmState::Completed {
226 iterations: self.steps.len(),
227 execution_time: total_time,
228 final_metrics: self.collect_final_metrics(),
229 };
230
231 Ok(())
232 }
233
234 fn execute_step(&mut self, step: &TrainingStep) -> Result<serde_json::Value> {
236 use serde_json::json;
239
240 let result = match step.algorithm.as_str() {
241 "kmeans" => {
242 json!({
243 "algorithm": "kmeans",
244 "centroids": [[0.0, 0.0], [1.0, 1.0]],
245 "inertia": 0.5,
246 "iterations": 10
247 })
248 }
249 "dbscan" => {
250 json!({
251 "algorithm": "dbscan",
252 "n_clusters": 2,
253 "core_samples": [0, 1, 2],
254 "noise_points": []
255 })
256 }
257 _ => {
258 return Err(ClusteringError::InvalidInput(format!(
259 "Unknown algorithm: {}",
260 step.algorithm
261 )));
262 }
263 };
264
265 self.intermediate_results
267 .insert(step.name.clone(), result.clone());
268
269 Ok(result)
270 }
271
272 fn check_dependencies(&self, dependencies: &[String]) -> Result<bool> {
274 for dep in dependencies {
275 if !self.steps.iter().any(|s| s.name == *dep && s.completed) {
276 return Ok(false);
277 }
278 }
279 Ok(true)
280 }
281
282 fn update_progress(&mut self, progress: f32) {
284 if let AlgorithmState::Running {
285 iteration,
286 start_time,
287 ..
288 } = &mut self.current_state
289 {
290 self.current_state = AlgorithmState::Running {
291 iteration: *iteration + 1,
292 start_time: *start_time,
293 progress,
294 };
295 }
296 }
297
298 fn record_execution(&mut self, step_name: &str, action: &str, result: ExecutionResult) {
300 let record = ExecutionRecord {
301 timestamp: SystemTime::now()
302 .duration_since(UNIX_EPOCH)
303 .unwrap_or_default()
304 .as_secs(),
305 step_name: step_name.to_string(),
306 action: action.to_string(),
307 result,
308 metadata: HashMap::new(),
309 };
310
311 self.execution_history.push(record);
312 }
313
314 fn collect_final_metrics(&self) -> HashMap<String, f64> {
316 let mut metrics = HashMap::new();
317
318 let total_steps = self.steps.len() as f64;
319 let completed_steps = self.steps.iter().filter(|s| s.completed).count() as f64;
320 let total_time: f64 = self.steps.iter().filter_map(|s| s.execution_time).sum();
321
322 metrics.insert("total_steps".to_string(), total_steps);
323 metrics.insert("completed_steps".to_string(), completed_steps);
324 metrics.insert("completion_rate".to_string(), completed_steps / total_steps);
325 metrics.insert("total_execution_time".to_string(), total_time);
326
327 metrics
328 }
329
330 pub fn save_checkpoint(&self) -> Result<()> {
332 if let Some(ref checkpoint_dir) = self.config.checkpoint_dir {
333 std::fs::create_dir_all(checkpoint_dir)
334 .map_err(|e| ClusteringError::InvalidInput(e.to_string()))?;
335
336 let checkpoint_file =
337 checkpoint_dir.join(format!("{}_checkpoint.json", self.workflow_id));
338 self.save_to_file(checkpoint_file)?;
339 }
340
341 Ok(())
342 }
343
344 pub fn load_checkpoint<P: AsRef<Path>>(path: P) -> Result<Self> {
346 Self::load_from_file(path)
347 }
348
349 pub fn pause(&mut self) {
351 let current_iteration = match &self.current_state {
352 AlgorithmState::Running { iteration, .. } => *iteration,
353 _ => 0,
354 };
355
356 self.current_state = AlgorithmState::Paused {
357 pause_time: SystemTime::now()
358 .duration_since(UNIX_EPOCH)
359 .unwrap_or_default()
360 .as_secs(),
361 paused_at_iteration: current_iteration,
362 };
363 }
364
365 pub fn resume(&mut self) -> Result<()> {
367 if let AlgorithmState::Paused {
368 paused_at_iteration,
369 ..
370 } = &self.current_state
371 {
372 self.current_state = AlgorithmState::Running {
373 iteration: *paused_at_iteration,
374 start_time: SystemTime::now()
375 .duration_since(UNIX_EPOCH)
376 .unwrap_or_default()
377 .as_secs(),
378 progress: (*paused_at_iteration as f32 / self.steps.len() as f32) * 100.0,
379 };
380
381 self.execute_remaining_steps()
383 } else {
384 Err(ClusteringError::InvalidInput(
385 "Workflow is not in paused state".to_string(),
386 ))
387 }
388 }
389
390 fn execute_remaining_steps(&mut self) -> Result<()> {
392 let start_index = self.current_step;
393
394 let steps_len = self.steps.len();
395 for i in start_index..steps_len {
396 if !self.steps[i].completed {
397 self.current_step = i;
398 let step_start = std::time::Instant::now();
399 let step_clone = self.steps[i].clone();
400 let result = self.execute_step(&step_clone)?;
401 let step_duration = step_start.elapsed().as_secs_f64();
402
403 self.steps[i].completed = true;
404 self.steps[i].execution_time = Some(step_duration);
405 self.steps[i].results = Some(result.clone());
406
407 let step_name = self.steps[i].name.clone();
408 self.record_execution(
409 &step_name,
410 "resume_execute",
411 ExecutionResult::Success {
412 duration: step_duration,
413 output: Some(result),
414 },
415 );
416 }
417 }
418
419 let final_metrics = self.collect_final_metrics();
420 self.current_state = AlgorithmState::Completed {
421 iterations: self.steps.len(),
422 execution_time: final_metrics
423 .get("total_execution_time")
424 .copied()
425 .unwrap_or(0.0),
426 final_metrics,
427 };
428
429 Ok(())
430 }
431
432 pub fn get_progress(&self) -> f32 {
434 match &self.current_state {
435 AlgorithmState::Running { progress, .. } => *progress,
436 AlgorithmState::Completed { .. } => 100.0,
437 AlgorithmState::Failed { .. } => 0.0,
438 AlgorithmState::Paused {
439 paused_at_iteration,
440 ..
441 } => (*paused_at_iteration as f32 / self.steps.len() as f32) * 100.0,
442 AlgorithmState::NotStarted => 0.0,
443 }
444 }
445
446 pub fn get_status(&self) -> WorkflowStatus {
448 WorkflowStatus {
449 workflow_id: self.workflow_id.clone(),
450 current_step: self.current_step,
451 total_steps: self.steps.len(),
452 state: self.current_state.clone(),
453 progress: self.get_progress(),
454 completed_steps: self.steps.iter().filter(|s| s.completed).count(),
455 }
456 }
457}
458
459impl SerializableModel for ClusteringWorkflow {}
460
461#[derive(Serialize, Deserialize, Debug, Clone)]
463pub struct WorkflowStatus {
464 pub workflow_id: String,
466 pub current_step: usize,
468 pub total_steps: usize,
470 pub state: AlgorithmState,
472 pub progress: f32,
474 pub completed_steps: usize,
476}
477
478#[derive(Serialize, Deserialize, Debug, Clone)]
480pub struct ClusteringWorkflowManager {
481 pub workflows: HashMap<String, ClusteringWorkflow>,
483 pub config: ManagerConfig,
485}
486
487#[derive(Serialize, Deserialize, Debug, Clone)]
489pub struct ManagerConfig {
490 pub max_concurrent_workflows: usize,
492 pub default_checkpoint_dir: Option<PathBuf>,
494 pub global_auto_save_interval: Option<u64>,
496}
497
498impl Default for ManagerConfig {
499 fn default() -> Self {
500 Self {
501 max_concurrent_workflows: 10,
502 default_checkpoint_dir: None,
503 global_auto_save_interval: Some(300), }
505 }
506}
507
508impl ClusteringWorkflowManager {
509 pub fn new(config: ManagerConfig) -> Self {
511 Self {
512 workflows: HashMap::new(),
513 config,
514 }
515 }
516
517 pub fn add_workflow(&mut self, workflow: ClusteringWorkflow) -> Result<()> {
519 if self.workflows.len() >= self.config.max_concurrent_workflows {
520 return Err(ClusteringError::InvalidInput(
521 "Maximum number of concurrent workflows reached".to_string(),
522 ));
523 }
524
525 self.workflows
526 .insert(workflow.workflow_id.clone(), workflow);
527 Ok(())
528 }
529
530 pub fn get_workflow(&self, workflow_id: &str) -> Option<&ClusteringWorkflow> {
532 self.workflows.get(workflow_id)
533 }
534
535 pub fn get_workflow_mut(&mut self, workflow_id: &str) -> Option<&mut ClusteringWorkflow> {
537 self.workflows.get_mut(workflow_id)
538 }
539
540 pub fn execute_workflow(&mut self, workflow_id: &str) -> Result<()> {
542 if let Some(workflow) = self.workflows.get_mut(workflow_id) {
543 workflow.execute()
544 } else {
545 Err(ClusteringError::InvalidInput(format!(
546 "Workflow not found: {}",
547 workflow_id
548 )))
549 }
550 }
551
552 pub fn get_all_statuses(&self) -> HashMap<String, WorkflowStatus> {
554 self.workflows
555 .iter()
556 .map(|(id, workflow)| (id.clone(), workflow.get_status()))
557 .collect()
558 }
559
560 pub fn cleanup_completed(&mut self) {
562 self.workflows.retain(|_, workflow| {
563 !matches!(workflow.current_state, AlgorithmState::Completed { .. })
564 });
565 }
566}
567
568#[derive(Serialize, Deserialize, Debug, Clone)]
570pub struct AutoSaveConfig {
571 pub enabled: bool,
573 pub interval_seconds: u64,
575 pub save_directory: PathBuf,
577}
578
579impl Default for AutoSaveConfig {
580 fn default() -> Self {
581 Self {
582 enabled: true,
583 interval_seconds: 300, save_directory: PathBuf::from("./checkpoints"),
585 }
586 }
587}
588
589#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
591pub enum WorkflowState {
592 Created,
594 Running,
596 Paused,
598 Completed,
600 Failed(String),
602 Cancelled,
604}
605
606#[derive(Serialize, Deserialize, Debug, Clone)]
608pub enum StepResult {
609 Success {
611 output: serde_json::Value,
613 metrics: HashMap<String, f64>,
615 },
616 Failure {
618 error: String,
620 details: Option<serde_json::Value>,
622 },
623 Skipped {
625 reason: String,
627 },
628}
629
630#[derive(Serialize, Deserialize, Debug, Clone)]
632pub struct WorkflowStep {
633 pub name: String,
635 pub step_type: String,
637 pub parameters: HashMap<String, serde_json::Value>,
639 pub dependencies: Vec<String>,
641 pub expected_duration: Option<f64>,
643}
644
645#[cfg(test)]
646mod tests {
647 use super::*;
648
649 #[test]
650 fn test_workflow_creation() {
651 let config = WorkflowConfig {
652 auto_save_interval: Some(60),
653 max_retries: 3,
654 step_timeout: Some(300),
655 parallel_execution: false,
656 checkpoint_dir: None,
657 };
658
659 let workflow = ClusteringWorkflow::new("test_workflow".to_string(), config);
660 assert_eq!(workflow.workflow_id, "test_workflow");
661 assert_eq!(workflow.current_step, 0);
662 assert!(workflow.steps.is_empty());
663 }
664
665 #[test]
666 fn test_workflow_step_addition() {
667 let config = WorkflowConfig {
668 auto_save_interval: None,
669 max_retries: 1,
670 step_timeout: None,
671 parallel_execution: false,
672 checkpoint_dir: None,
673 };
674
675 let mut workflow = ClusteringWorkflow::new("test".to_string(), config);
676
677 let step = TrainingStep {
678 name: "kmeans_step".to_string(),
679 algorithm: "kmeans".to_string(),
680 parameters: HashMap::new(),
681 dependencies: Vec::new(),
682 completed: false,
683 execution_time: None,
684 results: None,
685 };
686
687 workflow.add_step(step);
688 assert_eq!(workflow.steps.len(), 1);
689 assert_eq!(workflow.steps[0].name, "kmeans_step");
690 }
691
692 #[test]
693 fn test_workflow_manager() {
694 let config = ManagerConfig::default();
695 let mut manager = ClusteringWorkflowManager::new(config);
696
697 let workflow_config = WorkflowConfig {
698 auto_save_interval: None,
699 max_retries: 1,
700 step_timeout: None,
701 parallel_execution: false,
702 checkpoint_dir: None,
703 };
704
705 let workflow = ClusteringWorkflow::new("test_workflow".to_string(), workflow_config);
706 manager.add_workflow(workflow).unwrap();
707
708 assert!(manager.get_workflow("test_workflow").is_some());
709 assert_eq!(manager.workflows.len(), 1);
710 }
711}