atelier/klout/src/gen_data.rs
2025-01-16 20:05:09 -05:00

173 lines
3.5 KiB
Rust

use std::{
cmp::Ordering,
ops::{Add, Div, Mul, Sub},
path::Path,
};
use eyre::{eyre, Result};
use rustc_hash::FxHashMap;
type GramMap<const N: usize, T> = FxHashMap<[u8; N], T>;
#[derive(Default, Debug)]
struct Grams<T> {
grams1: GramMap<1, T>,
grams2: GramMap<2, T>,
grams3: GramMap<3, T>,
grams4: GramMap<4, T>,
}
impl<T> Grams<T>
where
T: Div<Output = T> + Copy,
{
fn divide_by(&mut self, n: T) {
divide_by(&mut self.grams1, n);
divide_by(&mut self.grams2, n);
divide_by(&mut self.grams3, n);
divide_by(&mut self.grams4, n);
}
}
impl<T> Grams<T>
where
T: Copy
+ PartialOrd
+ Default
+ Sub<Output = T>
+ Add<Output = T>
+ Div<Output = T>
+ Mul<Output = T>,
{
fn normalize(&mut self, omin: T, omax: T) {
normalize(&mut self.grams1, omin, omax);
normalize(&mut self.grams2, omin, omax);
normalize(&mut self.grams3, omin, omax);
normalize(&mut self.grams4, omin, omax);
}
}
fn divide_by<const N: usize, T: Div<Output = T> + Copy>(
grams: &mut GramMap<N, T>,
n: T,
) {
for v in grams.values_mut() {
*v = *v / n;
}
}
fn normalize<const N: usize, T>(grams: &mut GramMap<N, T>, omin: T, omax: T)
where
T: Copy
+ PartialOrd
+ Default
+ Sub<Output = T>
+ Add<Output = T>
+ Div<Output = T>
+ Mul<Output = T>,
{
let max = grams
.values()
.copied()
.max_by(|&a, &b| {
if a > b {
Ordering::Greater
} else {
Ordering::Less
}
})
.unwrap_or(Default::default());
let min = grams
.values()
.copied()
.min_by(|&a, &b| {
if a > b {
Ordering::Greater
} else {
Ordering::Less
}
})
.unwrap_or(Default::default());
for v in grams.values_mut() {
*v = map_to_range(*v, min, max, omin, omax);
}
}
#[test]
fn test_normalize() {
let mut input = GramMap::<1, f64>::default();
input.insert([b'a'], 500.);
input.insert([b'b'], 300.);
input.insert([b'c'], 100.);
input.insert([b'd'], 125.);
normalize(&mut input, 0., 100.);
assert_eq!(input[b"a"], 100.);
assert_eq!(input[b"b"], 50.);
assert_eq!(input[b"c"], 0.);
assert_eq!(input[b"d"], 6.25);
}
// maps a number from range [amin, amax] to range [bmin, bmax]
fn map_to_range<V>(v: V, amin: V, amax: V, bmin: V, bmax: V) -> V
where
V: Sub<Output = V>
+ Add<Output = V>
+ Mul<Output = V>
+ Div<Output = V>
+ Copy,
{
bmin + (((v - amin) * (bmax - bmin)) / (amax - amin))
}
#[test]
fn test_map_to_range() {
assert_eq!(map_to_range(40, 0, 100, 0, 10), 4);
assert_eq!(map_to_range(60, 50, 100, 5, 10), 6);
assert_eq!(map_to_range(55.5, 55., 56., 0., 1.), 0.5);
}
type GramsCounts = Grams<usize>;
type GramsFreqs = Grams<usize>;
fn gen_data_file(path: &Path) -> Result<GramsCounts> {
let data = std::fs::read_to_string(path)?;
let mut grams = Grams::default();
for win in data.as_bytes().windows(4) {
*grams.grams1.entry([win[0]]).or_insert(0) += 1;
*grams.grams2.entry([win[0], win[1]]).or_insert(0) += 1;
*grams.grams3.entry([win[0], win[1], win[2]]).or_insert(0) += 1;
*grams
.grams4
.entry([win[0], win[1], win[2], win[3]])
.or_insert(0) += 1;
}
// TODO: We lose a few N<4 grams here, but it's probably not that big of a deal
Ok(grams)
}
fn gen_data(inputs: Vec<String>) -> Result<GramsCounts> {
let mut grams = Grams::default();
for dir in inputs {
for de in walkdir::WalkDir::new(dir).into_iter() {
let de = de?;
if de.file_type().is_file() {
grams = grams.combine(gen_data_file(de.path())?);
}
}
}
Ok(grams)
}
fn main() -> Result<()> {
let mut dirs: Vec<String> = std::env::args().skip(1).collect();
let grams = gen_data(dirs)?;
Ok(())
}