Skip to main content

somatize_core/
strategy.rs

1//! Training strategies for distributed execution.
2//!
3//! A [`TrainingStrategy`] is a graph-level attribute that controls HOW the
4//! Scheduler distributes work across workers and HOW workers coordinate
5//! during training (gradient aggregation, state sync, communication).
6//!
7//! Subgraphs inherit the parent's strategy unless overridden.
8
9use crate::filter::RemoteTarget;
10use crate::graph::NodeId;
11use serde::{Deserialize, Serialize};
12
13/// Training strategy — graph-level attribute, inherited by subgraphs.
14#[derive(Debug, Clone, Default, Serialize, Deserialize)]
15#[serde(tag = "type")]
16#[non_exhaustive]
17pub enum TrainingStrategy {
18    /// All nodes execute locally (default).
19    #[default]
20    Local,
21
22    /// Replicate the entire graph on N workers, each sees a data shard.
23    /// Gradients are aggregated after each step.
24    DataParallel {
25        num_replicas: usize,
26        aggregation: GradientAggregation,
27    },
28
29    /// Arbitrary model partitioning: each Partition maps a set of
30    /// node IDs to a worker target. Any topology is supported.
31    ModelParallel {
32        partitions: Vec<Partition>,
33        communication: CommunicationProtocol,
34    },
35
36    /// Federated learning: data stays on workers, only model updates
37    /// are shared. The coordinator aggregates after each round.
38    Federated {
39        num_clients: usize,
40        rounds: usize,
41        aggregation: FederatedAggregation,
42        client_selection: ClientSelection,
43    },
44
45    /// Population-Based Training: evolutionary hyperparameter optimization.
46    /// Each generation trains a population, evaluates, then evolves.
47    PopulationBased {
48        population_size: usize,
49        generations: usize,
50        exploit: ExploitStrategy,
51        explore: ExploreStrategy,
52    },
53
54    /// User-defined strategy with a registered coordinator.
55    Custom {
56        coordinator: String,
57        config: serde_json::Value,
58    },
59}
60
61/// How gradients are aggregated across workers in data-parallel training.
62#[derive(Debug, Clone, Serialize, Deserialize)]
63#[serde(tag = "method")]
64#[non_exhaustive]
65pub enum GradientAggregation {
66    /// All workers exchange gradients (ring or tree reduction).
67    AllReduce,
68    /// A central parameter server collects and distributes updates.
69    ParameterServer,
70    /// Decentralized gossip-based aggregation.
71    Decentralized { topology: String },
72}
73
74/// A partition maps a set of node IDs to a worker target.
75///
76/// Used in `ModelParallel` to define which nodes run on which worker.
77/// The user has full control over the partitioning.
78#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct Partition {
80    pub node_ids: Vec<NodeId>,
81    pub target: RemoteTarget,
82}
83
84/// How model-parallel partitions communicate activations and gradients.
85#[derive(Debug, Clone, Serialize, Deserialize)]
86#[serde(tag = "protocol")]
87#[non_exhaustive]
88pub enum CommunicationProtocol {
89    /// Intermediate values flow via DataStore (S3, shared disk).
90    DataStore,
91    /// Direct point-to-point streaming between workers.
92    Direct,
93    /// Pipeline parallelism with micro-batching for overlap.
94    Pipeline { micro_batch_size: usize },
95}
96
97/// Aggregation method for federated learning rounds.
98#[derive(Debug, Clone, Serialize, Deserialize)]
99#[serde(tag = "method")]
100#[non_exhaustive]
101pub enum FederatedAggregation {
102    /// Federated Averaging: weighted mean of client updates.
103    FedAvg,
104    /// FedProx: adds proximal term to prevent client drift.
105    FedProx { mu: f64 },
106    /// FedYogi: adaptive federated optimization.
107    FedYogi { beta1: f64, beta2: f64, tau: f64 },
108}
109
110/// How clients are selected per federated round.
111#[derive(Debug, Clone, Serialize, Deserialize)]
112#[serde(tag = "method")]
113#[non_exhaustive]
114pub enum ClientSelection {
115    /// All available clients participate.
116    All,
117    /// Random subset of clients.
118    Random { fraction: f64 },
119    /// Only clients matching specific tags.
120    ByCapability { required_tags: Vec<String> },
121}
122
123/// PBT exploit strategy: how underperformers learn from top performers.
124#[derive(Debug, Clone, Serialize, Deserialize)]
125#[serde(tag = "method")]
126#[non_exhaustive]
127pub enum ExploitStrategy {
128    /// Bottom fraction copies weights+hyperparams from top fraction.
129    Truncation { fraction: f64 },
130    /// Each member is compared to a random other; loser copies winner.
131    Binary { threshold: f64 },
132}
133
134/// PBT explore strategy: how hyperparameters are mutated after exploit.
135#[derive(Debug, Clone, Serialize, Deserialize)]
136#[serde(tag = "method")]
137#[non_exhaustive]
138pub enum ExploreStrategy {
139    /// Multiply each hyperparameter by a random factor in [1-factor, 1+factor].
140    Perturbation { factor: f64 },
141    /// Resample hyperparameters from the original search space.
142    Resample,
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148
149    #[test]
150    fn default_is_local() {
151        assert!(matches!(
152            TrainingStrategy::default(),
153            TrainingStrategy::Local
154        ));
155    }
156
157    #[test]
158    fn serde_roundtrip_data_parallel() {
159        let strategy = TrainingStrategy::DataParallel {
160            num_replicas: 4,
161            aggregation: GradientAggregation::AllReduce,
162        };
163        let json = serde_json::to_string(&strategy).unwrap();
164        let parsed: TrainingStrategy = serde_json::from_str(&json).unwrap();
165        assert!(matches!(
166            parsed,
167            TrainingStrategy::DataParallel {
168                num_replicas: 4,
169                ..
170            }
171        ));
172    }
173
174    #[test]
175    fn serde_roundtrip_model_parallel() {
176        let strategy = TrainingStrategy::ModelParallel {
177            partitions: vec![
178                Partition {
179                    node_ids: vec!["embed".into(), "backbone".into()],
180                    target: RemoteTarget::Tag("gpu-0".into()),
181                },
182                Partition {
183                    node_ids: vec!["head_a".into()],
184                    target: RemoteTarget::Tag("gpu-1".into()),
185                },
186            ],
187            communication: CommunicationProtocol::Pipeline {
188                micro_batch_size: 4,
189            },
190        };
191        let json = serde_json::to_string(&strategy).unwrap();
192        let parsed: TrainingStrategy = serde_json::from_str(&json).unwrap();
193        assert!(matches!(parsed, TrainingStrategy::ModelParallel { .. }));
194    }
195
196    #[test]
197    fn serde_roundtrip_federated() {
198        let strategy = TrainingStrategy::Federated {
199            num_clients: 10,
200            rounds: 50,
201            aggregation: FederatedAggregation::FedProx { mu: 0.01 },
202            client_selection: ClientSelection::Random { fraction: 0.3 },
203        };
204        let json = serde_json::to_string(&strategy).unwrap();
205        let parsed: TrainingStrategy = serde_json::from_str(&json).unwrap();
206        assert!(matches!(
207            parsed,
208            TrainingStrategy::Federated {
209                num_clients: 10,
210                rounds: 50,
211                ..
212            }
213        ));
214    }
215
216    #[test]
217    fn serde_roundtrip_pbt() {
218        let strategy = TrainingStrategy::PopulationBased {
219            population_size: 20,
220            generations: 50,
221            exploit: ExploitStrategy::Truncation { fraction: 0.2 },
222            explore: ExploreStrategy::Perturbation { factor: 0.2 },
223        };
224        let json = serde_json::to_string(&strategy).unwrap();
225        let parsed: TrainingStrategy = serde_json::from_str(&json).unwrap();
226        assert!(matches!(
227            parsed,
228            TrainingStrategy::PopulationBased {
229                population_size: 20,
230                ..
231            }
232        ));
233    }
234}