§Weightlessly Build a Tree in Const Rust
2025-01-08, Daniel Pfeiffer
Life ain’t fair. Reality may be that you get more apples than oranges, and even less peaches, when you really wanted cherries. So, when life gives you lemons, plant a lemon tree!
For when you need to generate such uneven outcomes, the rand crate has its weighted methods. But they report O(n) time. Even CS students know that lookup should only be O(log n). The specialised choose_rand crate doesn’t do better.
To reduce this weighty processing, be prepared! Build a complete binary tree in advance, rather than go through your data each time. And for static data, no better time to do this, than compile time. That’s the challenge I gave myself. POV: I managed quite well.
Let’s say we want to return the letters of the alphabet by their actual frequency in some language. E.g. in English ‘E’ would be by far the most frequent result. For illustration let’s imagine a wooden language that has only few letters and bizarre frequencies. Assembling random three letter words should give ‘CAB’ more often than ‘FED’, and rarely ‘JIG’.
weights! {
49 => 'A',
45 => 'B',
35 => 'C',
31 => 'D',
27 => 'E',
21 => 'F',
17 => 'G',
14 => 'H',
10 => 'I',
7 => 'J',
}
From this seed, let’s plant a tree.
Fun fact: the German ‘Buchstaben’ (i.e. ‘letters’) literally means ‘beech sticks’, as those were used to lay out ᚱᚢᚾᛖᛉ, I mean runes. ‘Buch’ and its cognate ‘book’ is derived from that. I digress… Now that doesn’t buy us anything for easy lookup. So let’s traverse the tree depth first from left to right. We’ll add up the weights as we go and instead give each node a range from inclusive lower, to exclusive upper bound. This is the opposite of how binary trees are usually built. Instead of sorting values into it, we project our input values across the whole range of u8
to achieve this sorting.
We can now quickly find any random u8
in our tree. This maps nicely to our numbers, except that the upper bound for 'G'
is not storable in a u8
. We must prune our tree: if we made each upper bound inclusive, i.e. one less, that would work – for integers. For floats it would not, because 0.0 .. 1.0
is typically the float random range. And even if the range were bigger, we can’t subtract one without leaving gaps between the nodes. Our floaty continuum would be interrupted. So we need to subtract one only for integer types. Technically for floats this should make the upper bound exclusive, but the difference is so tiny, it doesn’t matter.
§Sawing Through Rust
Now trees with references and maybe heap-based nodes are hard in Rust. In Constland no trees should grow. Luckily binary trees can easily be flattened into arrays. Thanks to the regularity of the tree, doubling in width at every level, as an array it’s really easy to navigate. We can read row by row, like the lines of a book, where paragraphs virtually form one long line.
const fn parent(i: usize) -> usize {
(i - 1) / 2
}
const fn left_child(i: usize) -> usize {
2 * i + 1
}
const fn right_child(i: usize) -> usize {
2 * i + 2
}
Here you see the flattened tree, with the upper bounds adapted. Arcs indicate where the rows have been crimped together.
Down to business! Processing at compile time requires const
. Until 1.79 rolled out, my inverse assumption was that const
also implies compile time. After all the result should always be the same. So, parameters permitting, optimise it out! As so often, some annoying edge case must have prevented that intuitive automatism. But at least const {}
now makes it possible, even though it introduced a surprising semantic difference to code in a const fn
. More weirdly, what the block produces, isn’t const
. You can modify inside the object, if you let mut
it. Misleading keyword chosen…
Rust’s type system, powerful as it may be, is nonetheless rather limiting in what it can express. Our generic weight type W
must be numeric. So it has W::MIN
and W::MAX
, which we need. But that a primitive type has and as a generic type must expose these constants is not expressible, i.e. they can’t be read in generic code. Also there’s no way of saying that our generic type is numeric, so we can’t use literals. Even if we could, for floats we’d have to write 0.0
, else 0
. Luckily we can get zero from Default::default()
. Other numbers need something more substantial. We could get one by dividing any of our non-zero inputs by itself, and then work our way up from that – I hear the tree groaning in the wind.
Independently of being able to specify the range of W
in terms of W::MIN
and W::MAX
or less, it could be useful to later have access to the chosen limits. As the defaults will be widely used, i.e. mostly the same, they might best be part of the type, rather than taking up memory in the struct
. But <W, const MIN: W, const MAX: W>
ruffles the feathers of Rust: “the type of const parameters must not depend on other generic parameters.” Putting that on the impl
instead says: “the const parameter MIN
is not constrained by the impl trait, self type, or predicates.” I don’t have a trait
and we were just barred from putting it on the self type. As for predicates and what it says in the notes “must map each value to a distinct output value” I have no clue what that’s trying to get from me. I really appreciate Rust going to great lengths to prevent accidents. But often it’s hard to guess what might have blocked a useful seeming way of doing things. Then it just feels arbitrary.
Thanks to operator overloading, we can do arithmetic in generics. OTOH, anything implemented directly is not available. And that eliminates a lot of stuff, like checked_add()
or overflowing_add()
. So, colliding head on with generics makes non-trait impls a broken paradigm, expressively poorer than inheritance. Hopefully a way can be found to align the two worlds, rather than excluding maybe half the Rust library from generic code!
Next shocker: overloading, even for primitive types, means implicitly going through traits. And those still do not allow const
implementations. There are two related compiler milestones which were scheduled for 2024, but didn’t make it. Even once that becomes possible, it might take ages to make all those methods first unstable const
and later stable. For now const
generic functions are like your neighbour’s huge tree: wonderful in principle, while in practice blocking out your sunlight. What else can we do?
Rust isn’t yet capable of overloading by parameter type, except for forms of self
as the first parameter. However we don’t own the primitive types, so we can’t just add impl u8 { const fn weights(self, …) }
. Newtype to the rescue: a wrapper around each of the numeric primitives does allow to get this. Then (passing in one of the user provided inputs for later type inference) we can let x = 0; Wrapper(x).weights(x, 1…)
and let y = 0.0; Wrapper(y).weights(y, 1.0…)
– or so I thought. This worked nicely as long as I had only one integer and one float wrapper. As soon as I added more types, the compiler complained about ambiguity, even forgetting that it had known how to distinguish at least integers from floats.
The reason is that type inference has a weird gap: In let x = 256
the compiler can look into the future and know the exact type, complaining “literal out of range for u8
.” Likewise in let x = 1; -x
the compiler can complain “the trait Neg
is not implemented for u8
.” Conversely, given a signed type, we can thus implicitly call a method like Neg::neg()
. Whereas an explicit call x.checked_add()
is deemed ambiguous – the compiler suddenly again forgets the type. For explicit methods the type must annoyingly be established before the call. So this would have worked: Wrapper(0_u8).weights(0, 1…)
. But then we lose the advantage of derived types.
I wish for the like of <T: {number}>
or <T: {float}>
. And it should then know to access T::MAX
, whatever it takes to get there. Alternately or additionally in the future we’ll hopefully be able to say <type_of(x)>::MAX
. The num crate to the rescue, I hear you whisper in my ear. Alas they don’t dare to take the leap of faith to a stable version 1.0. And even if they did, they operate by trait methods – so not for const
, yada, yada, yada.
Orignally I had an enum Way
to advise traversal on the next steps to take. Even though it’s just a plain enum
, for ==
we must derive PartialEq
on it. But then, even though its repr
is a primitive type, comparison falls back to that trait
method – and again off the tree, because of const
. At least this time we can get lucky, because pattern matching (the same comparison, duh) does work. And there’s a nice wrapper matches!()
, less readable than ==
, but at least it gets the job done. On second thought, by pulling the initial deep descent out of the loop, we can eliminate the enum
. All further iterations can be discerned by only one condition. While our tree may shake in the storm, we won’t be thrown off so easily!
And in hindsight, when testing for minimum Rust version, it turns out I naturally used features that only just became stable in 1.83. That’s a happy coincidence, and at the same time limiting to aspiring users.
§The Solution
I find the crate name weight_matchers
catchy. However I’m still struggling with internal naming, given that match
is a keyword. Anyway, assuming well behaved values, the only thing that must happen at run-time is very easy and O(log n). Even better: the more weight is concentrated near the root of the tree, the more this tends towards O(1). For now there is no node type, but it might be introduced for better readability. Or, if the optimiser doesn’t, the bounds might be put into SIMD arrays, for faster lookup.
pub struct Weights<W, V, const LEN: usize> {
nodes: [(W, W, V); LEN],
}
impl<W: PartialOrd, V, const LEN: usize> Weights<W, V, LEN> {
pub fn get(&self, matching: W) -> &V {
// start comparing at the root
let mut i = 0;
loop {
let node = &self.nodes[i];
if matching < node.0 {
i = left_child(i);
} else if matching > node.1 {
i = right_child(i);
} else {
return &node.2;
}
}
}
}
There is also corresponding get_mut
. Then there are get_clamped
with bounds checks, for when you’re not sure of input validity and get_clamped_mut
. These variants give the nearest value. Whereas the unclamped variants have the normal array bounds checks, which panic.
Now that we know how we want to get something out, how do we put it in? In our case of a complete tree, there isn’t even a jagged bottom edge. Instead of holes, we have a clean end after the last node. Given the number of nodes, which is the length of the array, we can pre-calculate guards, from which node on there are no left or right children. Even depth first traversal becomes an easy iteration. Let’s transform the first tree above into the second. The whole tree must be given in one go. To avoid reallocation, we take an array nodes
of (0, weight, T)
with any initial placeholder, each of which we turn into (lower, upper, T)
spread across min ..= max
inplace.
const NO_LEFT_CHILD: usize = LEN / 2;
const NO_RIGHT_CHILD: usize = if LEN % 2 == 1 { LEN / 2 } else { LEN / 2 - 1 };
/// For any node index, descend as many times to left child as possible, maybe 0 times, i.e. same index returned.
const fn deep_left(mut i: usize) -> usize {
while i < NO_LEFT_CHILD {
i = left_child(i);
}
i
}
/// For any index (modified in place) go to the next index in depth first order. Return false after last branch.
/// To traverse whole tree, start at deepest left-most node.
/// The strategy is: if there is a right child, go to it and then as far down left as possible.
/// Else go up as many parents as this is the right child of and then one parent of which it is now the left child.
/// If at the end there was no such parent, it must be the root after the last branch.
const fn traverse(i: &mut usize) -> bool {
if *i < NO_RIGHT_CHILD {
// descend one rightwards and as many as possible leftwards
*i = deep_left(right_child(*i))
} else {
// rise as many as possible leftwards and, if not at root, one rightwards
let mut new = *i;
while new > 0 && new % 2 == 0 {
new = parent(new)
}
if new > 0 {
*i = parent(new);
} else {
// end of traversal
return false;
}
}
true
}
let mut acc = smallest possible number, i.e. <type>::MIN or 0.0
// start at bottom left
let mut i = deep_left(0);
loop {
nodes[i].0 = acc;
acc += nodes[i].1 - 1; // in case of float, don’t subtract 1
nodes[i].1 = acc;
if !traverse(&mut i) {
break;
}
acc += 1; // in case of float, don’t add 1
}
I had wanted the type to be derivable late – at the latest, when the caller gets a value by matching with a typeful (usually random) weight. But, with all those obstacles, that currently seems impossible. It might be best for now to do something with explicit types, i.e. have one (macro generated) constructor per type. But even if the user-facing macro takes the explicit type, there is no way for it to concat a function name like weights_u8
. You may have wondered why people get away with let str = "This works?"
. Type names and other identifiers have separate name spaces. So we can somewhat weirdly name our constuctors u8(…)
, etc. Wrapping the final previous loop
-snippet for each primitive type looks like:
impl<V, const LEN: usize> Weights<u8, V, LEN> {
pub const fn u8(mut nodes: [(u8, u8, V); LEN], mut min: u8, max: u8) -> Self {
…
}
}
That’s a bunch of functions, most of which would get generated for nothing. And that’s without going into how to deal with f16
and f128
. Those should be available if the user has opted in, but how to check that… Worse, this architecture would lock down the fact that the type name, aka function name, must be given explicitly. However, I’m still hoping to derive it in the future.
Therefore I prefer to put type dependent code into the constructor macro. That also gives better error message locations. The downside being that now internal methods must be public. At least I hide them from generated documentation (though if you drill down with rust-analyser, some doc is there.) They could be factored out into a builder.
The macro starts verbosely, because the language lacks an alternatives operator $|
. The last three branches could have been unified with $(…)?
. But there is no easy way to provide default values for when that part doesn’t match. In the middle we have the real loop
-snippet from above. A seemingly not so well known jewel is $crate
. While the code ends up within the caller’s, that refers back to where the macro was defined.
macro_rules! weights {
// delta, min & max helpers
(@ f16) => { (0.0, 0.0, 1.0) };
(@ f32) => { (0.0, 0.0, 1.0) };
(@ f64) => { (0.0, 0.0, 1.0) };
(@ f128) => { (0.0, 0.0, 1.0) };
(@ $type:ty) => { (1, <$type>::MIN, <$type>::MAX) };
(type $type:ident($min:expr, $max:expr); $($weight:expr => $value:expr),+ $(,)?) => {
{
let mut acc: $type = $min;
let mut this = $crate::Weights::newish([$(($weight, $weight, $value)),+]);
let mut i = this.deep_left(0);
loop {
let node = this.mut_node(i);
node.0 = acc;
acc += node.1 - $crate::weights!(@ $type).0;
node.1 = acc;
if !this.traverse(&mut i) {
break;
}
acc += $crate::weights!(@ $type).0;
}
assert!(acc == $max, "Weights did not add up to max + 1.");
this
}
};
(type $type:ident; $($rest:tt)+) => {
$crate::weights! { type $type($crate::weights!(@ $type).1, $crate::weights!(@ $type).2); $($rest)+ }
};
($($rest:tt)+) => {
$crate::weights! { type f32(0.0, 1.0); $($rest)+ }
};
}
And now you see how I initially had you on: without a type declaration (last branch) it defaults to f32
. So – still thinking about whether that’s an elegant way to specify the type – I should have started with:
weights! {
type u8;
49 => 'A',
45 => 'B',
…
}
§Conclusion and Outlook
Use the optimal data structure and algorithm for your task! For code that might end up in a hot loop, take off as much load, as you sensibly can.
As for const
Rust, it’s pretty much still a sapling. Many of Rust’s wonderful features are not available, much more so than I had feared. Yet, with a lot of perseverance, it’s somewhat useable, even without a proc-macro. And there are good benefits: moving expensive operations to compile time makes your trees grow sky high – at zero run time cost. Happy new year 2025 and hoping much of this will become easier this year!
Thoughts, improvement suggestions, and more serious applications (weighted sampling?) welcome on Reddit
The input data was so friendly as to exactly add up to the range of possible values for our random type u8
. As you can see by the crooked values, it can be finicky to come up with such a clean spread. Next time let’s see if we can spread arbitrary values on our caller’s behalf, to always achieve full coverage.
By pure chance the input was sorted by decreasing frequency. This gives the fastest results for the most probable cases, overall pushing O(log n) towards O(1). So as to always ensure this, it would be good to sort the input data on our caller’s behalf. In the future let’s see if this can be achieved in const
.