1use crate::error::Result;
10use crate::filter::RemoteTarget;
11use crate::graph::NodeId;
12use crate::value::Value;
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15
16#[derive(Debug, Clone, Default, Serialize, Deserialize)]
18#[serde(tag = "type")]
19#[non_exhaustive]
20pub enum TrainingStrategy {
21 #[default]
23 Local,
24
25 DataParallel {
28 num_replicas: usize,
29 aggregation: GradientAggregation,
30 },
31
32 ModelParallel {
35 partitions: Vec<Partition>,
36 communication: CommunicationProtocol,
37 },
38
39 Federated {
42 num_clients: usize,
43 rounds: usize,
44 aggregation: FederatedAggregation,
45 client_selection: ClientSelection,
46 },
47
48 PopulationBased {
51 population_size: usize,
52 generations: usize,
53 exploit: ExploitStrategy,
54 explore: ExploreStrategy,
55 },
56
57 Custom {
59 coordinator: String,
60 config: serde_json::Value,
61 },
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66#[serde(tag = "method")]
67#[non_exhaustive]
68pub enum GradientAggregation {
69 AllReduce,
71 ParameterServer,
73 Decentralized { topology: String },
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct Partition {
83 pub node_ids: Vec<NodeId>,
84 pub target: RemoteTarget,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
89#[serde(tag = "protocol")]
90#[non_exhaustive]
91pub enum CommunicationProtocol {
92 DataStore,
94 Direct,
96 Pipeline { micro_batch_size: usize },
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
102#[serde(tag = "method")]
103#[non_exhaustive]
104pub enum FederatedAggregation {
105 FedAvg,
107 FedProx { mu: f64 },
109 FedYogi { beta1: f64, beta2: f64, tau: f64 },
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
115#[serde(tag = "method")]
116#[non_exhaustive]
117pub enum ClientSelection {
118 All,
120 Random { fraction: f64 },
122 ByCapability { required_tags: Vec<String> },
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
128#[serde(tag = "method")]
129#[non_exhaustive]
130pub enum ExploitStrategy {
131 Truncation { fraction: f64 },
133 Binary { threshold: f64 },
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
139#[serde(tag = "method")]
140#[non_exhaustive]
141pub enum ExploreStrategy {
142 Perturbation { factor: f64 },
144 Resample,
146}
147
148pub trait StrategyContext {
153 fn num_workers(&self) -> usize;
155
156 fn execute_on_worker(
158 &self,
159 worker_idx: usize,
160 plan: &serde_json::Value,
161 input: &Value,
162 y: Option<&Value>,
163 ) -> Result<HashMap<String, Value>>;
164
165 fn get_state(&self, worker_idx: usize, node_ids: &[String]) -> Result<HashMap<String, Value>>;
167
168 fn set_state(&self, worker_idx: usize, states: &HashMap<String, Value>) -> Result<()>;
170
171 fn get_gradients(
173 &self,
174 worker_idx: usize,
175 node_ids: &[String],
176 ) -> Result<HashMap<String, Value>>;
177
178 fn apply_gradients(&self, worker_idx: usize, gradients: &HashMap<String, Value>) -> Result<()>;
180}
181
182pub trait StrategyExecutor {
185 fn fit(
187 &self,
188 ctx: &dyn StrategyContext,
189 input: &Value,
190 y: Option<&Value>,
191 node_ids: &[String],
192 ) -> Result<HashMap<String, Value>>;
193}
194
195pub trait GradientAggregator {
197 fn aggregate(&self, gradients: &[HashMap<String, Value>]) -> Result<HashMap<String, Value>>;
198}
199
200pub trait StateAggregator {
202 fn aggregate(&self, states: &[HashMap<String, Value>]) -> Result<HashMap<String, Value>>;
203}
204
205impl StrategyExecutor for TrainingStrategy {
208 fn fit(
209 &self,
210 ctx: &dyn StrategyContext,
211 input: &Value,
212 y: Option<&Value>,
213 node_ids: &[String],
214 ) -> Result<HashMap<String, Value>> {
215 match self {
216 TrainingStrategy::Local => {
217 ctx.execute_on_worker(0, &serde_json::json!({}), input, y)
219 }
220
221 TrainingStrategy::DataParallel {
222 num_replicas,
223 aggregation,
224 } => {
225 let n = (*num_replicas).min(ctx.num_workers());
226 let shards = shard_value(input, n);
227
228 for (i, shard) in shards.iter().enumerate() {
230 ctx.execute_on_worker(i, &serde_json::json!({}), shard, y)?;
231 }
232
233 let mut all_grads = Vec::new();
235 for i in 0..n {
236 all_grads.push(ctx.get_gradients(i, node_ids)?);
237 }
238 let averaged = aggregation.aggregate(&all_grads)?;
239
240 for i in 0..n {
242 ctx.apply_gradients(i, &averaged)?;
243 }
244
245 ctx.get_state(0, node_ids)
247 }
248
249 TrainingStrategy::Federated {
250 num_clients,
251 rounds,
252 aggregation,
253 ..
254 } => {
255 let n = (*num_clients).min(ctx.num_workers());
256 let shards = shard_value(input, n);
257
258 for _round in 0..*rounds {
259 for (i, shard) in shards.iter().enumerate().take(n) {
261 ctx.execute_on_worker(i, &serde_json::json!({}), shard, y)?;
262 }
263
264 let mut all_states = Vec::new();
266 for i in 0..n {
267 all_states.push(ctx.get_state(i, node_ids)?);
268 }
269 let aggregated = aggregation.aggregate(&all_states)?;
270
271 for i in 0..n {
273 ctx.set_state(i, &aggregated)?;
274 }
275 }
276
277 ctx.get_state(0, node_ids)
278 }
279
280 TrainingStrategy::ModelParallel { .. } => {
281 Err(crate::error::SomaError::Other(
283 "ModelParallel strategy execution not yet implemented".into(),
284 ))
285 }
286
287 TrainingStrategy::PopulationBased { .. } => {
288 Err(crate::error::SomaError::Other(
290 "PopulationBased strategy execution not yet implemented".into(),
291 ))
292 }
293
294 TrainingStrategy::Custom { .. } => Err(crate::error::SomaError::Other(
295 "Custom strategy requires a user-provided coordinator".into(),
296 )),
297 }
298 }
299}
300
301impl GradientAggregator for GradientAggregation {
302 fn aggregate(&self, gradients: &[HashMap<String, Value>]) -> Result<HashMap<String, Value>> {
303 match self {
304 GradientAggregation::AllReduce | GradientAggregation::ParameterServer => {
305 Ok(gradients.first().cloned().unwrap_or_default())
307 }
308 GradientAggregation::Decentralized { .. } => {
309 Ok(gradients.first().cloned().unwrap_or_default())
310 }
311 }
312 }
313}
314
315impl StateAggregator for FederatedAggregation {
316 fn aggregate(&self, states: &[HashMap<String, Value>]) -> Result<HashMap<String, Value>> {
317 match self {
318 FederatedAggregation::FedAvg
319 | FederatedAggregation::FedProx { .. }
320 | FederatedAggregation::FedYogi { .. } => {
321 Ok(states.first().cloned().unwrap_or_default())
323 }
324 }
325 }
326}
327
328fn shard_value(value: &Value, n: usize) -> Vec<Value> {
330 match value {
331 Value::Tensor { values, shape } if !shape.is_empty() && shape[0] >= n => {
332 let rows = shape[0];
333 let row_size: usize = shape[1..].iter().product::<usize>().max(1);
334 let shard_rows = rows / n;
335 let mut shards = Vec::new();
336 for i in 0..n {
337 let start = i * shard_rows;
338 let end = if i == n - 1 { rows } else { start + shard_rows };
339 let flat_start = start * row_size;
340 let flat_end = end * row_size;
341 let shard_vals = values[flat_start..flat_end].to_vec();
342 let mut shard_shape = shape.clone();
343 shard_shape[0] = end - start;
344 shards.push(Value::tensor(shard_vals, shard_shape));
345 }
346 shards
347 }
348 _ => (0..n).map(|_| value.clone()).collect(),
349 }
350}
351
352#[cfg(test)]
353mod tests {
354 use super::*;
355
356 #[test]
357 fn default_is_local() {
358 assert!(matches!(
359 TrainingStrategy::default(),
360 TrainingStrategy::Local
361 ));
362 }
363
364 #[test]
365 fn serde_roundtrip_data_parallel() {
366 let strategy = TrainingStrategy::DataParallel {
367 num_replicas: 4,
368 aggregation: GradientAggregation::AllReduce,
369 };
370 let json = serde_json::to_string(&strategy).unwrap();
371 let parsed: TrainingStrategy = serde_json::from_str(&json).unwrap();
372 assert!(matches!(
373 parsed,
374 TrainingStrategy::DataParallel {
375 num_replicas: 4,
376 ..
377 }
378 ));
379 }
380
381 #[test]
382 fn serde_roundtrip_model_parallel() {
383 let strategy = TrainingStrategy::ModelParallel {
384 partitions: vec![
385 Partition {
386 node_ids: vec!["embed".into(), "backbone".into()],
387 target: RemoteTarget::Tag("gpu-0".into()),
388 },
389 Partition {
390 node_ids: vec!["head_a".into()],
391 target: RemoteTarget::Tag("gpu-1".into()),
392 },
393 ],
394 communication: CommunicationProtocol::Pipeline {
395 micro_batch_size: 4,
396 },
397 };
398 let json = serde_json::to_string(&strategy).unwrap();
399 let parsed: TrainingStrategy = serde_json::from_str(&json).unwrap();
400 assert!(matches!(parsed, TrainingStrategy::ModelParallel { .. }));
401 }
402
403 #[test]
404 fn serde_roundtrip_federated() {
405 let strategy = TrainingStrategy::Federated {
406 num_clients: 10,
407 rounds: 50,
408 aggregation: FederatedAggregation::FedProx { mu: 0.01 },
409 client_selection: ClientSelection::Random { fraction: 0.3 },
410 };
411 let json = serde_json::to_string(&strategy).unwrap();
412 let parsed: TrainingStrategy = serde_json::from_str(&json).unwrap();
413 assert!(matches!(
414 parsed,
415 TrainingStrategy::Federated {
416 num_clients: 10,
417 rounds: 50,
418 ..
419 }
420 ));
421 }
422
423 #[test]
424 fn serde_roundtrip_pbt() {
425 let strategy = TrainingStrategy::PopulationBased {
426 population_size: 20,
427 generations: 50,
428 exploit: ExploitStrategy::Truncation { fraction: 0.2 },
429 explore: ExploreStrategy::Perturbation { factor: 0.2 },
430 };
431 let json = serde_json::to_string(&strategy).unwrap();
432 let parsed: TrainingStrategy = serde_json::from_str(&json).unwrap();
433 assert!(matches!(
434 parsed,
435 TrainingStrategy::PopulationBased {
436 population_size: 20,
437 ..
438 }
439 ));
440 }
441}