rust_lstm_1025/
lib.rs

1use bmi_rs::bmi::{Bmi, BmiResult, Location, RefValues, ValueType, Values, register_model};
2use burn::nn::LstmState;
3use burn::prelude::*;
4use burn::record::{FullPrecisionSettings, Recorder};
5use burn_import::pytorch::PyTorchFileRecorder;
6use glob::glob;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::fs;
10
11use std::path::Path;
12
13mod nextgen_lstm;
14mod python;
15use nextgen_lstm::{NextgenLstm, vec_to_tensor};
16use python::convert_model;
17
18#[derive(Debug, Serialize, Deserialize)]
19struct ModelMetadata {
20    input_size: usize,
21    hidden_size: usize,
22    output_size: usize,
23    input_names: Vec<String>,
24    output_names: Vec<String>,
25}
26
27#[derive(Debug, Serialize, Deserialize)]
28struct TrainingScalars {
29    input_mean: Vec<f32>,
30    input_std: Vec<f32>,
31    output_mean: f32,
32    output_std: f32,
33}
34
35// Structure to hold a single model's components
36struct ModelInstance<B: Backend> {
37    model: NextgenLstm<B>,
38    metadata: ModelMetadata,
39    scalars: TrainingScalars,
40    lstm_state: Option<LstmState<B, 2>>,
41}
42
43// Macro to handle variable lookup
44macro_rules! match_var {
45    ($name:expr, $($pattern:pat => $result:expr),+ $(,)?) => {
46        match $name {
47            $($pattern => $result,)+
48            _ => Err(Box::new(std::io::Error::new(
49                std::io::ErrorKind::NotFound,
50                format!("Variable {} not found", $name)
51            )))
52        }
53    };
54}
55
56pub struct LstmBmi<B: Backend> {
57    // Ensemble of models
58    models: Vec<ModelInstance<B>>,
59    device: B::Device,
60
61    // Configuration
62    config_path: String,
63    area_sqkm: f32,
64    output_scale_factor_cms: f32,
65
66    // BMI variables stored as simple HashMap
67    variables: HashMap<String, Vec<f64>>,
68
69    // Variable metadata
70    input_var_names: Vec<&'static str>,
71    output_var_names: Vec<&'static str>,
72
73    // Time tracking
74    current_time: f64,
75    start_time: f64,
76    end_time: f64,
77    time_step: f64,
78}
79
80impl<B: Backend> LstmBmi<B> {
81    pub fn new(device: B::Device) -> Self {
82        // Define all BMI variable names
83        let input_vars = vec![
84            "atmosphere_water__liquid_equivalent_precipitation_rate",
85            "land_surface_air__temperature",
86            "land_surface_radiation~incoming~longwave__energy_flux",
87            "land_surface_air__pressure",
88            "atmosphere_air_water~vapor__relative_saturation",
89            "land_surface_radiation~incoming~shortwave__energy_flux",
90            "land_surface_wind__x_component_of_velocity",
91            "land_surface_wind__y_component_of_velocity",
92            // "basin__mean_of_elevation",
93            // "basin__mean_of_slope",
94        ];
95
96        let output_vars = vec![
97            "land_surface_water__runoff_volume_flux",
98            "land_surface_water__runoff_depth",
99        ];
100
101        // Initialize variables HashMap
102        let mut variables = HashMap::new();
103        for var in &input_vars {
104            variables.insert(var.to_string(), vec![0.0]);
105        }
106        // add variables here so bmi can't see them
107        variables.insert("basin__mean_of_elevation".to_string(), vec![0.0]);
108        variables.insert("basin__mean_of_slope".to_string(), vec![0.0]);
109
110        for var in &output_vars {
111            variables.insert(var.to_string(), vec![0.0]);
112        }
113
114        LstmBmi {
115            models: Vec::new(),
116            device,
117            config_path: String::new(),
118            area_sqkm: 0.0,
119            output_scale_factor_cms: 0.0,
120            variables,
121            input_var_names: input_vars,
122            output_var_names: output_vars,
123            current_time: 0.0,
124            start_time: 0.0,
125            end_time: 36000.0, // 10 hours default
126            time_step: 3600.0, // 1 hour
127        }
128    }
129
130    fn internal_to_external_name(&self, internal: &str) -> String {
131        let mapping = [
132            (
133                "DLWRF_surface",
134                "land_surface_radiation~incoming~longwave__energy_flux",
135            ),
136            ("PRES_surface", "land_surface_air__pressure"),
137            (
138                "SPFH_2maboveground",
139                "atmosphere_air_water~vapor__relative_saturation",
140            ),
141            (
142                "APCP_surface",
143                "atmosphere_water__liquid_equivalent_precipitation_rate",
144            ),
145            (
146                "DSWRF_surface",
147                "land_surface_radiation~incoming~shortwave__energy_flux",
148            ),
149            ("TMP_2maboveground", "land_surface_air__temperature"),
150            (
151                "UGRD_10maboveground",
152                "land_surface_wind__x_component_of_velocity",
153            ),
154            (
155                "VGRD_10maboveground",
156                "land_surface_wind__y_component_of_velocity",
157            ),
158            ("elev_mean", "basin__mean_of_elevation"),
159            ("slope_mean", "basin__mean_of_slope"),
160        ];
161
162        mapping
163            .iter()
164            .find(|(k, _)| *k == internal)
165            .map(|(_, v)| v.to_string())
166            .unwrap_or_else(|| internal.to_string())
167    }
168
169    fn load_single_model(&self, training_config_path: &Path) -> BmiResult<ModelInstance<B>> {
170        let training_config = fs::read_to_string(training_config_path)?;
171        let training_config: serde_yaml::Value = serde_yaml::from_str(&training_config)?;
172
173        // Find model path
174        let model_dir = training_config["run_dir"]
175            .as_str()
176            .ok_or("Missing run_dir")?
177            .replace(
178                "..",
179                training_config_path
180                    .parent()
181                    .unwrap()
182                    .parent()
183                    .unwrap()
184                    .parent()
185                    .unwrap()
186                    .to_str()
187                    .unwrap(),
188            );
189
190        let model_path = glob(&format!("{}/model_*.pt", model_dir))?
191            .next()
192            .ok_or("No model file found")??;
193
194        // Convert weights if needed
195        let model_folder = model_path.parent().unwrap();
196        let burn_dir = model_folder.join("burn");
197        let converted_path = burn_dir.join(model_path.file_name().unwrap());
198        let lock_file_path = burn_dir.join(".conversion.lock");
199
200        // Create burn directory if it doesn't exist
201        if !burn_dir.exists() {
202            fs::create_dir_all(&burn_dir)?;
203        }
204
205        // Check if conversion is needed
206        let needs_conversion = !converted_path.exists()
207            || !converted_path.with_extension("json").exists()
208            || !burn_dir.join("train_data_scaler.json").exists()
209            || !burn_dir.join("weights.json").exists();
210
211        if needs_conversion {
212            // Try to acquire lock
213            let mut lock_acquired = false;
214            let process_id = std::process::id();
215
216            loop {
217                // Try to create lock file atomically
218                match fs::OpenOptions::new()
219                    .write(true)
220                    .create_new(true)
221                    .open(&lock_file_path)
222                {
223                    Ok(mut file) => {
224                        // Write process ID to lock file for debugging
225                        use std::io::Write;
226                        writeln!(file, "Locked by process {}", process_id)?;
227                        lock_acquired = true;
228                        // println!("Process {} acquired conversion lock", process_id);
229                        break;
230                    }
231                    Err(_) => {
232                        // Lock file exists, another process is converting
233                        // println!(
234                        //     "Process {} waiting for model conversion (lock held by another process)...",
235                        //     process_id
236                        // );
237                        std::thread::sleep(std::time::Duration::from_millis(1000));
238
239                        // Check if conversion is complete
240                        if converted_path.exists()
241                            && converted_path.with_extension("json").exists()
242                            && burn_dir.join("train_data_scaler.json").exists()
243                            && burn_dir.join("weights.json").exists()
244                        {
245                            // println!("Process {} found completed conversion", process_id);
246                            break;
247                        }
248                    }
249                }
250            }
251
252            // If we acquired the lock, do the conversion
253            if lock_acquired {
254                println!(
255                    "Process {} converting PyTorch weights to Burn format for model: {}",
256                    process_id,
257                    model_path.display()
258                );
259
260                // Perform conversion
261                match convert_model(&model_path, &training_config_path) {
262                    Ok(_) => {
263                        println!("Process {} completed model conversion", process_id);
264                    }
265                    Err(e) => {
266                        // Clean up lock file on error
267                        let _ = fs::remove_file(&lock_file_path);
268                        return Err(Box::new(std::io::Error::new(
269                            std::io::ErrorKind::Other,
270                            format!("Model conversion failed: {}", e),
271                        )));
272                    }
273                }
274
275                // Remove lock file after successful conversion
276                fs::remove_file(&lock_file_path)?;
277                println!("Process {} released conversion lock", process_id);
278            }
279        } else {
280            // println!("Model already converted, skipping conversion");
281        }
282
283        // Load metadata
284        let metadata_str = fs::read_to_string(converted_path.with_extension("json"))?;
285        let metadata: ModelMetadata = serde_json::from_str(&metadata_str)?;
286
287        // Load scalars
288        let scalars_str = fs::read_to_string(burn_dir.join("train_data_scaler.json"))?;
289        let scalars: TrainingScalars = serde_json::from_str(&scalars_str)?;
290
291        // Load model
292        let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
293            .load(converted_path.into(), &self.device)?;
294
295        let mut model = NextgenLstm::init(
296            &self.device,
297            metadata.input_size,
298            metadata.hidden_size,
299            metadata.output_size,
300        );
301        model = model.load_record(record);
302        model.load_json_weights(
303            &self.device,
304            burn_dir.join("weights.json").to_str().unwrap(),
305        );
306
307        Ok(ModelInstance {
308            model,
309            metadata,
310            scalars,
311            lstm_state: None,
312        })
313    }
314
315    fn run_single_model(&mut self, model_idx: usize, inputs: &[f32]) -> BmiResult<f32> {
316        let model_instance = &mut self.models[model_idx];
317
318        // Scale inputs
319        let scaled_inputs: Vec<f32> = inputs
320            .iter()
321            .zip(&model_instance.scalars.input_mean)
322            .zip(&model_instance.scalars.input_std)
323            .map(
324                |((val, mean), std)| {
325                    if *std != 0.0 { (val - mean) / std } else { 0.0 }
326                },
327            )
328            .collect();
329
330        // Create input tensor
331        let input_tensor_data = vec_to_tensor(
332            &scaled_inputs,
333            vec![1, 1, model_instance.metadata.input_size],
334        );
335        let input_tensor = Tensor::from_data(input_tensor_data, &self.device);
336
337        // Forward pass
338        let (output, new_state) = model_instance
339            .model
340            .forward(input_tensor, model_instance.lstm_state.take());
341        model_instance.lstm_state = Some(new_state);
342
343        // Process output
344        let output_vec: Vec<f32> = output.into_data().to_vec().unwrap();
345        let output_value = output_vec[0];
346
347        // Denormalize
348        let surface_runoff_mm = (output_value * model_instance.scalars.output_std
349            + model_instance.scalars.output_mean)
350            .max(0.0);
351
352        Ok(surface_runoff_mm)
353    }
354
355    fn run_ensemble(&mut self) -> BmiResult<()> {
356        if self.models.is_empty() {
357            return Err("No models in ensemble")?;
358        }
359        // Run all models and collect outputs
360        let mut ensemble_outputs = Vec::new();
361        for i in 0..self.models.len() {
362            let input_names = self.models[i].metadata.input_names.clone();
363            // Gather inputs in the correct order
364            let mut inputs = Vec::new();
365            for name in &input_names {
366                let bmi_name = self.internal_to_external_name(name);
367                let value = self
368                    .variables
369                    .get(&bmi_name)
370                    .and_then(|v| v.first())
371                    .copied()
372                    .unwrap_or(0.0) as f32;
373                inputs.push(value);
374            }
375            let output = self.run_single_model(i, &inputs)?;
376            ensemble_outputs.push(output);
377        }
378
379        // Calculate mean of ensemble outputs
380        let mean_surface_runoff_mm = if !ensemble_outputs.is_empty() {
381            ensemble_outputs.iter().sum::<f32>() / ensemble_outputs.len() as f32
382        } else {
383            0.0
384        };
385
386        // Convert to output units
387        let surface_runoff_m = mean_surface_runoff_mm / 1000.0;
388        let surface_runoff_volume_m3_s = mean_surface_runoff_mm * self.output_scale_factor_cms;
389
390        // Set outputs
391        self.variables.insert(
392            "land_surface_water__runoff_depth".to_string(),
393            vec![surface_runoff_m as f64],
394        );
395        self.variables.insert(
396            "land_surface_water__runoff_volume_flux".to_string(),
397            vec![surface_runoff_volume_m3_s as f64],
398        );
399        Ok(())
400    }
401}
402
403impl<B: Backend> Bmi for LstmBmi<B> {
404    fn initialize(&mut self, config_file: &str) -> BmiResult<()> {
405        // println!("Initializing LSTM BMI with config: {}", config_file);
406        self.config_path = config_file.to_string();
407
408        // Load configuration
409        let config_path = Path::new(config_file);
410        let config_str = fs::read_to_string(config_path)?;
411        let config: serde_yaml::Value = serde_yaml::from_str(&config_str)?;
412
413        // Get all training config paths (ensemble)
414        let training_configs = config["train_cfg_file"]
415            .as_sequence()
416            .ok_or("train_cfg_file should be an array")?;
417
418        // println!("Loading ensemble of {} models", training_configs.len());
419
420        // Load each model in the ensemble
421        for (idx, config_value) in training_configs.iter().enumerate() {
422            let training_config_path = Path::new(
423                config_value
424                    .as_str()
425                    .ok_or(format!("train_cfg_file[{}] not a string", idx))?,
426            );
427
428            // println!(
429            //     "Loading model {}/{}: {}",
430            //     idx + 1,
431            //     training_configs.len(),
432            //     training_config_path.display()
433            // );
434
435            let model_instance = self.load_single_model(training_config_path)?;
436            self.models.push(model_instance);
437        }
438
439        // Get area from config
440        self.area_sqkm = config
441            .get("area_sqkm")
442            .ok_or("Missing area_sqkm")?
443            .as_f64()
444            .ok_or("area_sqkm not a number")? as f32;
445
446        self.output_scale_factor_cms =
447            (1.0 / 1000.0) * (self.area_sqkm * 1000.0 * 1000.0) * (1.0 / 3600.0);
448
449        // Set static inputs from config
450        let elevation = config
451            .get("elev_mean")
452            .and_then(|v| v.as_f64())
453            .unwrap_or(0.0);
454        let slope = config
455            .get("slope_mean")
456            .and_then(|v| v.as_f64())
457            .unwrap_or(0.0);
458
459        self.variables
460            .insert("basin__mean_of_elevation".to_string(), vec![elevation]);
461        self.variables
462            .insert("basin__mean_of_slope".to_string(), vec![slope]);
463
464        // Reset time
465        self.current_time = self.start_time;
466
467        // println!(
468        //     "LSTM BMI ensemble initialized successfully with {} models",
469        //     self.models.len()
470        // );
471        Ok(())
472    }
473
474    fn update(&mut self) -> BmiResult<()> {
475        self.run_ensemble()?;
476        self.current_time += self.time_step;
477        Ok(())
478    }
479
480    fn update_until(&mut self, then: f64) -> BmiResult<()> {
481        if then < self.current_time {
482            return Err(Box::new(std::io::Error::new(
483                std::io::ErrorKind::InvalidInput,
484                format!(
485                    "Target time {} is before current time {}",
486                    then, self.current_time
487                ),
488            )));
489        }
490
491        while self.current_time < then {
492            self.update()?;
493            if self.current_time > then {
494                self.current_time = then;
495            }
496        }
497        Ok(())
498    }
499
500    fn finalize(&mut self) -> BmiResult<()> {
501        self.models.clear();
502        Ok(())
503    }
504
505    fn get_component_name(&self) -> &str {
506        "NextGen LSTM BMI Ensemble"
507    }
508
509    fn get_input_item_count(&self) -> u32 {
510        self.input_var_names.len() as u32
511    }
512
513    fn get_output_item_count(&self) -> u32 {
514        self.output_var_names.len() as u32
515    }
516
517    fn get_input_var_names(&self) -> &[&str] {
518        &self.input_var_names
519    }
520
521    fn get_output_var_names(&self) -> &[&str] {
522        &self.output_var_names
523    }
524
525    fn get_var_grid(&self, name: &str) -> BmiResult<i32> {
526        if self.variables.contains_key(name) {
527            Ok(0) // Scalar grid
528        } else {
529            Err(Box::new(std::io::Error::new(
530                std::io::ErrorKind::NotFound,
531                format!("Variable {} not found", name),
532            )))
533        }
534    }
535
536    fn get_var_type(&self, name: &str) -> BmiResult<ValueType> {
537        if self.variables.contains_key(name) {
538            Ok(ValueType::F64)
539        } else {
540            Err(Box::new(std::io::Error::new(
541                std::io::ErrorKind::NotFound,
542                format!("Variable {} not found", name),
543            )))
544        }
545    }
546
547    fn get_var_units(&self, name: &str) -> BmiResult<&str> {
548        match_var!(name,
549            "atmosphere_water__liquid_equivalent_precipitation_rate" => Ok("mm h-1"),
550            "land_surface_air__temperature" => Ok("degK"),
551            "land_surface_radiation~incoming~longwave__energy_flux" => Ok("W m-2"),
552            "land_surface_air__pressure" => Ok("Pa"),
553            "atmosphere_air_water~vapor__relative_saturation" => Ok("kg kg-1"),
554            "land_surface_radiation~incoming~shortwave__energy_flux" => Ok("W m-2"),
555            "land_surface_wind__x_component_of_velocity" => Ok("m s-1"),
556            "land_surface_wind__y_component_of_velocity" => Ok("m s-1"),
557            "basin__mean_of_elevation" => Ok("m"),
558            "basin__mean_of_slope" => Ok("m km-1"),
559            "land_surface_water__runoff_volume_flux" => Ok("m3 s-1"),
560            "land_surface_water__runoff_depth" => Ok("m")
561        )
562    }
563
564    fn get_var_nbytes(&self, name: &str) -> BmiResult<u32> {
565        let itemsize = self.get_var_itemsize(name)?;
566        let values = self.get_value_ptr(name)?;
567        Ok(values.len() as u32 * itemsize)
568    }
569
570    fn get_var_location(&self, name: &str) -> BmiResult<Location> {
571        if self.variables.contains_key(name) {
572            Ok(Location::Node)
573        } else {
574            Err(Box::new(std::io::Error::new(
575                std::io::ErrorKind::NotFound,
576                format!("Variable {} not found", name),
577            )))
578        }
579    }
580
581    fn get_current_time(&self) -> f64 {
582        self.current_time
583    }
584
585    fn get_start_time(&self) -> f64 {
586        self.start_time
587    }
588
589    fn get_end_time(&self) -> f64 {
590        self.end_time
591    }
592
593    fn get_time_units(&self) -> &str {
594        "seconds"
595    }
596
597    fn get_time_step(&self) -> f64 {
598        self.time_step
599    }
600
601    fn get_value_ptr(&self, name: &str) -> BmiResult<RefValues<'_>> {
602        Ok(self
603            .variables
604            .get(name)
605            .map(|v| RefValues::F64(v))
606            .ok_or_else(|| {
607                Box::new(std::io::Error::new(
608                    std::io::ErrorKind::NotFound,
609                    format!("Variable {} not found", name),
610                ))
611            })?)
612    }
613
614    fn get_value_at_indices(&self, name: &str, inds: &[u32]) -> BmiResult<Values> {
615        let values = self.variables.get(name).ok_or_else(|| {
616            std::io::Error::new(
617                std::io::ErrorKind::NotFound,
618                format!("Variable {} not found", name),
619            )
620        })?;
621
622        let mut result = Vec::with_capacity(inds.len());
623        for &idx in inds {
624            if (idx as usize) >= values.len() {
625                return Err(Box::new(std::io::Error::new(
626                    std::io::ErrorKind::InvalidInput,
627                    format!("Index {} out of bounds", idx),
628                )));
629            }
630            result.push(values[idx as usize]);
631        }
632        Ok(Values::F64(result))
633    }
634
635    fn set_value(&mut self, name: &str, src: RefValues) -> BmiResult<()> {
636        if let RefValues::F64(values) = src {
637            if let Some(var) = self.variables.get_mut(name) {
638                if values.len() != var.len() {
639                    return Err(Box::new(std::io::Error::new(
640                        std::io::ErrorKind::InvalidInput,
641                        "Source array size mismatch",
642                    )));
643                }
644                var.copy_from_slice(values);
645                Ok(())
646            } else {
647                Err(Box::new(std::io::Error::new(
648                    std::io::ErrorKind::NotFound,
649                    format!("Variable {} not found", name),
650                )))
651            }
652        } else {
653            Err(Box::new(std::io::Error::new(
654                std::io::ErrorKind::InvalidInput,
655                "Type mismatch: expected F64",
656            )))
657        }
658    }
659
660    fn set_value_at_indices(&mut self, name: &str, inds: &[u32], src: RefValues) -> BmiResult<()> {
661        if let RefValues::F64(values) = src {
662            if values.len() != inds.len() {
663                return Err(Box::new(std::io::Error::new(
664                    std::io::ErrorKind::InvalidInput,
665                    "Source array size doesn't match indices count",
666                )));
667            }
668
669            let var = self.variables.get_mut(name).ok_or_else(|| {
670                std::io::Error::new(
671                    std::io::ErrorKind::NotFound,
672                    format!("Variable {} not found", name),
673                )
674            })?;
675
676            for (i, &idx) in inds.iter().enumerate() {
677                if (idx as usize) >= var.len() {
678                    return Err(Box::new(std::io::Error::new(
679                        std::io::ErrorKind::InvalidInput,
680                        format!("Index {} out of bounds", idx),
681                    )));
682                }
683                var[idx as usize] = values[i];
684            }
685            Ok(())
686        } else {
687            Err(Box::new(std::io::Error::new(
688                std::io::ErrorKind::InvalidInput,
689                "Type mismatch: expected F64",
690            )))
691        }
692    }
693}
694
695// Export function for C binding
696#[unsafe(no_mangle)]
697pub extern "C" fn register_bmi_lstm(handle: *mut ffi::Bmi) -> *mut ffi::Bmi {
698    // type Backend = burn::backend::NdArray;
699    type Backend = burn::backend::Candle;
700    let device = Default::default();
701
702    let model = LstmBmi::<Backend>::new(device);
703    register_model(handle, model);
704    handle
705}