Commit 430aa0d8 authored by Richard Barnes's avatar Richard Barnes Committed by Facebook GitHub Bot

Add `midpoint()` calculation to folly

Summary:
A number of libraries at Facebook include interval midpoint calculations; however, doing these in a mathematically precise way (without over/underflow) can be tricky. Doing them wrong can break binary searches over large datasets and give imprecise floating-point calculations.

This function provides an early opportunity to fix binary searches and other calculations which can later be updated to `std::midpoint()` when C++20 becomes available.

Reviewed By: yfeldblum

Differential Revision: D23997097

fbshipit-source-id: 373e0dc1d1ff071f697ee782be46fb0d49a2f8f7
parent 9db54edd
......@@ -23,6 +23,7 @@
#include <stdint.h>
#include <cmath>
#include <limits>
#include <type_traits>
......@@ -200,4 +201,58 @@ inline constexpr detail::IdivResultType<N, D> divRoundAway(N num, D denom) {
: detail::divRoundAwayBranchful<R>(num, denom));
}
// clang-format off
// Disabling clang-formatting for midpoint to retain 1:1 correlation
// with LLVM
// midpoint
//
// mimic: std::numeric::midpoint, C++20
// from:
// https://github.com/llvm/llvm-project/blob/llvmorg-11.0.0/libcxx/include/numeric,
// Apache 2.0 with LLVM exceptions
template <class _Tp>
constexpr std::enable_if_t<
std::is_integral<_Tp>::value && !std::is_same<bool, _Tp>::value &&
!std::is_null_pointer<_Tp>::value,
_Tp>
midpoint(_Tp __a, _Tp __b) noexcept {
using _Up = std::make_unsigned_t<_Tp>;
constexpr _Up __bitshift = std::numeric_limits<_Up>::digits - 1;
_Up __diff = _Up(__b) - _Up(__a);
_Up __sign_bit = __b < __a;
_Up __half_diff = (__diff / 2) + (__sign_bit << __bitshift) + (__sign_bit & __diff);
return __a + __half_diff;
}
template <class _TPtr>
constexpr std::enable_if_t<
std::is_pointer<_TPtr>::value &&
std::is_object<std::remove_pointer_t<_TPtr>>::value &&
!std::is_void<std::remove_pointer_t<_TPtr>>::value &&
(sizeof(std::remove_pointer_t<_TPtr>) > 0),
_TPtr>
midpoint(_TPtr __a, _TPtr __b) noexcept {
return __a + midpoint(std::ptrdiff_t(0), __b - __a);
}
template <class _Fp>
constexpr std::enable_if_t<std::is_floating_point<_Fp>::value, _Fp> midpoint(
_Fp __a,
_Fp __b) noexcept {
constexpr _Fp __lo = std::numeric_limits<_Fp>::min()*2;
constexpr _Fp __hi = std::numeric_limits<_Fp>::max()/2;
return std::abs(__a) <= __hi && std::abs(__b) <= __hi ? // typical case: overflow is impossible
(__a + __b)/2 : // always correctly rounded
std::abs(__a) < __lo ? __a + __b/2 : // not safe to halve a
std::abs(__b) < __lo ? __a/2 + __b : // not safe to halve b
__a/2 + __b/2; // otherwise correctly rounded
}
// clang-format on
} // namespace folly
......@@ -15,8 +15,10 @@
*/
#include <folly/Math.h>
#include <folly/functional/Invoke.h>
#include <algorithm>
#include <array>
#include <type_traits>
#include <utility>
#include <vector>
......@@ -244,3 +246,201 @@ TEST(Bits, divTestUint64) {
runDivTests<uint64_t, uint64_t, __int128>();
}
#endif
FOLLY_CREATE_FREE_INVOKER(midpoint_invoke, midpoint);
TEST(MidpointTest, MidpointTest) {
EXPECT_EQ(midpoint<int8_t>(2, 4), 3);
EXPECT_EQ(midpoint<int8_t>(3, 4), 3);
EXPECT_EQ(midpoint<int8_t>(-2, 2), 0);
EXPECT_EQ(midpoint<int8_t>(-4, -2), -3);
EXPECT_EQ(midpoint<int8_t>(102, 104), 103);
EXPECT_EQ(midpoint<int8_t>(126, 126), 126);
EXPECT_EQ(midpoint<int8_t>(-104, -102), -103);
// Perform some simple tests. Note that because these are small integers
// they can be represented exactly, so we do not have floating-point error.
EXPECT_EQ(midpoint(2.0, 4.0), 3.0);
EXPECT_EQ(midpoint(-2.0, 2.0), 0.0);
EXPECT_EQ(midpoint(-2.1, 2.1), 0.0);
EXPECT_EQ(midpoint(-4.0, -2.0), -3.0);
EXPECT_EQ(midpoint(102.0, 104.0), 103.0);
EXPECT_EQ(midpoint(-104.0, -102.0), -103.0);
// Double
EXPECT_EQ(midpoint(2.0, 4.0), 3.0);
EXPECT_EQ(midpoint(0.0, 0.4), 0.2);
EXPECT_EQ(midpoint(0.0, -0.0), 0.0);
EXPECT_EQ(midpoint(9e9, -9e9), 0.0);
EXPECT_EQ(
midpoint(
std::numeric_limits<double>::max(),
std::numeric_limits<double>::max()),
std::numeric_limits<double>::max());
EXPECT_TRUE(std::isnan(midpoint(
-std::numeric_limits<double>::infinity(),
std::numeric_limits<double>::infinity())));
// Float
EXPECT_EQ(midpoint(2.0f, 4.0f), 3.0f);
EXPECT_EQ(midpoint(0.0f, 0.4f), 0.2f);
EXPECT_EQ(midpoint(0.0f, -0.0f), 0.0f);
EXPECT_EQ(midpoint(9e9f, -9e9f), 0.0f);
EXPECT_EQ(
midpoint(
std::numeric_limits<float>::max(), std::numeric_limits<float>::max()),
std::numeric_limits<float>::max());
EXPECT_TRUE(std::isnan(midpoint(
-std::numeric_limits<float>::infinity(),
std::numeric_limits<float>::infinity())));
// Long double
EXPECT_EQ(midpoint(2.0l, 4.0l), 3.0l);
EXPECT_EQ(midpoint(0.0l, 0.4l), 0.2l);
EXPECT_EQ(midpoint(0.0l, -0.0l), 0.0l);
EXPECT_EQ(midpoint(9e9l, -9e9l), 0.0l);
EXPECT_EQ(
midpoint(
std::numeric_limits<long double>::max(),
std::numeric_limits<long double>::max()),
std::numeric_limits<long double>::max());
EXPECT_TRUE(std::isnan(midpoint(
-std::numeric_limits<long double>::infinity(),
std::numeric_limits<long double>::infinity())));
EXPECT_TRUE(noexcept(midpoint(1, 2)));
EXPECT_FALSE((is_invocable_v<midpoint_invoke, bool>));
EXPECT_FALSE((is_invocable_v<midpoint_invoke, const bool>));
EXPECT_FALSE((is_invocable_v<midpoint_invoke, volatile int>));
constexpr auto MY_INT_MAX = std::numeric_limits<int>::max();
constexpr auto MY_INT_MIN = std::numeric_limits<int>::min();
constexpr auto MY_UINT_MAX = std::numeric_limits<unsigned int>::max();
constexpr auto MY_SHRT_MAX = std::numeric_limits<short>::max();
constexpr auto MY_SHRT_MIN = std::numeric_limits<short>::min();
constexpr auto MY_SCHAR_MAX = std::numeric_limits<signed char>::max();
constexpr auto MY_SCHAR_MIN = std::numeric_limits<signed char>::min();
EXPECT_EQ(midpoint(0, 0), 0);
EXPECT_EQ(midpoint(1, 1), 1);
EXPECT_EQ(midpoint(0, 1), 0);
EXPECT_EQ(midpoint(1, 0), 1);
EXPECT_EQ(midpoint(0, 2), 1);
EXPECT_EQ(midpoint(3, 2), 3);
EXPECT_EQ(midpoint(-5, 4), -1);
EXPECT_EQ(midpoint(5, -4), 1);
EXPECT_EQ(midpoint(-5, -4), -5);
EXPECT_EQ(midpoint(-4, -5), -4);
EXPECT_EQ(midpoint(MY_INT_MIN, MY_INT_MAX), -1);
EXPECT_EQ(midpoint(MY_INT_MAX, MY_INT_MIN), 0);
EXPECT_EQ(midpoint(MY_INT_MAX, MY_INT_MAX), MY_INT_MAX);
EXPECT_EQ(midpoint(MY_INT_MAX, MY_INT_MAX - 1), MY_INT_MAX);
EXPECT_EQ(midpoint(MY_INT_MAX - 1, MY_INT_MAX - 1), MY_INT_MAX - 1);
EXPECT_EQ(midpoint(MY_INT_MAX - 1, MY_INT_MAX), MY_INT_MAX - 1);
EXPECT_EQ(midpoint(MY_INT_MAX, MY_INT_MAX - 2), MY_INT_MAX - 1);
EXPECT_EQ(midpoint(0u, 0u), 0);
EXPECT_EQ(midpoint(0u, 1u), 0);
EXPECT_EQ(midpoint(1u, 0u), 1);
EXPECT_EQ(midpoint(0u, 2u), 1);
EXPECT_EQ(midpoint(3u, 2u), 3);
EXPECT_EQ(midpoint(0u, MY_UINT_MAX), MY_UINT_MAX / 2);
EXPECT_EQ(midpoint(MY_UINT_MAX, 0u), (MY_UINT_MAX / 2 + 1));
EXPECT_EQ(midpoint(MY_UINT_MAX, MY_UINT_MAX), MY_UINT_MAX);
EXPECT_EQ(midpoint(MY_UINT_MAX, MY_UINT_MAX - 1), MY_UINT_MAX);
EXPECT_EQ(midpoint(MY_UINT_MAX - 1, MY_UINT_MAX - 1), MY_UINT_MAX - 1);
EXPECT_EQ(midpoint(MY_UINT_MAX - 1, MY_UINT_MAX), MY_UINT_MAX - 1);
EXPECT_EQ(midpoint(MY_UINT_MAX, MY_UINT_MAX - 2), MY_UINT_MAX - 1);
EXPECT_EQ(midpoint<short>(0, 0), 0);
EXPECT_EQ(midpoint<short>(0, 1), 0);
EXPECT_EQ(midpoint<short>(1, 0), 1);
EXPECT_EQ(midpoint<short>(0, 2), 1);
EXPECT_EQ(midpoint<short>(3, 2), 3);
EXPECT_EQ(midpoint<short>(-5, 4), -1);
EXPECT_EQ(midpoint<short>(5, -4), 1);
EXPECT_EQ(midpoint<short>(-5, -4), -5);
EXPECT_EQ(midpoint<short>(-4, -5), -4);
EXPECT_EQ(midpoint<short>(MY_SHRT_MIN, MY_SHRT_MAX), -1);
EXPECT_EQ(midpoint<short>(MY_SHRT_MAX, MY_SHRT_MIN), 0);
EXPECT_EQ(midpoint<short>(MY_SHRT_MAX, MY_SHRT_MAX), MY_SHRT_MAX);
EXPECT_EQ(midpoint<short>(MY_SHRT_MAX, MY_SHRT_MAX - 1), MY_SHRT_MAX);
EXPECT_EQ(midpoint<short>(MY_SHRT_MAX - 1, MY_SHRT_MAX - 1), MY_SHRT_MAX - 1);
EXPECT_EQ(midpoint<short>(MY_SHRT_MAX - 1, MY_SHRT_MAX), MY_SHRT_MAX - 1);
EXPECT_EQ(midpoint<short>(MY_SHRT_MAX, MY_SHRT_MAX - 2), MY_SHRT_MAX - 1);
EXPECT_EQ(midpoint<signed char>(0, 0), 0);
EXPECT_EQ(midpoint<signed char>(1, 1), 1);
EXPECT_EQ(midpoint<signed char>(0, 1), 0);
EXPECT_EQ(midpoint<signed char>(1, 0), 1);
EXPECT_EQ(midpoint<signed char>(0, 2), 1);
EXPECT_EQ(midpoint<signed char>(3, 2), 3);
EXPECT_EQ(midpoint<signed char>(-5, 4), -1);
EXPECT_EQ(midpoint<signed char>(5, -4), 1);
EXPECT_EQ(midpoint<signed char>(-5, -4), -5);
EXPECT_EQ(midpoint<signed char>(-4, -5), -4);
EXPECT_EQ(midpoint<signed char>(MY_SCHAR_MIN, MY_SCHAR_MAX), -1);
EXPECT_EQ(midpoint<signed char>(MY_SCHAR_MAX, MY_SCHAR_MIN), 0);
EXPECT_EQ(midpoint<signed char>(MY_SCHAR_MAX, MY_SCHAR_MAX), MY_SCHAR_MAX);
EXPECT_EQ(
midpoint<signed char>(MY_SCHAR_MAX, MY_SCHAR_MAX - 1), MY_SCHAR_MAX);
constexpr auto MY_SIZE_T_MAX = std::numeric_limits<size_t>::max();
EXPECT_EQ(midpoint<size_t>(0, 0), 0);
EXPECT_EQ(midpoint<size_t>(1, 1), 1);
EXPECT_EQ(midpoint<size_t>(0, 1), 0);
EXPECT_EQ(midpoint<size_t>(1, 0), 1);
EXPECT_EQ(midpoint<size_t>(0, 2), 1);
EXPECT_EQ(midpoint<size_t>(3, 2), 3);
EXPECT_EQ(midpoint<size_t>((size_t)0, MY_SIZE_T_MAX), MY_SIZE_T_MAX / 2);
EXPECT_EQ(midpoint<size_t>(MY_SIZE_T_MAX, (size_t)0), (MY_SIZE_T_MAX / 2 + 1));
EXPECT_EQ(midpoint<size_t>(MY_SIZE_T_MAX, MY_SIZE_T_MAX), MY_SIZE_T_MAX);
EXPECT_EQ(midpoint<size_t>(MY_SIZE_T_MAX, MY_SIZE_T_MAX - 1), MY_SIZE_T_MAX);
EXPECT_EQ(midpoint<size_t>(MY_SIZE_T_MAX - 1, MY_SIZE_T_MAX - 1), MY_SIZE_T_MAX - 1);
EXPECT_EQ(midpoint<size_t>(MY_SIZE_T_MAX - 1, MY_SIZE_T_MAX), MY_SIZE_T_MAX - 1);
EXPECT_EQ(midpoint<size_t>(MY_SIZE_T_MAX, MY_SIZE_T_MAX - 2), MY_SIZE_T_MAX - 1);
#if FOLLY_HAVE_INT128_T
const auto I128_MIN = std::numeric_limits<__int128_t>::min();
const auto I128_MAX = std::numeric_limits<__int128_t>::max();
EXPECT_EQ(midpoint<__int128_t>(0, 0), 0);
EXPECT_EQ(midpoint<__int128_t>(1, 1), 1);
EXPECT_EQ(midpoint<__int128_t>(0, 1), 0);
EXPECT_EQ(midpoint<__int128_t>(1, 0), 1);
EXPECT_EQ(midpoint<__int128_t>(0, 2), 1);
EXPECT_EQ(midpoint<__int128_t>(3, 2), 3);
EXPECT_EQ(midpoint<__int128_t>(-5, 4), -1);
EXPECT_EQ(midpoint<__int128_t>(5, -4), 1);
EXPECT_EQ(midpoint<__int128_t>(-5, -4), -5);
EXPECT_EQ(midpoint<__int128_t>(-4, -5), -4);
EXPECT_EQ(midpoint<__int128_t>(I128_MIN, I128_MAX), -1);
EXPECT_EQ(midpoint<__int128_t>(I128_MAX, I128_MIN), 0);
EXPECT_EQ(midpoint<__int128_t>(I128_MAX, I128_MAX), I128_MAX);
EXPECT_EQ(midpoint<__int128_t>(I128_MAX, I128_MAX - 1), I128_MAX);
#endif
// Test every possibility for signed char.
for (int a = MY_SCHAR_MIN; a <= MY_SCHAR_MAX; ++a)
for (int b = MY_SCHAR_MIN; b <= MY_SCHAR_MAX; ++b)
EXPECT_EQ(midpoint(a, b), midpoint<int>(a, b));
EXPECT_FALSE((is_invocable_v<midpoint_invoke, void>));
EXPECT_FALSE((is_invocable_v<midpoint_invoke, int()>));
EXPECT_FALSE((is_invocable_v<midpoint_invoke, int&>));
EXPECT_FALSE((is_invocable_v<midpoint_invoke, struct Incomplete>));
constexpr std::array<int, 3> ca = {0, 1, 2};
EXPECT_EQ(midpoint(ca.data(), ca.data() + 3), ca.data() + 1);
constexpr std::array<int, 4> a = {0, 1, 2, 3};
EXPECT_EQ(midpoint(a.data(), a.data()), a.data());
EXPECT_EQ(midpoint(a.data(), a.data() + 1), a.data());
EXPECT_EQ(midpoint(a.data(), a.data() + 2), a.data() + 1);
EXPECT_EQ(midpoint(a.data(), a.data() + 3), a.data() + 1);
EXPECT_EQ(midpoint(a.data(), a.data() + 4), a.data() + 2);
EXPECT_EQ(midpoint(a.data() + 1, a.data()), a.data() + 1);
EXPECT_EQ(midpoint(a.data() + 2, a.data()), a.data() + 1);
EXPECT_EQ(midpoint(a.data() + 3, a.data()), a.data() + 2);
EXPECT_EQ(midpoint(a.data() + 4, a.data()), a.data() + 2);
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment