tf_binding_rs/
occupancy.rs

1use crate::error::MotifError;
2use crate::fasta::reverse_complement;
3use crate::types::*;
4use polars::lazy::dsl::*;
5use polars::prelude::*;
6use std::collections::HashMap;
7use std::fmt::format;
8use std::fs::File;
9use std::io::{BufRead, BufReader};
10use std::iter::Peekable;
11
12const PSEUDOCOUNT: f64 = 0.0001;
13const RT: f64 = 2.5;
14
15/// Advances the iterator until a MOTIF line is found
16fn skip_until_motif<I>(lines: &mut Peekable<I>)
17where
18    I: Iterator<Item = Result<String, std::io::Error>>,
19{
20    while let Some(Ok(line)) = lines.peek() {
21        if line.starts_with("MOTIF") {
22            break;
23        }
24        lines.next();
25    }
26}
27
28/// Parses a single PWM from the iterator
29fn parse_pwm<I>(lines: &mut I) -> Result<Option<(String, PWM)>, MotifError>
30where
31    I: Iterator<Item = Result<String, std::io::Error>>,
32{
33    // Get motif ID from MOTIF line
34    let motif_line = match lines.next() {
35        Some(Ok(line)) if line.starts_with("MOTIF") => line,
36        _ => return Ok(None),
37    };
38
39    let motif_id = motif_line
40        .split_whitespace()
41        .nth(1)
42        .ok_or_else(|| MotifError::InvalidFileFormat("Missing motif ID".into()))?
43        .to_string();
44
45    // Skip header lines
46    for _ in 0..2 {
47        lines.next();
48    }
49
50    // Read PWM rows until we hit a non-PWM line
51    let pwm_rows: Vec<Vec<f64>> = lines
52        .take_while(|line| {
53            line.as_ref()
54                .map(|l| l.starts_with(|c: char| c.is_whitespace() || c == '0' || c == '1'))
55                .unwrap_or(false)
56        })
57        .map(|line| {
58            let line = line.map_err(|e| MotifError::Io(e))?;
59            let values: Vec<f64> = line
60                .split_whitespace()
61                .map(|s| s.parse::<f64>())
62                .collect::<Result<Vec<_>, _>>()
63                .map_err(|e| MotifError::InvalidFileFormat(format!("Invalid PWM value: {}", e)))?;
64
65            Ok(values)
66        })
67        .collect::<Result<Vec<_>, MotifError>>()?;
68
69    if pwm_rows.is_empty() {
70        return Err(MotifError::InvalidFileFormat("Empty PWM".into()));
71    }
72
73    // Create PWM DataFrame
74    let pwm = DataFrame::new(vec![
75        Column::new(
76            "A".into(),
77            pwm_rows.iter().map(|row| row[0]).collect::<Vec<f64>>(),
78        ),
79        Column::new(
80            "C".into(),
81            pwm_rows.iter().map(|row| row[1]).collect::<Vec<f64>>(),
82        ),
83        Column::new(
84            "G".into(),
85            pwm_rows.iter().map(|row| row[2]).collect::<Vec<f64>>(),
86        ),
87        Column::new(
88            "T".into(),
89            pwm_rows.iter().map(|row| row[3]).collect::<Vec<f64>>(),
90        ),
91    ])
92    .map_err(|e| MotifError::DataError(e.to_string()))?;
93
94    Ok(Some((motif_id, pwm)))
95}
96
97/// Reads Position Weight Matrices (PWMs) from a MEME format file
98///
99/// This function parses a MEME-formatted file containing one or more Position Weight Matrices,
100/// each identified by a unique motif ID. The PWMs represent DNA binding motifs where each position
101/// contains probabilities for the four nucleotides (A, C, G, T).
102///
103/// # Arguments
104/// * `filename` - Path to the MEME format file to read
105///
106/// # Returns
107/// * `Result<PWMCollection, MotifError>` - A HashMap where keys are motif IDs and values are their corresponding PWMs
108///
109/// # Errors
110/// * `MotifError::Io` - If the file cannot be opened or read
111/// * `MotifError::InvalidFileFormat` - If the file format is invalid or no PWMs are found
112/// * `MotifError::DataError` - If there are issues creating the PWM DataFrame
113///
114/// # Example
115/// ```ignore
116/// use tf_binding_rs::occupancy::read_pwm_files;
117///
118/// let pwms = read_pwm_files("path/to/motifs.meme").unwrap();
119/// for (motif_id, pwm) in pwms {
120///     println!("Found motif: {}", motif_id);
121/// }
122/// ```
123///
124/// # Format
125/// The input file should be in MEME format, where each PWM is preceded by a "MOTIF" line
126/// containing the motif ID, followed by the matrix values.
127pub fn read_pwm_files(filename: &str) -> Result<PWMCollection, MotifError> {
128    let file = File::open(filename)?;
129    let reader = BufReader::new(file);
130    let mut lines = reader.lines().peekable();
131    let mut pwms = HashMap::new();
132
133    // Skip header until first MOTIF
134    skip_until_motif(&mut lines);
135
136    // Parse all PWMs
137    while let Some((id, pwm)) = parse_pwm(&mut lines)? {
138        pwms.insert(id, pwm);
139        skip_until_motif(&mut lines);
140    }
141
142    if pwms.is_empty() {
143        return Err(MotifError::InvalidFileFormat("No PWMs found".into()));
144    }
145
146    Ok(pwms)
147}
148
149/// Reads Position Weight Matrices (PWMs) from a MEME format file and converts them to Energy Weight Matrices (EWMs)
150///
151/// This function reads PWMs and converts them to EWMs using the formula ddG = -RT ln(p_b,i / p_c,i), where:
152/// - p_b,i is the probability of base b
153/// - p_c,i is the probability of the consensus base
154/// - ddG is relative free energy
155///
156/// The conversion process:
157/// 1. Reads PWMs from the MEME file
158/// 2. Adds pseudocounts to handle zeros in the PWM
159/// 3. Normalizes each position by the most frequent letter to get relative Kd
160/// 4. Converts to EWM using the formula above
161///
162/// # Arguments
163/// * `filename` - Path to the MEME format file containing PWMs
164///
165/// # Returns
166/// * `Result<EWMCollection, MotifError>` - A HashMap where keys are motif IDs and values are their corresponding EWMs
167///
168/// # Errors
169/// * `MotifError::Io` - If the file cannot be opened or read
170/// * `MotifError::InvalidFileFormat` - If the file format is invalid or no PWMs are found
171/// * `MotifError::DataError` - If there are issues creating or manipulating the matrices
172///
173/// # Constants
174/// * `PSEUDOCOUNT` - Value (default: 0.0001) added to every matrix position to handle zeros
175/// * `RT` - The RT value (default: 2.5) used in the ddG formula in kJ/mol
176///
177/// # Example
178/// ```ignore
179/// use tf_binding_rs::occupancy::read_pwm_to_ewm;
180///
181/// let ewms = read_pwm_to_ewm("path/to/motifs.meme").unwrap();
182/// for (motif_id, ewm) in ewms {
183///     println!("Processed EWM for motif: {}", motif_id);
184/// }
185/// ```
186pub fn read_pwm_to_ewm(filename: &str) -> Result<EWMCollection, MotifError> {
187    let pwms = read_pwm_files(filename)?;
188
189    let ewms: EWMCollection = pwms
190        .into_iter()
191        .map(|(id, pwm)| {
192            let normalized = pwm
193                .clone()
194                .lazy()
195                .select([
196                    (col("A") + lit(PSEUDOCOUNT)).alias("A_pseudo"),
197                    (col("C") + lit(PSEUDOCOUNT)).alias("C_pseudo"),
198                    (col("G") + lit(PSEUDOCOUNT)).alias("G_pseudo"),
199                    (col("T") + lit(PSEUDOCOUNT)).alias("T_pseudo"),
200                ])
201                .with_column(
202                    max_horizontal([
203                        col("A_pseudo"),
204                        col("C_pseudo"),
205                        col("G_pseudo"),
206                        col("T_pseudo"),
207                    ])
208                    .unwrap()
209                    .alias("max_val"),
210                )
211                .select([
212                    (col("A_pseudo") / col("max_val")).alias("A_norm"),
213                    (col("C_pseudo") / col("max_val")).alias("C_norm"),
214                    (col("G_pseudo") / col("max_val")).alias("G_norm"),
215                    (col("T_pseudo") / col("max_val")).alias("T_norm"),
216                ])
217                .select([
218                    (-lit(RT) * col("A_norm").log(std::f64::consts::E)).alias("A"),
219                    (-lit(RT) * col("C_norm").log(std::f64::consts::E)).alias("C"),
220                    (-lit(RT) * col("G_norm").log(std::f64::consts::E)).alias("G"),
221                    (-lit(RT) * col("T_norm").log(std::f64::consts::E)).alias("T"),
222                ])
223                .collect()
224                .map_err(|e| MotifError::DataError(e.to_string()))?;
225
226            Ok((id, normalized))
227        })
228        .collect::<Result<HashMap<_, _>, MotifError>>()?;
229
230    Ok(ewms)
231}
232
233/// Scans both strands of a sequence with an energy matrix to compute binding energies
234///
235/// This function calculates the energy score for each possible k-mer in the sequence on both
236/// forward and reverse strands. For each position, it extracts the k-mer subsequence and
237/// calculates the total energy score by summing individual nucleotide contributions.
238///
239/// # Arguments
240/// * `seq` - The DNA sequence to scan
241/// * `ewm` - Energy Weight Matrix as a DataFrame where columns represent A,C,G,T and rows are positions
242///
243/// # Returns
244/// * `Result<(Vec<f64>, Vec<f64>), MotifError>` - A tuple containing forward and reverse strand scores
245///
246/// # Errors
247/// * `MotifError::DataError` - If there are issues extracting values from the EWM DataFrame
248///
249/// # Example
250/// ```ignore
251/// use tf_binding_rs::occupancy::energy_landscape;
252///
253/// let seq = "ATCGATCG";
254/// let (fwd_scores, rev_scores) = energy_landscape(&seq, &ewm).unwrap();
255/// println!("Forward strand scores: {:?}", fwd_scores);
256/// println!("Reverse strand scores: {:?}", rev_scores);
257/// ```
258pub fn energy_landscape(seq: &str, ewm: &EWM) -> Result<(Vec<f64>, Vec<f64>), MotifError> {
259    let motif_len = ewm.height();
260    let n_scores = seq.len() - motif_len + 1;
261    let r_seq = reverse_complement(seq)?;
262
263    let mut fscores = vec![0.0; n_scores];
264    let mut rscores = vec![0.0; n_scores];
265
266    for (pos, (fscore, rscore)) in fscores.iter_mut().zip(rscores.iter_mut()).enumerate() {
267        let f_kmer = &seq[pos..pos + motif_len];
268        let r_kmer = &r_seq[pos..pos + motif_len];
269
270        *fscore = (0..motif_len)
271            .map(|i| {
272                ewm.column(&f_kmer[i..i + 1])
273                    .unwrap()
274                    .get(i)
275                    .unwrap()
276                    .try_extract::<f64>()
277                    .map_err(|e| MotifError::DataError(e.to_string()))
278            })
279            .sum::<Result<f64, MotifError>>()?;
280
281        *rscore = (0..motif_len)
282            .map(|i| {
283                ewm.column(&r_kmer[i..i + 1])
284                    .unwrap()
285                    .get(i)
286                    .unwrap()
287                    .try_extract::<f64>()
288                    .map_err(|e| MotifError::DataError(e.to_string()))
289            })
290            .sum::<Result<f64, MotifError>>()?;
291    }
292
293    rscores.reverse();
294    Ok((fscores, rscores))
295}
296
297/// Computes the occupancy landscape by scanning sequence with the energy matrix
298///
299/// This function calculates the probability of TF binding at each position by:
300/// 1. Computing energy scores using `energy_landscape()`
301/// 2. Converting energy scores to occupancy probabilities using the formula:
302///    occupancy = 1 / (1 + exp(energy - mu))
303/// where mu is the chemical potential of the transcription factor.
304///
305/// # Arguments
306/// * `seq` - The DNA sequence to scan
307/// * `ewm` - Energy Weight Matrix as a DataFrame
308/// * `mu` - Chemical potential of the transcription factor
309///
310/// # Returns
311/// * `Result<(Vec<f64>, Vec<f64>), MotifError>` - A tuple containing forward and reverse strand occupancies
312///
313/// # Errors
314/// * `MotifError::DataError` - If there are issues calculating energy scores
315///
316/// # Example
317/// ```ignore
318/// use tf_binding_rs::occupancy::occupancy_landscape;
319///
320/// let seq = "ATCGATCG";
321/// let mu = -3.0;
322/// let (fwd_occ, rev_occ) = occupancy_landscape(&seq, &ewm, mu).unwrap();
323/// println!("Forward strand occupancy: {:?}", fwd_occ);
324/// ```
325pub fn occupancy_landscape(
326    seq: &str,
327    ewm: &EWM,
328    mu: f64,
329) -> Result<(Vec<f64>, Vec<f64>), MotifError> {
330    let (fscores, rscores) = energy_landscape(seq, ewm)?;
331
332    let foccupancies: Vec<f64> = fscores
333        .into_iter()
334        .map(|s| 1.0 / (1.0 + (s - mu).exp()))
335        .collect();
336
337    let roccupancies: Vec<f64> = rscores
338        .into_iter()
339        .map(|s| 1.0 / (1.0 + (s - mu).exp()))
340        .collect();
341
342    Ok((foccupancies, roccupancies))
343}
344
345/// Computes the occupancy landscape for multiple transcription factors
346///
347/// This function calculates binding probabilities for each TF in the collection and combines
348/// them into a single DataFrame. The results include both forward and reverse strand occupancies
349/// for each TF, with values padded to match the sequence length.
350///
351/// # Arguments
352/// * `seq` - The DNA sequence to scan
353/// * `ewms` - Collection of Energy Weight Matrices, where keys are TF names
354/// * `mu` - Chemical potential of the transcription factors
355///
356/// # Returns
357/// * `Result<DataFrame, MotifError>` - DataFrame containing occupancy predictions where:
358///   - Rows represent positions in the sequence
359///   - Columns are named "{TF_NAME}_F" and "{TF_NAME}_R" for forward/reverse orientations
360///   - Values indicate predicted occupancy (0-1) at each position
361///
362/// # Errors
363/// * `MotifError::DataError` - If there are issues creating the DataFrame or calculating occupancies
364///
365/// # Example
366/// ```ignore
367/// use tf_binding_rs::occupancy::total_landscape;
368///
369/// let seq = "ATCGATCG";
370/// let mu = -3.0;
371/// let landscape = total_landscape(&seq, &ewm_collection, mu).unwrap();
372/// println!("Combined occupancy landscape:\n{}", landscape);
373/// ```
374pub fn total_landscape(seq: &str, ewms: &EWMCollection, mu: f64) -> Result<DataFrame, MotifError> {
375    let seq_len = seq.len();
376    let mut columns: Vec<Column> = Vec::new();
377    let mut names: Vec<String> = Vec::new();
378
379    for (name, ewm) in ewms {
380        let (fscores, rscores) = occupancy_landscape(seq, ewm, mu)?;
381
382        // pad scores to sequence length
383        let amount_to_add = seq_len - fscores.len();
384        let mut fscores_padded = fscores.clone();
385        let mut rscores_padded = rscores.clone();
386        fscores_padded.extend(vec![0.0; amount_to_add]);
387        rscores_padded.extend(vec![0.0; amount_to_add]);
388
389        // create series for forward and reverse scores
390        columns.push(Column::new(format!("{}_F", name).into(), fscores_padded));
391        columns.push(Column::new(format!("{}_R", name).into(), rscores_padded));
392        names.push(name.to_string());
393    }
394
395    DataFrame::new(columns).map_err(|e| MotifError::DataError(e.to_string()))
396}