Major day 2 optimization

bram-benchmarks
Bram 2025-12-03 23:59:56 +01:00
parent 94201f6868
commit 32242a43b9
2 changed files with 180 additions and 37 deletions

View File

@ -1,50 +1,189 @@
use crate::utils;
use std::ops::RangeInclusive;
use std::collections::HashMap;
pub fn answer(text : String) -> ( u64, u64 ) {
( text.split(",").flat_map(to_range).filter(|x| is_invalid_by(x, 2)).sum()
, text.split(",").flat_map(to_range).filter(is_invalid).sum()
)
use crate::utils;
pub fn answer(text : String) -> ( u128, u128 ) {
text.trim()
.split(",")
.flat_map(Range::from)
.map(|range| range.invalid_id_sum())
.fold((0, 0), |(x1, y1), (x2, y2)| (x1 + x2, y1 + y2))
}
fn is_invalid(n : &u64) -> bool {
for b in 2..=10 {
if is_invalid_by(n, b) {
return true;
struct Range {
low : String,
high : String,
}
impl Range {
fn from(s : &str) -> Vec<Range> {
let mut ranges : Vec<Range> = Vec::new();
let (low, high) = match s.split_once("-") {
Some(pair) => pair,
None => return Vec::new(),
};
// Sanitize input - for AoC, optional
if low.len() > high.len() { return Vec::new() }
if low.bytes().any(|c| c < b'0' || c > b'9') { return Vec::new() }
if high.bytes().any(|c| c < b'0' || c > b'9') { return Vec::new() }
for size in low.len()..=high.len() {
let range = Range::new(
if low.len() == size {
String::from(low)
} else {
std::iter::once('1')
.chain(std::iter::repeat('0').take(size - 1))
.collect()
},
if high.len() == size {
String::from(high)
} else {
std::iter::repeat('9').take(size).collect()
},
);
ranges.push(range);
}
ranges
}
false
}
fn has_pattern(&self, s : &str) -> bool {
assert!(s.len() > 0);
assert!(s.len() < self.len());
assert_eq!(self.len() % s.len(), 0);
fn is_invalid_by(n : &u64, by : usize) -> bool {
let s : String = n.to_string();
let size : usize = s.len() / by;
for (l, c) in self.low.chars().zip(s.chars().cycle()) {
if l < c {
break;
} else if l > c {
return false;
}
}
if s.len() % by == 0 {
for i in 1..by {
if s[0..size] != s[i * size..(i + 1) * size] {
for (h, c) in self.high.chars().zip(s.chars().cycle()) {
if c < h {
break;
} else if c > h {
return false;
}
}
true
} else {
false
}
fn invalid_id_sum(&self) -> ( u128, u128 ) {
// Part 1
( if self.len() % 2 == 0 {
self.total_of_pattern_size(self.len() / 2)
} else {
0
}
// Part 2
, self.pattern_hashmap().values().sum()
)
}
fn len(&self) -> usize {
self.low.len()
}
fn new(low : String, high : String) -> Range {
// Verify they're the same length
assert_eq!(low.len(), high.len());
// Verify they're all 0-9
assert!(low.chars().all(|c| '0' <= c && c <= '9'));
assert!(high.chars().all(|c| '0' <= c && c <= '9'));
// Verify they're in order
for (l, h) in low.chars().zip(high.chars()) {
assert!(l <= h);
if l < h {
break;
}
}
Range{ low : low, high : high }
}
// Create a HashMap that counts the number of patterns for each value
fn pattern_hashmap(&self) -> HashMap<usize, u128> {
let mut h : HashMap<usize, u128> = HashMap::new();
// Populate the HashMap
for size in 1..=(self.len() / 2) {
if self.len() % size == 0 {
let mut patterns : u128 = self.total_of_pattern_size(size);
// We'll double count a few patterns.
// For example,
//
// 247247 247247
// |------|------|
//
// this number is a pattern of 6, but it is also a pattern of 3! (lol)
// 6-patterns always catch ALL 3-patterns. In fact, all
// k-patterns are always caught by n-patterns iff k | n.
// Therefore, if we want to accurately count the number of
// 6-patterns that AREN'T already counted as 1-patterns,
// 2-patterns or 3-patterns, we must subtract our total by
// those amounts.
for (key, value) in h.iter() {
if size % key == 0 {
patterns -= value;
}
}
h.insert(size, patterns);
}
}
h
}
// Count all invalid IDs where the pattern is of a given size.
// Then count the numbers of those chunks.
// These can then be re-used to count the full value efficiently.
fn total_of_pattern_size(&self, size : usize) -> u128 {
// 1. Find the range of all patterns of size `size`
let inf_str: &str = self.low.get(0..size).unwrap();
let sup_str: &str = self.high.get(0..size).unwrap();
let inf : u128 = if self.has_pattern(inf_str) {
utils::str_to_u128(inf_str).unwrap()
} else {
utils::str_to_u128(inf_str).unwrap() + 1
};
let sup : u128 = if self.has_pattern(sup_str) {
utils::str_to_u128(sup_str).unwrap()
} else {
utils::str_to_u128(sup_str).unwrap() - 1
};
// No patterns exist!
if inf > sup {
return 0;
}
// 2. Calculate the sum of the pattern snippets
let pattern_sum : u128 = ((sup - inf + 1) * (inf + sup)) / 2;
// 3. Calculate the sum of the patterns when they repeat themselves.
// For example:
// Step 2 calculated the sum of x = 123 + 124 + 125
// Step 3 calculates the sum of
// y = 123123123 + 124124124 + 125125125
// by realizing that:
// y = 1001001 * (123 + 124 + 125) = 100100100 * x
// = (10^6 + 10^3 + 10^0) * x
let layers : usize = self.len() / size;
let factor : u128 = (0..layers).map(|layer| 10u128.pow((size * layer) as u32)).sum();
factor * pattern_sum
}
}
fn to_range(s : &str) -> RangeInclusive<u64> {
let empty = 1..=0;
if let Some((low, high)) = s.split_once("-") {
match ( utils::str_to_u64(low), utils::str_to_u64(high) ) {
( Some(l), Some(h) ) =>
l..=h,
_ =>
// Empty iterator
empty,
}
} else {
empty
}
}

View File

@ -2,8 +2,8 @@ use std::fs;
pub mod diagnostics;
pub fn char_to_u8(s : char) -> Option<u8> {
s.to_string().parse::<u8>().ok()
pub fn char_to_u8(c : char) -> Option<u8> {
c.to_digit(10).and_then(|n| u8::try_from(n).ok())
}
// pub fn char_to_u64(s : char) -> Option<u64> {
@ -43,3 +43,7 @@ pub fn str_to_i16(s : &str) -> Option<i16> {
pub fn str_to_u64(s : &str) -> Option<u64> {
s.trim().to_string().parse::<u64>().ok()
}
pub fn str_to_u128(s : &str) -> Option<u128> {
s.trim().to_string().parse::<u128>().ok()
}