Good programmers use good tools, and binary search is one of the most important tools in your toolbox. That's why you shouldn't accept anything less than the very finest binary search implementations.
You want a binary search that is absolutely reliable and returns the correct result in all cases (no overflow bugs!). You want one that works in all situations you'd use binary search, not just for searching through vectors. You want one that doesn't require a consolation with a manual to remember if it returns the first value that's greater or the last value that's lower. You want one with an API that invites you to think about edge cases so you don't cut yourself when they come up. And you want one that's simple enough that you can understand it, take it apart, and put it back together again if you have to.
If you're tired of using the first implementation of binary search that popped into some standard library author's head, if you're tired of settling for linear search because the binary search function you have handy doesn't work for your use case, and if you're tired of getting bit by edge cases, it's time to step up your game and use a lovingly-crafted binary search that has none of these issues.
So I'd like to show you [my very own binary search](https://github.com/dfinity/ic/commit/79aca8ede9fccd322e0e01183ee72c0bc3ed2e61):
```rust
pub fn search<T: Mid, G>(predicate: G, mut l: T, mut r: T) -> (Option<T>, Option<T>)
where
G: Fn(&T) -> bool,
{
// Check that f is false for l and true for r
if predicate(&l) {
return (None, Some(l));
}
if !predicate(&r) {
return (Some(r), None);
}
loop {
// Sanity check: f must be false for l and true for r, otherwise
// the input function was not monotonic
// If you think you never make mistakes, you can remove these
// checks and it will run slightly faster.
if predicate(&l) {
return (None, None);
}
if !predicate(&r) {
return (None, None);
}
match l.mid(&r) {
None => return (Some(l), Some(r)),
Some(m) => {
if predicate(&m) {
r = m;
} else {
l = m;
}
}
}
}
}
```
---
### Evangelism
This implementation of binary search is:
1. **Correct**: As long as the `Mid` trait is implemented correctly (I provide a correct implementation lower in this post), it does not give incorrect answers. This is more than many binary search implementations can claim, as it is very common for overflow to lead to bugs.
2. **Intuitive**: No more needing to remember "does this return the highest index whose value is lower than the value I'm searching for, or the lowest index whose value is greater than the value I'm searching for?". This one returns both, so you can use whichever you need for your use case.
3. **Error-detecting**: All binary search implementations depend on certain assumptions about their inputs - usually that they be non-descending. This one also makes that assumption, but in some cases if that assumption is violated it returns an error rather than silently returning a meaningless result.
4. **Simple**: Just look at the code. How could it be simpler?
----
### Explanation
It's inspired by the classic [Competitive programming in Haskell: better binary search](https://byorgey.wordpress.com/2023/01/01/competitive-programming-in-haskell-better-binary-search/) post. However, I've made a few modifications of my own, to make it a little more opinionated and remove some footguns, so I think it merits its own blog post.
Here's the idea. You pass in a predicate and a range. The predicate takes a number and returns true or false. The invariant that you have to maintain is that if the predicate returns true for some inputs `x`, it needs to return `true` for all inputs between `x` and the top of the range. (In other words, the predicate must be monotonic.)
It then returns two numbers - the first is the highest value in your range where the predicate is false, and the second is the lowest value where the predicate is true.
If that sounds a bit abstract, it's easier once you see it in action. For my predicate, I'll choose $x^3 \geq 512$. Then I'll search for integers between 0 and 20 to figure out when this becomes true.
```rust
let predicate = |x: &u64| x.pow(3) >= 512;
let result = search(predicate, 0, 20);
// (Some(7), Some(8))
```
Ok, it returned `(Some(7), Some(8))`. So it looks like for 7 this predicate is false, and for 8 it is true. As it happens, $7^3 = 343$ and $8^3 = 512$. So 7 is indeed the largest integer for which this is false, and 8 is the smallest integer for which this is true.
Why does it return `Some(7)` and `Some(8)` rather than just `7` and `8`? The reason is because the predicate might always be true or never true. If the predicate is always true, then there's no greatest value for which it is false, so the only sensible thing to return is `None`. And if it's always false, there's no smallest value for which it is true, so the only sensible thing to return is `None`.
So in these cases it returns `(Some(x), None)` or `(None, Some(x))`. The only case where it returns `(None, None)` are those when it detects non-monotonicity in your predicate (returning `false` when a lower input returned `true`). So if you know your predicate is monotonic, feel free to hit this case with a `panic!("should never happen")`.
----
### Reasonable-seeming `Mid` trait
This binary search also requires that you implement `Mid` for the argument to your predicate. (However this trait is anything but mid.) This trait tells the function how to find the midpoint of two values.
```rust
pub trait Mid: Sized {
/// Returns the value halfway between `self` and `other`, rounded to -infinity.
/// If there are no value between `l` and `r`, returns `None`.
fn mid(&self, other: &Self) -> Option<Self>;
}
// Not the best implementation, feel free to change
impl<T> Mid for T
where
T: Add<Output = T> + Sub<Output = T> + Div<Output = T> + Ord + Copy + From<u8>,
{
fn mid(&self, other: &Self) -> Option<Self> {
let (small, large) = if self <= other {
(self, other)
} else {
(other, self)
};
let difference = *large - *small;
let two = T::from(2u8);
let mid = *small + (difference / two);
if mid == *small || mid == *large {
None
} else {
Some(mid)
}
}
}
```
I don't think this is the best implementation. I actually know that there are faster ones available. Instead, I wrote it with a focus on correctness, because implementing a midpoint function that is always correct even for values at the extreme ends of what can be represented is not totally straightforward. However, I'm worried it's subtly wrong for e.g. floating point numbers or something, so maybe look over it for yourself or use the [midpoint](https://docs.rs/midpoint/latest/midpoint/) crate.[^1]
[^1]: Just make sure whatever midpoint you use always rounds in the same direction (not towards `0`), and returns `None` if there are no values in between the provided ones (as this is how the search function knows the search is over).