从 std::ratio 看编译期运算
<ratio>
提供了 编译期 分数运算功能(还记得当年 OOP 课上要写一个分数类,不过是运行时的)。编译期呢,也就是说分母(numerator)和分子(denominator)都是在编译期时就是确定的,并且在编译期完成所需的各种运算。
先抄一个 ratio 使用的例子:
#include <iostream>
#include <ratio>
int main ()
{
typedef std::ratio<1,3> one_third;
typedef std::ratio<2,4> two_fourths;
typedef std::ratio_add<one_third,two_fourths> sum;
std::cout << "sum= " << sum::num << "/" << sum::den;
std::cout << " (which is: " << ( double(sum::num) / sum::den ) << ")" << std::endl;
return 0;
}
ratio_add
(还有 sub,mul 等等)是怎样在编译期做到的呢?这就要靠强大的 template 以及 constexpr 了(注意,这时候是一直在对类型做运算喔)。
先来看 ratio 是怎么定义的。
template<intmax_t _Num, intmax_t _Den = 1>
struct ratio
{
static_assert(_Den != 0, "denominator cannot be zero");
static_assert(_Num >= -__INTMAX_MAX__ && _Den >= -__INTMAX_MAX__,
"out of range");
// Note: sign(N) * abs(N) == N
static constexpr intmax_t num =
_Num * __static_sign<_Den>::value / __static_gcd<_Num, _Den>::value;
static constexpr intmax_t den =
__static_abs<_Den>::value / __static_gcd<_Num, _Den>::value;
typedef ratio<num, den> type;
};
ratio
是一个用分子和分母做模板参数的类模板。里面两个 constexpr
通过对模板参数运算,得到化简后的分母和分子。sign
,abs
,gcd
这些名字都一目了然了,大概做的事情就是把分母分子除 gcd
化简,然后把符号拿到分子上去。来看一下这几个模板是怎么做的。
template<intmax_t _Pn>
struct __static_sign
: integral_constant<intmax_t, (_Pn < 0) ? -1 : 1>
{ };
template<intmax_t _Pn>
struct __static_abs
: integral_constant<intmax_t, _Pn * __static_sign<_Pn>::value>
{ };
template<intmax_t _Pn, intmax_t _Qn>
struct __static_gcd
: __static_gcd<_Qn, (_Pn % _Qn)>
{ };
template<intmax_t _Pn>
struct __static_gcd<_Pn, 0>
: integral_constant<intmax_t, __static_abs<_Pn>::value>
{ };
template<intmax_t _Qn>
struct __static_gcd<0, _Qn>
: integral_constant<intmax_t, __static_abs<_Qn>::value>
{ };
过程简介明了。
接下来就是 ratio
的运算了,加减乘除。先来回忆一下加法是怎么做的,把两边分子分别乘对方的分母加到一起为分子,分母相乘结果为分母,之后再通分。
template<typename _R1, typename _R2>
struct __ratio_add
{
typedef typename __ratio_add_impl<_R1, _R2>::type type;
static constexpr intmax_t num = type::num;
static constexpr intmax_t den = type::den;
};
template<typename _R1, typename _R2>
constexpr intmax_t __ratio_add<_R1, _R2>::num;
template<typename _R1, typename _R2>
constexpr intmax_t __ratio_add<_R1, _R2>::den;
/// ratio_add
template<typename _R1, typename _R2>
using ratio_add = typename __ratio_add<_R1, _R2>::type;
ratio_add
的结果类型并不是 ratio
,不过只要有 num
和 den
这两个成员就足够了~ 而真的需要 ratio 类型的话,那个 type 应该就是 ratio 吧?
来看一下 __ratio_add_impl
的实现。
template<typename _R1, typename _R2,
bool = (_R1::num >= 0),
bool = (_R2::num >= 0),
bool = ratio_less<ratio<__static_abs<_R1::num>::value, _R1::den>,
ratio<__static_abs<_R2::num>::value, _R2::den> >::value>
struct __ratio_add_impl
{
private:
typedef typename __ratio_add_impl<
ratio<-_R1::num, _R1::den>,
ratio<-_R2::num, _R2::den> >::type __t;
public:
typedef ratio<-__t::num, __t::den> type;
};
template<typename _R1, typename _R2, bool __b>
struct __ratio_add_impl<_R1, _R2, true, true, __b>
template<typename _R1, typename _R2>
struct __ratio_add_impl<_R1, _R2, false, true, true>
: __ratio_add_impl<_R2, _R1>
{ };
template<typename _R1, typename _R2>
struct __ratio_add_impl<_R1, _R2, true, false, false>
比想象中的要复杂,模板参数分别是 _R1
,_R2
的正负号和绝对值关系。各种情况都被最后划归到最后两种情况:
-
_R1 > 0
,_R2 > 0
-
_R1 > 0
,_R2 < 0
,abs(_R1) > abs(_R2)
但是实际运算中并没有这么复杂啊,这么大张旗鼓,根据直觉估计是要对溢出做处理。
先看第一种情况。
template<typename _R1, typename _R2, bool __b>
struct __ratio_add_impl<_R1, _R2, true, true, __b>
{
private:
static constexpr uintmax_t __g = __static_gcd<_R1::den, _R2::den>::value;
// 计算两边分母 gcd,__g。
static constexpr uintmax_t __d2 = _R2::den / __g;
// __d2 也就是 _R1 分子分母需要扩大的倍数
typedef __big_mul<_R1::num, _R2::den / __g> __x;
typedef __big_mul<_R2::num, _R1::den / __g> __y;
// __x, __y 分别是 _R1 和 _R2 上分母通分后,分子扩大相应倍数的结果
typedef __big_add<__x::__hi, __x::__lo, __y::__hi, __y::__lo> __n;
// 对高位和低位分别做加法
static_assert(__n::__hi >= __x::__hi, "Internal library error");
// assert 相加结果
typedef __big_div<__n::__hi, __n::__lo, __g> __ng;
static constexpr uintmax_t __g2 = __static_gcd<__ng::__rem, __g>::value;
typedef __big_div<__n::__hi, __n::__lo, __g2> __n_final;
// __n_final = __n / __g2 = __n / gcd(__ng::rem, __g) = __n / gcd(__n % __g, __g) = __n / gcd(__g, __n)
// 把 __n 中可能带有 __g 的因子去掉得到最后的分子(__n 不会带有除 __g 以外 _R2 或者 _R1 分母的因子)
static_assert(__n_final::__rem == 0, "Internal library error");
static_assert(__n_final::__quot_hi == 0 &&
__n_final::__quot_lo <= __INTMAX_MAX__, "overflow in addition");
// assert 整除,溢出
typedef __big_mul<_R1::den / __g2, __d2> __d_final;
// 计算最后的分母,__d_final = _R1::end / gcd(__g, __n) * _R2::end / __g
static_assert(__d_final::__hi == 0 &&
__d_final::__lo <= __INTMAX_MAX__, "overflow in addition");
// assert 溢出
public:
typedef ratio<__n_final::__quot_lo, __d_final::__lo> type;
};
原来里面在做运算的时候,考虑了溢出,并且对高低位做了处理。难怪。相加的情况搞清楚了,相减也是同理咯。不过这几个 __big_div
什么的是怎么做的呢?
template<uintmax_t __hi1, uintmax_t __lo1, uintmax_t __hi2, uintmax_t __lo2>
struct __big_add
{
static constexpr uintmax_t __lo = __lo1 + __lo2;
static constexpr uintmax_t __hi = (__hi1 + __hi2 +
(__lo1 + __lo2 < __lo1)); // carry
};
// Subtract a number from a bigger one.
template<uintmax_t __hi1, uintmax_t __lo1, uintmax_t __hi2, uintmax_t __lo2>
struct __big_sub
{
static_assert(!__big_less<__hi1, __lo1, __hi2, __lo2>::value,
"Internal library error");
static constexpr uintmax_t __lo = __lo1 - __lo2;
static constexpr uintmax_t __hi = (__hi1 - __hi2 -
(__lo1 < __lo2)); // carry
};
加减的还好办,关键是乘除。
template<uintmax_t __x, uintmax_t __y>
struct __big_mul
{
private:
static constexpr uintmax_t __c = uintmax_t(1) << (sizeof(intmax_t) * 4);
static constexpr uintmax_t __x0 = __x % __c;
static constexpr uintmax_t __x1 = __x / __c;
static constexpr uintmax_t __y0 = __y % __c;
static constexpr uintmax_t __y1 = __y / __c;
static constexpr uintmax_t __x0y0 = __x0 * __y0;
static constexpr uintmax_t __x0y1 = __x0 * __y1;
static constexpr uintmax_t __x1y0 = __x1 * __y0;
static constexpr uintmax_t __x1y1 = __x1 * __y1;
static constexpr uintmax_t __mix = __x0y1 + __x1y0; // possible carry...
static constexpr uintmax_t __mix_lo = __mix * __c;
static constexpr uintmax_t __mix_hi
= __mix / __c + ((__mix < __x0y1) ? __c : 0); // ... added here
typedef __big_add<__mix_hi, __mix_lo, __x1y1, __x0y0> _Res;
public:
static constexpr uintmax_t __hi = _Res::__hi;
static constexpr uintmax_t __lo = _Res::__lo;
};
__c
的中间位为 1
,将 __x
,__y
用 __c
分解之后做乘法,然后合并。为了防止溢出,__big_mul
把原来的精度扩展了一倍。那如果是 __big_div
呢?很复杂,这里暂时先不分析了,留着坑以后可以挖。
__ratio_add
调用了 __big_add
,__big_mul
,__big_div
这些工具。ratio_sub
可以化归到 __ratio_add
上。相似的,__ratio_divide
也可以化归到 __ratio_multiply
上。来看 __ratio_multiply
。
template<typename _R1, typename _R2>
struct __ratio_multiply
{
private:
static const intmax_t __gcd1 =
__static_gcd<_R1::num, _R2::den>::value;
static const intmax_t __gcd2 =
__static_gcd<_R2::num, _R1::den>::value;
public:
typedef ratio<
__safe_multiply<(_R1::num / __gcd1),
(_R2::num / __gcd2)>::value,
__safe_multiply<(_R1::den / __gcd2),
(_R2::den / __gcd1)>::value> type;
static constexpr intmax_t num = type::num;
static constexpr intmax_t den = type::den;
};
这里并没有之前那个 __big_mul
,因为这里并不需要双精度,溢出直接挂就好了,所以 __safe_multiply
相对要简单的多。
template<intmax_t _Pn, intmax_t _Qn>
struct __safe_multiply
{
private:
static const uintmax_t __c = uintmax_t(1) << (sizeof(intmax_t) * 4);
static const uintmax_t __a0 = __static_abs<_Pn>::value % __c;
static const uintmax_t __a1 = __static_abs<_Pn>::value / __c;
static const uintmax_t __b0 = __static_abs<_Qn>::value % __c;
static const uintmax_t __b1 = __static_abs<_Qn>::value / __c;
static_assert(__a1 == 0 || __b1 == 0,
"overflow in multiplication");
static_assert(__a0 * __b1 + __b0 * __a1 < (__c >> 1),
"overflow in multiplication");
static_assert(__b0 * __a0 <= __INTMAX_MAX__,
"overflow in multiplication");
static_assert((__a0 * __b1 + __b0 * __a1) * __c
<= __INTMAX_MAX__ - __b0 * __a0,
"overflow in multiplication");
public:
static const intmax_t value = _Pn * _Qn;
};
恩,运算搞定了,还有大小关系。还记得 operator 关系定义的规则吧?所以这里只看 ratio_less
就够了。
template<typename _R1, typename _R2>
struct ratio_less
: __ratio_less_impl<_R1, _R2>::type
{ };
// 同为 + 号,直接给 __ratio_less_impl_1
template<typename _R1, typename _R2,
bool = (_R1::num == 0 || _R2::num == 0
|| (__static_sign<_R1::num>::value
!= __static_sign<_R2::num>::value)),
bool = (__static_sign<_R1::num>::value == -1
&& __static_sign<_R2::num>::value == -1)>
struct __ratio_less_impl
: __ratio_less_impl_1<_R1, _R2>::type
{ };
// 如果异号或者其中有分子为 0,直接做判断
template<typename _R1, typename _R2>
struct __ratio_less_impl<_R1, _R2, true, false>
: integral_constant<bool, _R1::num < _R2::num>
{ };
// 同为 - 号,加负号之后调换位置给 __ratio_less_impl_1
template<typename _R1, typename _R2>
struct __ratio_less_impl<_R1, _R2, false, true>
: __ratio_less_impl_1<ratio<-_R2::num, _R2::den>,
ratio<-_R1::num, _R1::den> >::type
{ };
继续看 __ratio_less_impl_1
。
template<typename _R1, typename _R2,
typename _Left = __big_mul<_R1::num,_R2::den>,
typename _Right = __big_mul<_R2::num,_R1::den> >
struct __ratio_less_impl_1
: integral_constant<bool, __big_less<_Left::__hi, _Left::__lo,
_Right::__hi, _Right::__lo>::value>
{ };
转给了 __big_less
,双精度的比较,其实 yy 一下就知道里面在怎么做的了。
template<uintmax_t __hi1, uintmax_t __lo1, uintmax_t __hi2, uintmax_t __lo2>
struct __big_less
: integral_constant<bool, (__hi1 < __hi2
|| (__hi1 == __hi2 && __lo1 < __lo2))>
{ };
最后看一下预定义的几个 ratio。
typedef ratio<1, 1000000000000000000> atto;
typedef ratio<1, 1000000000000000> femto;
typedef ratio<1, 1000000000000> pico;
typedef ratio<1, 1000000000> nano;
typedef ratio<1, 1000000> micro;
typedef ratio<1, 1000> milli;
typedef ratio<1, 100> centi;
typedef ratio<1, 10> deci;
typedef ratio< 10, 1> deca;
typedef ratio< 100, 1> hecto;
typedef ratio< 1000, 1> kilo;
typedef ratio< 1000000, 1> mega;
typedef ratio< 1000000000, 1> giga;
typedef ratio< 1000000000000, 1> tera;
typedef ratio< 1000000000000000, 1> peta;
typedef ratio< 1000000000000000000, 1> exa;
常见的数量级哈,在 std::chrono
里有用到。
总结一下~
- 编译期无敌的 C++ template,图灵完备不是盖的。
- 利用模板来进行编译期 selection (分子分母正负号,大小关系)。怎样选择 selection 的条件对逻辑思维是个锻炼,怎么把分支做的最简洁?
- single precision to double precision 乘法,和 double to single 的除法。(除法改天再挖坑)