atelier/klout/src/main.rs

445 lines
9.4 KiB
Rust

use core::f64;
use eyre::Result;
use rand::{rngs::ThreadRng, Rng};
use std::{
collections::HashMap,
io::{self, Write},
sync::mpsc::{Receiver, Sender},
};
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
enum Hand {
Left,
Right,
}
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
enum Digit {
Pinky,
Ring,
Middle,
Index,
}
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
struct Finger {
hand: Hand,
digit: Digit,
}
const MATRIX_COUNT: usize = 30;
macro_rules! F {
($h:ident, $d:ident) => {
Finger {
hand: Hand::$h,
digit: Digit::$d,
}
};
}
const LPINKY: Finger = F!(Left, Pinky);
const LRING: Finger = F!(Left, Ring);
const LMIDDLE: Finger = F!(Left, Middle);
const LINDEX: Finger = F!(Left, Index);
const RPINKY: Finger = F!(Right, Pinky);
const RRING: Finger = F!(Right, Ring);
const RMIDDLE: Finger = F!(Right, Middle);
const RINDEX: Finger = F!(Right, Index);
#[rustfmt::skip]
const MATRIX_EFFORT: &[f64; MATRIX_COUNT] = &[
100., 6., 10., 6., 40., 40., 10., 10., 10., 100.,
50., 2., 0.1, 0.1, 15., 15., 0.1, 0.1, 2., 50.,
10., 14., 10., 20., 40., 30., 6., 14., 6., 100.,
];
#[rustfmt::skip]
const MATRIX_FINGERS: &[Finger; MATRIX_COUNT] = &[
LPINKY, LRING, LMIDDLE, LINDEX, LINDEX, RINDEX, RINDEX, RMIDDLE, RRING, RPINKY,
LPINKY, LRING, LMIDDLE, LINDEX, LINDEX, RINDEX, RINDEX, RMIDDLE, RRING, RPINKY,
LRING, LMIDDLE, LINDEX, LINDEX, LINDEX, RINDEX, RINDEX, RMIDDLE, RRING, RPINKY,
];
const CHARACTERS: &str = "abcdefghijklmnopqrstuvwxyz.,/'";
const _: () = assert!(CHARACTERS.len() == MATRIX_COUNT);
const fn char_data_max_index() -> usize {
let bytes = CHARACTERS.as_bytes();
let mut max = 0;
let len = bytes.len();
let mut i = 0;
while i < len {
let b = bytes[i];
max = if b < max { max } else { b };
i += 1;
}
max as usize + 1
}
// https://www3.nd.edu/~busiforc/handouts/cryptography/Letter%20Frequencies.html#Relative_frequencies_of_letters
const CHAR_FREQ_LOOKUP: [f64; char_data_max_index()] = {
let mut data: [f64; char_data_max_index()] = [0.; char_data_max_index()];
macro_rules! d {
($ch:expr, $val:expr) => {{
data[$ch as usize] = $val * 1000.;
}};
}
d!('a', 0.08167);
d!('b', 0.01492);
d!('c', 0.02782);
d!('d', 0.04253);
d!('e', 0.12702);
d!('f', 0.02228);
d!('g', 0.02015);
d!('h', 0.06094);
d!('i', 0.06966);
d!('j', 0.00153);
d!('k', 0.00772);
d!('l', 0.04025);
d!('m', 0.02406);
d!('n', 0.06749);
d!('o', 0.07507);
d!('p', 0.01929);
d!('q', 0.00095);
d!('r', 0.05987);
d!('s', 0.06327);
d!('t', 0.09056);
d!('u', 0.02758);
d!('v', 0.00978);
d!('w', 0.02360);
d!('x', 0.00150);
d!('y', 0.01974);
d!('z', 0.00074);
data
};
#[derive(Copy, Clone)]
struct Layout {
key_indices: [usize; char_data_max_index()],
}
impl Layout {
const fn from_key_matrix(m: &[char; MATRIX_COUNT]) -> Self {
let mut key_indices = [0; char_data_max_index()];
let mut i = 0;
while i < m.len() {
let ch = m[i];
key_indices[ch as usize] = i;
i += 1;
}
Self { key_indices }
}
fn to_key_matrix(self) -> [char; MATRIX_COUNT] {
let mut out = ['0'; MATRIX_COUNT];
for ch in CHARACTERS.chars() {
out[self.key_indices[ch as usize]] = ch;
}
out
}
fn key_to_xy(&self, ch: char) -> (u8, u8) {
let index = self.key_indices[ch as usize];
let y = index / 10;
let x = index % 10;
(x as u8, y as u8)
}
fn balance_penalty(&self) -> f64 {
let km = self.to_key_matrix();
let mut left_freq = 0.;
for x in 0..5 {
for y in 0..3 {
let i = y * 10 + x;
let ch = km[i];
left_freq += CHAR_FREQ_LOOKUP[ch as usize];
}
}
let mut right_freq = 0.;
for x in 5..10 {
for y in 0..3 {
let i = y * 10 + x;
let ch = km[i];
right_freq += CHAR_FREQ_LOOKUP[ch as usize];
}
}
(left_freq - right_freq).abs() * 10.
}
}
#[test]
fn test_key_to_xy() {
assert_eq!(INITIAL_LAYOUT.key_to_xy('g'), (4, 1));
}
fn key_effort(ch: char, l: &Layout) -> f64 {
CHAR_FREQ_LOOKUP[ch as usize] * MATRIX_EFFORT[l.key_indices[ch as usize]]
}
/*
#[rustfmt::skip]
const INITIAL_LAYOUT: Layout = Layout::from_key_matrix(&[
'q', 'w', 'f', 'p', 'b', 'j', 'l', 'u', 'y', '\'',
'a', 'r', 's', 't', 'g', 'm', 'n', 'e', 'i', 'o',
'x', 'c', 'd', 'v', 'z', 'k', 'h', ',', '.', '/',
]);
*/
#[rustfmt::skip]
const INITIAL_LAYOUT: Layout = Layout::from_key_matrix(&[
'q', 'w', 'e', 'r', 't', 'y', 'u', 'i', 'o', 'p',
'a', 's', 'd', 'f', 'g', 'h', 'j', 'k', 'l', '\'',
'z', 'x', 'c', 'v', 'b', 'n', 'm', ',', '.', '/',
]);
type NGramFreqs = HashMap<&'static str, f64>;
static mut BIGRAM_FREQS: Option<NGramFreqs> = None;
unsafe fn init_bigram_freqs() {
let freqs = [
("th", 0.03882543),
("he", 0.03681),
("in", 0.02283899),
("er", 0.02178042),
("an", 0.02140460),
("re", 0.01749394),
("nd", 0.01571977),
("on", 0.01418244),
("en", 0.01383239),
("at", 0.01335523),
("ou", 0.01285485),
("ed", 0.01275779),
("ha", 0.01274742),
("to", 0.01169655),
("or", 0.01151094),
("it", 0.01134891),
("is", 0.01109877),
("hi", 0.01092302),
("es", 0.01092301),
("ng", 0.01053385),
]
.into_iter()
.map(|(s, v)| (s, v * 4000.))
.collect::<NGramFreqs>();
unsafe { BIGRAM_FREQS = Some(freqs) };
}
#[allow(static_mut_refs)]
fn bigram_effort(bigram: &str, l: &Layout) -> f64 {
let bigrams = unsafe { BIGRAM_FREQS.as_ref().unwrap_unchecked() };
let mut eff = bigrams.get(bigram).copied().unwrap_or(1.);
let ch1 = bigram.as_bytes()[0];
let ch2 = bigram.as_bytes()[1];
let finger1 = MATRIX_FINGERS[l.key_indices[ch1 as usize]];
let finger2 = MATRIX_FINGERS[l.key_indices[ch2 as usize]];
let ch1xy = l.key_to_xy(ch1 as char);
let ch2xy = l.key_to_xy(ch2 as char);
let x_diff = ch1xy.0.abs_diff(ch2xy.0);
let y_diff = ch1xy.1.abs_diff(ch2xy.1);
if finger1 == finger2 {
if y_diff > 1 {
eff *= 1000.;
} else {
eff *= 500.;
}
}
if finger1.hand != finger2.hand {
eff /= 2.;
} else if x_diff == 1 {
let mult = if (finger1.hand == Hand::Left && ch1xy.0 < ch2xy.0)
|| (finger1.hand == Hand::Right && ch1xy.0 > ch2xy.0)
{
10.
} else {
1.
};
if y_diff == 0 {
eff /= 1000. * mult;
} else {
eff /= 250. * mult;
}
} else {
eff *= (x_diff + 1) as f64;
eff *= (y_diff + 1) as f64;
}
eff
}
static mut ALL_BIGRAMS: Option<Vec<String>> = None;
unsafe fn init_all_bigrams() {
let it = CHARACTERS.chars().flat_map(|ch1| {
CHARACTERS
.chars()
.map(move |ch2| [ch1 as u8, ch2 as u8])
.map(|b| String::from_utf8(b.to_vec()).unwrap())
});
unsafe { ALL_BIGRAMS = Some(it.collect()) }
}
#[allow(static_mut_refs)]
fn all_bigrams() -> &'static [String] {
unsafe { ALL_BIGRAMS.as_ref().unwrap_unchecked() }
}
const N_WORKERS: usize = 6;
fn mutate_layout(layout: &mut Layout, rng: &mut ThreadRng, max_swaps: usize) {
let num_swaps = rng.gen_range(1..=max_swaps);
for _ in 0..num_swaps {
let ch1 = CHARACTERS
.chars()
.nth(rng.gen_range(0..CHARACTERS.len()))
.unwrap();
let ch2 = CHARACTERS
.chars()
.nth(rng.gen_range(0..CHARACTERS.len()))
.unwrap();
layout.key_indices.swap(ch1 as usize, ch2 as usize);
}
}
struct WorkerThreadInit {
initial_layout: Layout,
report: Sender<(f64, Layout)>,
/// The maximum amount of deviation from the min_effort value allowed before
/// reverting back to the currently-known best layout
max_dev: f64,
/// The maximum number of swaps to make between iterations
max_swaps: usize,
}
fn worker_thread_run(init: WorkerThreadInit) {
let mut rng = rand::thread_rng();
let mut layout = init.initial_layout;
let mut best_layout = layout;
let mut min_effort: f64 = f64::MAX;
loop {
let mut eff: f64 =
CHARACTERS.chars().map(|v| key_effort(v, &layout)).sum();
eff += all_bigrams()
.iter()
.map(|b| bigram_effort(b, &layout))
.sum::<f64>();
eff *= layout.balance_penalty();
if eff < min_effort {
min_effort = eff;
best_layout = layout;
init.report.send((eff, layout)).unwrap();
} else if (eff - min_effort) > init.max_dev {
layout = best_layout;
}
mutate_layout(&mut layout, &mut rng, init.max_swaps);
}
}
fn print_layout(layout: Layout) {
let km = layout.to_key_matrix();
for chunk in km.chunks(MATRIX_COUNT / 3) {
for ch in chunk {
print!("{ch}");
}
println!();
}
}
fn clear_screen() {
print!("\x1B[2J\x1B[H");
io::stdout().flush().unwrap();
}
fn ui_thread_run(rx: Receiver<(f64, Layout)>) {
for (eff, layout) in rx.iter() {
clear_screen();
print_layout(layout);
println!();
println!("{eff}");
}
}
fn aggregator_thread_run(
layout: Layout,
rx: Receiver<(f64, Layout)>,
tx: Sender<(f64, Layout)>,
) {
let mut best_layout = layout;
let mut min_effort = f64::MAX;
for (eff, layout) in rx.iter() {
if eff < min_effort {
best_layout = layout;
min_effort = eff;
tx.send((min_effort, best_layout)).unwrap();
}
}
}
fn main() -> Result<()> {
unsafe {
init_all_bigrams();
init_bigram_freqs();
};
let (tx_agg, rx_agg) = std::sync::mpsc::channel();
let (tx_ui, rx_ui) = std::sync::mpsc::channel();
let ui_thread = std::thread::spawn(|| {
ui_thread_run(rx_ui);
});
let aggregator_thread = std::thread::spawn(|| {
aggregator_thread_run(INITIAL_LAYOUT, rx_agg, tx_ui);
});
let worker_threads = (0..N_WORKERS)
.map(|_| {
let tx = tx_agg.clone();
std::thread::spawn(move || {
worker_thread_run(WorkerThreadInit {
initial_layout: INITIAL_LAYOUT,
report: tx,
max_dev: 1000.,
max_swaps: 10,
})
})
})
.collect::<Vec<_>>();
ui_thread.join().unwrap();
aggregator_thread.join().unwrap();
for j in worker_threads {
j.join().unwrap();
}
Ok(())
}