tangram_features/
compute.rs

1use crate::{
2	bag_of_words::BagOfWordsFeatureGroup,
3	bag_of_words_cosine_similarity::BagOfWordsCosineSimilarityFeatureGroup,
4	identity::IdentityFeatureGroup, normalized::NormalizedFeatureGroup,
5	one_hot_encoded::OneHotEncodedFeatureGroup, word_embedding::WordEmbeddingFeatureGroup,
6	FeatureGroup,
7};
8use ndarray::prelude::*;
9use tangram_table::prelude::*;
10
11/// Compute features as an `Array` of `f32`s.
12pub fn compute_features_array_f32(
13	table: &TableView,
14	feature_groups: &[FeatureGroup],
15	progress: &impl Fn(),
16) -> Array2<f32> {
17	let n_features = feature_groups
18		.iter()
19		.map(|feature_group| feature_group.n_features())
20		.sum::<usize>();
21	let mut features = Array::zeros((table.nrows(), n_features));
22	let mut feature_index = 0;
23	for feature_group in feature_groups.iter() {
24		let n_features_in_group = feature_group.n_features();
25		let slice = s![.., feature_index..feature_index + n_features_in_group];
26		let features = features.slice_mut(slice);
27		compute_features_array_f32_for_feature_group(table, feature_group, features, progress);
28		feature_index += n_features_in_group;
29	}
30	features
31}
32
33fn compute_features_array_f32_for_feature_group(
34	table: &TableView,
35	feature_group: &FeatureGroup,
36	features: ArrayViewMut2<f32>,
37	progress: &impl Fn(),
38) {
39	match &feature_group {
40		FeatureGroup::Identity(feature_group) => {
41			compute_features_array_f32_for_identity_feature_group(
42				table,
43				feature_group,
44				features,
45				progress,
46			)
47		}
48		FeatureGroup::Normalized(feature_group) => {
49			compute_features_array_f32_for_normalized_feature_group(
50				table,
51				feature_group,
52				features,
53				progress,
54			)
55		}
56		FeatureGroup::OneHotEncoded(feature_group) => {
57			compute_features_array_f32_for_one_hot_encoded_feature_group(
58				table,
59				feature_group,
60				features,
61				progress,
62			)
63		}
64		FeatureGroup::BagOfWords(feature_group) => {
65			compute_features_array_f32_for_bag_of_words_feature_group(
66				table,
67				feature_group,
68				features,
69				progress,
70			)
71		}
72		FeatureGroup::BagOfWordsCosineSimilarity(feature_group) => {
73			compute_features_array_f32_for_bag_of_words_cosine_similarity_feature_group(
74				table,
75				feature_group,
76				features,
77				progress,
78			)
79		}
80		FeatureGroup::WordEmbedding(feature_group) => {
81			compute_features_array_f32_for_word_embedding_feature_group(
82				table,
83				feature_group,
84				features,
85				progress,
86			)
87		}
88	}
89}
90
91fn compute_features_array_f32_for_identity_feature_group(
92	table: &TableView,
93	feature_group: &IdentityFeatureGroup,
94	features: ArrayViewMut2<f32>,
95	progress: &impl Fn(),
96) {
97	// Get the source column.
98	let source_column = table
99		.columns()
100		.iter()
101		.find(|column| column.name() == Some(&feature_group.source_column_name))
102		.unwrap();
103	feature_group.compute_array_f32(features, source_column.view(), progress);
104}
105
106fn compute_features_array_f32_for_normalized_feature_group(
107	table: &TableView,
108	feature_group: &NormalizedFeatureGroup,
109	features: ArrayViewMut2<f32>,
110	progress: &impl Fn(),
111) {
112	// Get the source column.
113	let source_column = table
114		.columns()
115		.iter()
116		.find(|column| column.name() == Some(&feature_group.source_column_name))
117		.unwrap();
118	feature_group.compute_array_f32(features, source_column.view(), progress)
119}
120
121fn compute_features_array_f32_for_one_hot_encoded_feature_group(
122	table: &TableView,
123	feature_group: &OneHotEncodedFeatureGroup,
124	features: ArrayViewMut2<f32>,
125	progress: &impl Fn(),
126) {
127	// Get the source column.
128	let source_column = table
129		.columns()
130		.iter()
131		.find(|column| column.name() == Some(&feature_group.source_column_name))
132		.unwrap();
133	feature_group.compute_array_f32(features, source_column.view(), progress);
134}
135
136fn compute_features_array_f32_for_bag_of_words_feature_group(
137	table: &TableView,
138	feature_group: &BagOfWordsFeatureGroup,
139	features: ArrayViewMut2<f32>,
140	progress: &impl Fn(),
141) {
142	// Get the source column.
143	let source_column = table
144		.columns()
145		.iter()
146		.find(|column| column.name() == Some(&feature_group.source_column_name))
147		.unwrap();
148	feature_group.compute_array_f32(features, source_column.view(), progress);
149}
150
151fn compute_features_array_f32_for_bag_of_words_cosine_similarity_feature_group(
152	table: &TableView,
153	feature_group: &BagOfWordsCosineSimilarityFeatureGroup,
154	features: ArrayViewMut2<f32>,
155	progress: &impl Fn(),
156) {
157	// Get the source column.
158	// Get the data for the source column.
159	let source_column_a = table
160		.columns()
161		.iter()
162		.find(|column| column.name().unwrap() == feature_group.source_column_name_a)
163		.unwrap();
164	let source_column_b = table
165		.columns()
166		.iter()
167		.find(|column| column.name().unwrap() == feature_group.source_column_name_b)
168		.unwrap();
169	feature_group.compute_array_f32(
170		features,
171		source_column_a.view(),
172		source_column_b.view(),
173		progress,
174	);
175}
176
177fn compute_features_array_f32_for_word_embedding_feature_group(
178	table: &TableView,
179	feature_group: &WordEmbeddingFeatureGroup,
180	features: ArrayViewMut2<f32>,
181	progress: &impl Fn(),
182) {
183	// Get the source column.
184	let source_column = table
185		.columns()
186		.iter()
187		.find(|column| column.name() == Some(&feature_group.source_column_name))
188		.unwrap();
189	feature_group.compute_array_f32(features, source_column.view(), progress);
190}
191
192/// Compute features as a `Table`.
193pub fn compute_features_table(
194	table: &TableView,
195	feature_groups: &[FeatureGroup],
196	progress: &impl Fn(u64),
197) -> Table {
198	let mut features = Table::new(Vec::new(), Vec::new());
199	for feature_group in feature_groups.iter() {
200		compute_features_table_for_feature_group(table, feature_group, &mut features, progress)
201	}
202	features
203}
204
205fn compute_features_table_for_feature_group(
206	table: &TableView,
207	feature_group: &FeatureGroup,
208	features: &mut Table,
209	progress: &impl Fn(u64),
210) {
211	match &feature_group {
212		FeatureGroup::Identity(feature_group) => compute_features_table_for_identity_feature_group(
213			table,
214			feature_group,
215			features,
216			progress,
217		),
218		FeatureGroup::Normalized(feature_group) => {
219			compute_features_table_for_normalized_feature_group(
220				table,
221				feature_group,
222				features,
223				progress,
224			)
225		}
226		FeatureGroup::OneHotEncoded(_) => unimplemented!(),
227		FeatureGroup::BagOfWords(feature_group) => {
228			compute_features_table_for_bag_of_words_feature_group(
229				table,
230				feature_group,
231				features,
232				progress,
233			)
234		}
235		FeatureGroup::BagOfWordsCosineSimilarity(feature_group) => {
236			compute_features_table_for_bag_of_words_cosine_similarity_feature_group(
237				table,
238				feature_group,
239				features,
240				progress,
241			)
242		}
243		FeatureGroup::WordEmbedding(feature_group) => {
244			compute_features_table_for_word_embedding_feature_group(
245				table,
246				feature_group,
247				features,
248				progress,
249			)
250		}
251	};
252}
253
254fn compute_features_table_for_identity_feature_group(
255	table: &TableView,
256	feature_group: &IdentityFeatureGroup,
257	features: &mut Table,
258	progress: &impl Fn(u64),
259) {
260	let column = table
261		.columns()
262		.iter()
263		.find(|column| column.name().unwrap() == feature_group.source_column_name)
264		.unwrap();
265	let feature_column = feature_group.compute_table(column.view(), progress);
266	features.columns_mut().push(feature_column);
267}
268
269fn compute_features_table_for_normalized_feature_group(
270	table: &TableView,
271	feature_group: &NormalizedFeatureGroup,
272	features: &mut Table,
273	progress: &impl Fn(u64),
274) {
275	let column = table
276		.columns()
277		.iter()
278		.find(|column| column.name().unwrap() == feature_group.source_column_name)
279		.unwrap();
280	let feature_column = feature_group.compute_table(column.view(), progress);
281	features.columns_mut().push(feature_column);
282}
283
284fn compute_features_table_for_bag_of_words_feature_group(
285	table: &TableView,
286	feature_group: &BagOfWordsFeatureGroup,
287	features: &mut Table,
288	progress: &impl Fn(u64),
289) {
290	// Get the data for the source column.
291	let source_column = table
292		.columns()
293		.iter()
294		.find(|column| column.name().unwrap() == feature_group.source_column_name)
295		.unwrap();
296	let columns = feature_group.compute_table(source_column.view(), progress);
297	for column in columns {
298		features.columns_mut().push(column);
299	}
300}
301
302fn compute_features_table_for_bag_of_words_cosine_similarity_feature_group(
303	table: &TableView,
304	feature_group: &BagOfWordsCosineSimilarityFeatureGroup,
305	features: &mut Table,
306	progress: &impl Fn(u64),
307) {
308	// Get the data for the source column.
309	let source_column_a = table
310		.columns()
311		.iter()
312		.find(|column| column.name().unwrap() == feature_group.source_column_name_a)
313		.unwrap();
314	let source_column_b = table
315		.columns()
316		.iter()
317		.find(|column| column.name().unwrap() == feature_group.source_column_name_b)
318		.unwrap();
319	let column =
320		feature_group.compute_table(source_column_a.view(), source_column_b.view(), progress);
321	features.columns_mut().push(column);
322}
323
324fn compute_features_table_for_word_embedding_feature_group(
325	table: &TableView,
326	feature_group: &WordEmbeddingFeatureGroup,
327	features: &mut Table,
328	progress: &impl Fn(u64),
329) {
330	// Get the data for the source column.
331	let source_column = table
332		.columns()
333		.iter()
334		.find(|column| column.name().unwrap() == feature_group.source_column_name)
335		.unwrap();
336	let columns = feature_group.compute_table(source_column.view(), progress);
337	for column in columns {
338		features.columns_mut().push(column);
339	}
340}
341
342pub fn compute_features_array_value<'a>(
343	table: &TableView<'a>,
344	feature_groups: &[FeatureGroup],
345	progress: &impl Fn(),
346) -> Array2<TableValue<'a>> {
347	let n_features = feature_groups.iter().map(|g| g.n_features()).sum::<usize>();
348	let mut features = Array::from_elem((table.nrows(), n_features), TableValue::Unknown);
349	let mut feature_index = 0;
350	for feature_group in feature_groups.iter() {
351		let n_features_in_group = feature_group.n_features();
352		let slice = s![.., feature_index..feature_index + n_features_in_group];
353		let features = features.slice_mut(slice);
354		compute_features_array_value_for_feature_group(table, feature_group, features, progress);
355		feature_index += n_features_in_group;
356	}
357	features
358}
359
360fn compute_features_array_value_for_feature_group(
361	table: &TableView,
362	feature_group: &FeatureGroup,
363	features: ArrayViewMut2<tangram_table::TableValue>,
364	progress: &impl Fn(),
365) {
366	match &feature_group {
367		FeatureGroup::Identity(feature_group) => {
368			compute_features_array_value_for_identity_feature_group(
369				table,
370				feature_group,
371				features,
372				progress,
373			)
374		}
375		FeatureGroup::Normalized(feature_group) => {
376			compute_features_array_value_for_normalized_feature_group(
377				table,
378				feature_group,
379				features,
380				progress,
381			)
382		}
383		FeatureGroup::OneHotEncoded(_) => unimplemented!(),
384		FeatureGroup::BagOfWords(feature_group) => {
385			compute_features_array_value_for_bag_of_words_feature_group(
386				table,
387				feature_group,
388				features,
389				progress,
390			)
391		}
392		FeatureGroup::BagOfWordsCosineSimilarity(feature_group) => {
393			compute_features_array_value_for_bag_of_words_cosine_similarity_feature_group(
394				table,
395				feature_group,
396				features,
397				progress,
398			)
399		}
400		FeatureGroup::WordEmbedding(feature_group) => {
401			compute_features_array_value_for_word_embedding_feature_group(
402				table,
403				feature_group,
404				features,
405				progress,
406			)
407		}
408	}
409}
410
411fn compute_features_array_value_for_identity_feature_group(
412	table: &TableView,
413	feature_group: &IdentityFeatureGroup,
414	features: ArrayViewMut2<tangram_table::TableValue>,
415	progress: &impl Fn(),
416) {
417	let source_column = table
418		.columns()
419		.iter()
420		.find(|column| column.name().unwrap() == feature_group.source_column_name)
421		.unwrap();
422	feature_group.compute_array_value(features, source_column.view(), progress);
423}
424
425fn compute_features_array_value_for_normalized_feature_group(
426	table: &TableView,
427	feature_group: &NormalizedFeatureGroup,
428	features: ArrayViewMut2<tangram_table::TableValue>,
429	progress: &impl Fn(),
430) {
431	let source_column = table
432		.columns()
433		.iter()
434		.find(|column| column.name().unwrap() == feature_group.source_column_name)
435		.unwrap();
436	feature_group.compute_array_value(features, source_column.view(), progress);
437}
438
439fn compute_features_array_value_for_bag_of_words_feature_group(
440	table: &TableView,
441	feature_group: &BagOfWordsFeatureGroup,
442	features: ArrayViewMut2<tangram_table::TableValue>,
443	progress: &impl Fn(),
444) {
445	// Get the data for the source column.
446	let source_column = table
447		.columns()
448		.iter()
449		.find(|column| column.name().unwrap() == feature_group.source_column_name)
450		.unwrap();
451	feature_group.compute_array_value(features, source_column.view(), progress);
452}
453
454fn compute_features_array_value_for_bag_of_words_cosine_similarity_feature_group(
455	table: &TableView,
456	feature_group: &BagOfWordsCosineSimilarityFeatureGroup,
457	features: ArrayViewMut2<tangram_table::TableValue>,
458	progress: &impl Fn(),
459) {
460	// Get the data for the source column.
461	let source_column_a = table
462		.columns()
463		.iter()
464		.find(|column| column.name().unwrap() == feature_group.source_column_name_a)
465		.unwrap();
466	let source_column_b = table
467		.columns()
468		.iter()
469		.find(|column| column.name().unwrap() == feature_group.source_column_name_b)
470		.unwrap();
471	feature_group.compute_array_value(
472		features,
473		source_column_a.view(),
474		source_column_b.view(),
475		progress,
476	);
477}
478
479fn compute_features_array_value_for_word_embedding_feature_group(
480	table: &TableView,
481	feature_group: &WordEmbeddingFeatureGroup,
482	features: ArrayViewMut2<tangram_table::TableValue>,
483	progress: &impl Fn(),
484) {
485	// Get the data for the source column.
486	let source_column = table
487		.columns()
488		.iter()
489		.find(|column| column.name().unwrap() == feature_group.source_column_name)
490		.unwrap();
491	feature_group.compute_array_value(features, source_column.view(), progress);
492}