Skip to main content

somatize_core/
strategy.rs

1//! Training strategies for distributed execution.
2//!
3//! A [`TrainingStrategy`] is a graph-level attribute that controls HOW the
4//! Scheduler distributes work across workers and HOW workers coordinate
5//! during training (gradient aggregation, state sync, communication).
6//!
7//! Subgraphs inherit the parent's strategy unless overridden.
8
9use crate::error::Result;
10use crate::filter::RemoteTarget;
11use crate::graph::NodeId;
12use crate::value::Value;
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15
16/// Training strategy — graph-level attribute, inherited by subgraphs.
17#[derive(Debug, Clone, Default, Serialize, Deserialize)]
18#[serde(tag = "type")]
19#[non_exhaustive]
20pub enum TrainingStrategy {
21    /// All nodes execute locally (default).
22    #[default]
23    Local,
24
25    /// Replicate the entire graph on N workers, each sees a data shard.
26    /// Gradients are aggregated after each step.
27    DataParallel {
28        num_replicas: usize,
29        aggregation: GradientAggregation,
30    },
31
32    /// Arbitrary model partitioning: each Partition maps a set of
33    /// node IDs to a worker target. Any topology is supported.
34    ModelParallel {
35        partitions: Vec<Partition>,
36        communication: CommunicationProtocol,
37    },
38
39    /// Federated learning: data stays on workers, only model updates
40    /// are shared. The coordinator aggregates after each round.
41    Federated {
42        num_clients: usize,
43        rounds: usize,
44        aggregation: FederatedAggregation,
45        client_selection: ClientSelection,
46    },
47
48    /// Population-Based Training: evolutionary hyperparameter optimization.
49    /// Each generation trains a population, evaluates, then evolves.
50    PopulationBased {
51        population_size: usize,
52        generations: usize,
53        exploit: ExploitStrategy,
54        explore: ExploreStrategy,
55    },
56
57    /// User-defined strategy with a registered coordinator.
58    Custom {
59        coordinator: String,
60        config: serde_json::Value,
61    },
62}
63
64/// How gradients are aggregated across workers in data-parallel training.
65#[derive(Debug, Clone, Serialize, Deserialize)]
66#[serde(tag = "method")]
67#[non_exhaustive]
68pub enum GradientAggregation {
69    /// All workers exchange gradients (ring or tree reduction).
70    AllReduce,
71    /// A central parameter server collects and distributes updates.
72    ParameterServer,
73    /// Decentralized gossip-based aggregation.
74    Decentralized { topology: String },
75}
76
77/// A partition maps a set of node IDs to a worker target.
78///
79/// Used in `ModelParallel` to define which nodes run on which worker.
80/// The user has full control over the partitioning.
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct Partition {
83    pub node_ids: Vec<NodeId>,
84    pub target: RemoteTarget,
85}
86
87/// How model-parallel partitions communicate activations and gradients.
88#[derive(Debug, Clone, Serialize, Deserialize)]
89#[serde(tag = "protocol")]
90#[non_exhaustive]
91pub enum CommunicationProtocol {
92    /// Intermediate values flow via DataStore (S3, shared disk).
93    DataStore,
94    /// Direct point-to-point streaming between workers.
95    Direct,
96    /// Pipeline parallelism with micro-batching for overlap.
97    Pipeline { micro_batch_size: usize },
98}
99
100/// Aggregation method for federated learning rounds.
101#[derive(Debug, Clone, Serialize, Deserialize)]
102#[serde(tag = "method")]
103#[non_exhaustive]
104pub enum FederatedAggregation {
105    /// Federated Averaging: weighted mean of client updates.
106    FedAvg,
107    /// FedProx: adds proximal term to prevent client drift.
108    FedProx { mu: f64 },
109    /// FedYogi: adaptive federated optimization.
110    FedYogi { beta1: f64, beta2: f64, tau: f64 },
111}
112
113/// How clients are selected per federated round.
114#[derive(Debug, Clone, Serialize, Deserialize)]
115#[serde(tag = "method")]
116#[non_exhaustive]
117pub enum ClientSelection {
118    /// All available clients participate.
119    All,
120    /// Random subset of clients.
121    Random { fraction: f64 },
122    /// Only clients matching specific tags.
123    ByCapability { required_tags: Vec<String> },
124}
125
126/// PBT exploit strategy: how underperformers learn from top performers.
127#[derive(Debug, Clone, Serialize, Deserialize)]
128#[serde(tag = "method")]
129#[non_exhaustive]
130pub enum ExploitStrategy {
131    /// Bottom fraction copies weights+hyperparams from top fraction.
132    Truncation { fraction: f64 },
133    /// Each member is compared to a random other; loser copies winner.
134    Binary { threshold: f64 },
135}
136
137/// PBT explore strategy: how hyperparameters are mutated after exploit.
138#[derive(Debug, Clone, Serialize, Deserialize)]
139#[serde(tag = "method")]
140#[non_exhaustive]
141pub enum ExploreStrategy {
142    /// Multiply each hyperparameter by a random factor in [1-factor, 1+factor].
143    Perturbation { factor: f64 },
144    /// Resample hyperparameters from the original search space.
145    Resample,
146}
147
148// ── Traits: execution contracts for strategies and aggregation ──
149
150/// Context provided to strategy executors.
151/// Abstracts worker communication — the strategy doesn't know about WS/HTTP.
152pub trait StrategyContext {
153    /// Number of available workers.
154    fn num_workers(&self) -> usize;
155
156    /// Execute a plan on a specific worker (by index). Returns trained states.
157    fn execute_on_worker(
158        &self,
159        worker_idx: usize,
160        plan: &serde_json::Value,
161        input: &Value,
162        y: Option<&Value>,
163    ) -> Result<HashMap<String, Value>>;
164
165    /// Get trained states from a worker.
166    fn get_state(&self, worker_idx: usize, node_ids: &[String]) -> Result<HashMap<String, Value>>;
167
168    /// Set states on a worker (e.g. after aggregation).
169    fn set_state(&self, worker_idx: usize, states: &HashMap<String, Value>) -> Result<()>;
170
171    /// Get gradients from a worker.
172    fn get_gradients(
173        &self,
174        worker_idx: usize,
175        node_ids: &[String],
176    ) -> Result<HashMap<String, Value>>;
177
178    /// Apply gradients on a worker.
179    fn apply_gradients(&self, worker_idx: usize, gradients: &HashMap<String, Value>) -> Result<()>;
180}
181
182/// Contract for training strategy execution.
183/// Every TrainingStrategy variant implements this — including Local.
184pub trait StrategyExecutor {
185    /// Train the model according to this strategy.
186    fn fit(
187        &self,
188        ctx: &dyn StrategyContext,
189        input: &Value,
190        y: Option<&Value>,
191        node_ids: &[String],
192    ) -> Result<HashMap<String, Value>>;
193}
194
195/// Contract for gradient aggregation across workers.
196pub trait GradientAggregator {
197    fn aggregate(&self, gradients: &[HashMap<String, Value>]) -> Result<HashMap<String, Value>>;
198}
199
200/// Contract for federated state aggregation.
201pub trait StateAggregator {
202    fn aggregate(&self, states: &[HashMap<String, Value>]) -> Result<HashMap<String, Value>>;
203}
204
205// ── Trait implementations ──
206
207impl StrategyExecutor for TrainingStrategy {
208    fn fit(
209        &self,
210        ctx: &dyn StrategyContext,
211        input: &Value,
212        y: Option<&Value>,
213        node_ids: &[String],
214    ) -> Result<HashMap<String, Value>> {
215        match self {
216            TrainingStrategy::Local => {
217                // Single worker, full dataset
218                ctx.execute_on_worker(0, &serde_json::json!({}), input, y)
219            }
220
221            TrainingStrategy::DataParallel {
222                num_replicas,
223                aggregation,
224            } => {
225                let n = (*num_replicas).min(ctx.num_workers());
226                let shards = shard_value(input, n);
227
228                // Fit on each worker with its shard
229                for (i, shard) in shards.iter().enumerate() {
230                    ctx.execute_on_worker(i, &serde_json::json!({}), shard, y)?;
231                }
232
233                // Collect and aggregate gradients
234                let mut all_grads = Vec::new();
235                for i in 0..n {
236                    all_grads.push(ctx.get_gradients(i, node_ids)?);
237                }
238                let averaged = aggregation.aggregate(&all_grads)?;
239
240                // Apply to all workers
241                for i in 0..n {
242                    ctx.apply_gradients(i, &averaged)?;
243                }
244
245                // Return states from first worker
246                ctx.get_state(0, node_ids)
247            }
248
249            TrainingStrategy::Federated {
250                num_clients,
251                rounds,
252                aggregation,
253                ..
254            } => {
255                let n = (*num_clients).min(ctx.num_workers());
256                let shards = shard_value(input, n);
257
258                for _round in 0..*rounds {
259                    // Each client trains on its shard
260                    for (i, shard) in shards.iter().enumerate().take(n) {
261                        ctx.execute_on_worker(i, &serde_json::json!({}), shard, y)?;
262                    }
263
264                    // Collect and aggregate states
265                    let mut all_states = Vec::new();
266                    for i in 0..n {
267                        all_states.push(ctx.get_state(i, node_ids)?);
268                    }
269                    let aggregated = aggregation.aggregate(&all_states)?;
270
271                    // Distribute back
272                    for i in 0..n {
273                        ctx.set_state(i, &aggregated)?;
274                    }
275                }
276
277                ctx.get_state(0, node_ids)
278            }
279
280            TrainingStrategy::ModelParallel { .. } => {
281                // TODO: forward/backward across partitions
282                Err(crate::error::SomaError::Other(
283                    "ModelParallel strategy execution not yet implemented".into(),
284                ))
285            }
286
287            TrainingStrategy::PopulationBased { .. } => {
288                // TODO: PBT cycle
289                Err(crate::error::SomaError::Other(
290                    "PopulationBased strategy execution not yet implemented".into(),
291                ))
292            }
293
294            TrainingStrategy::Custom { .. } => Err(crate::error::SomaError::Other(
295                "Custom strategy requires a user-provided coordinator".into(),
296            )),
297        }
298    }
299}
300
301impl GradientAggregator for GradientAggregation {
302    fn aggregate(&self, gradients: &[HashMap<String, Value>]) -> Result<HashMap<String, Value>> {
303        match self {
304            GradientAggregation::AllReduce | GradientAggregation::ParameterServer => {
305                // TODO: proper tensor averaging
306                Ok(gradients.first().cloned().unwrap_or_default())
307            }
308            GradientAggregation::Decentralized { .. } => {
309                Ok(gradients.first().cloned().unwrap_or_default())
310            }
311        }
312    }
313}
314
315impl StateAggregator for FederatedAggregation {
316    fn aggregate(&self, states: &[HashMap<String, Value>]) -> Result<HashMap<String, Value>> {
317        match self {
318            FederatedAggregation::FedAvg
319            | FederatedAggregation::FedProx { .. }
320            | FederatedAggregation::FedYogi { .. } => {
321                // TODO: proper tensor averaging
322                Ok(states.first().cloned().unwrap_or_default())
323            }
324        }
325    }
326}
327
328/// Split a Value::Tensor along the first dimension into N shards.
329fn shard_value(value: &Value, n: usize) -> Vec<Value> {
330    match value {
331        Value::Tensor { values, shape } if !shape.is_empty() && shape[0] >= n => {
332            let rows = shape[0];
333            let row_size: usize = shape[1..].iter().product::<usize>().max(1);
334            let shard_rows = rows / n;
335            let mut shards = Vec::new();
336            for i in 0..n {
337                let start = i * shard_rows;
338                let end = if i == n - 1 { rows } else { start + shard_rows };
339                let flat_start = start * row_size;
340                let flat_end = end * row_size;
341                let shard_vals = values[flat_start..flat_end].to_vec();
342                let mut shard_shape = shape.clone();
343                shard_shape[0] = end - start;
344                shards.push(Value::tensor(shard_vals, shard_shape));
345            }
346            shards
347        }
348        _ => (0..n).map(|_| value.clone()).collect(),
349    }
350}
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355
356    #[test]
357    fn default_is_local() {
358        assert!(matches!(
359            TrainingStrategy::default(),
360            TrainingStrategy::Local
361        ));
362    }
363
364    #[test]
365    fn serde_roundtrip_data_parallel() {
366        let strategy = TrainingStrategy::DataParallel {
367            num_replicas: 4,
368            aggregation: GradientAggregation::AllReduce,
369        };
370        let json = serde_json::to_string(&strategy).unwrap();
371        let parsed: TrainingStrategy = serde_json::from_str(&json).unwrap();
372        assert!(matches!(
373            parsed,
374            TrainingStrategy::DataParallel {
375                num_replicas: 4,
376                ..
377            }
378        ));
379    }
380
381    #[test]
382    fn serde_roundtrip_model_parallel() {
383        let strategy = TrainingStrategy::ModelParallel {
384            partitions: vec![
385                Partition {
386                    node_ids: vec!["embed".into(), "backbone".into()],
387                    target: RemoteTarget::Tag("gpu-0".into()),
388                },
389                Partition {
390                    node_ids: vec!["head_a".into()],
391                    target: RemoteTarget::Tag("gpu-1".into()),
392                },
393            ],
394            communication: CommunicationProtocol::Pipeline {
395                micro_batch_size: 4,
396            },
397        };
398        let json = serde_json::to_string(&strategy).unwrap();
399        let parsed: TrainingStrategy = serde_json::from_str(&json).unwrap();
400        assert!(matches!(parsed, TrainingStrategy::ModelParallel { .. }));
401    }
402
403    #[test]
404    fn serde_roundtrip_federated() {
405        let strategy = TrainingStrategy::Federated {
406            num_clients: 10,
407            rounds: 50,
408            aggregation: FederatedAggregation::FedProx { mu: 0.01 },
409            client_selection: ClientSelection::Random { fraction: 0.3 },
410        };
411        let json = serde_json::to_string(&strategy).unwrap();
412        let parsed: TrainingStrategy = serde_json::from_str(&json).unwrap();
413        assert!(matches!(
414            parsed,
415            TrainingStrategy::Federated {
416                num_clients: 10,
417                rounds: 50,
418                ..
419            }
420        ));
421    }
422
423    #[test]
424    fn serde_roundtrip_pbt() {
425        let strategy = TrainingStrategy::PopulationBased {
426            population_size: 20,
427            generations: 50,
428            exploit: ExploitStrategy::Truncation { fraction: 0.2 },
429            explore: ExploreStrategy::Perturbation { factor: 0.2 },
430        };
431        let json = serde_json::to_string(&strategy).unwrap();
432        let parsed: TrainingStrategy = serde_json::from_str(&json).unwrap();
433        assert!(matches!(
434            parsed,
435            TrainingStrategy::PopulationBased {
436                population_size: 20,
437                ..
438            }
439        ));
440    }
441}