《产生式元编程》 第七章 巧活用折叠表达式

教育   2024-08-27 13:40   美国  

C++ Generative Metaprogramming:

目录:


Introduction

模板是第一阶段元编程最核心的工具,中篇以两章五星难度的内容开头,深入纵览其核心技术与诸般妙诀。本章要讨论的元编程工具——Fold Expressions,依旧处于第一阶段,是 C++17 引入的一个跨越性产生式特性。

该特性从更高层面抽象了参数包拆解方式,消除了传统递归所带来的复杂性,是非常有用的编译期遍历方式。因此,这个特性是产生式元编程中的关键部分,这也是单独分配一章深度讨论的原因。

这个特性也并不完善,C++26/29 依旧会增加与其相关的特性,增加哪些?为什么加?痛点在哪儿?这些问题也将在本章给出答案。

Fold

Fold 这个概念来自函数式编程,本身指的是一类高阶函数,这些函数使用给定的组合操作来分析递归数据结构,并通过递归处理其组成部分的结果来重新组合,最终构建出一个返回值。简而言之,Fold 能够递归地遍历数据结构中的每个元素,并通过一个组合函数将这些元素的值合并为一个结果。这个组合函数通常定义了如何将两个值合并在一起,以及如何处理基础情况(如空数据结构)。

这种从递归数据结构中提取信息的数学概念称为 Catamorphism(源自古希腊语:κατά "向下" 和 μορφή "形式,形状"),表示从一个初始代数到其他代数的唯一同态映射。编程中,目标代数通常就是一个单一的值或结果。若从词的本意来理解,其实指的就是将一复杂数据结构向下拆解成简单数据结构的过程,这个过程递归地利用一个组合函数来完成。Catamorphism 通过抽象化递归数据结构的处理方式,提供了一种通用的模式来遍历和处理递归数据结构,使得代码更具可读性和可维护性,适用于各种场景,如求和、乘积、查找、过滤等。

Fold 只是 Catamorphism 的一个具体实现,主要是指对列表和序列的处理。例如,对列表 [1, 2, 3, 4] 求和,可以这样表示:

1foldl (+) 0 [1, 2, 3, 4] = (((0 + 1) + 2) + 3) + 4 = 10
2foldr (+) 0 [1, 2, 3, 4] = 1 + (2 + (3 + (4 + 0))) = 10

想必大家也不陌生,foldl 是从左向右应用函数的方式,称为左折叠,foldr与之相反,称为右折叠。若是操作类型满足交换律,不论选择哪种方式,结果都一样。由此结构,也可以看到 Fold 通常包含的三个参数:

  • 二元函数:+,定义了如何将列表的两个元素合并成一个值;

  • 初始值:0,折叠操作的开始值,它与列表的第一个元素一起传递给二元函数;

  • 列表:[1, 2, 3, 4],折叠操作的元素集合。

当然,倘是一元函数,则不需要初始值,例如,拼接字符串:

1foldl (++) ["a", "b", "c", "d"] = (((("a" ++ "b") ++ "c") ++ "d")) = "abcd"
2foldr (++) ["a", "b", "c", "d"] = "a" ++ ("b" ++ ("c" ++ "d")) = "abcd"

正是这种更高一级的抽象方式,Fold 这个概念才能够简化常规的递归方式,以一种更加易于人类理解的方式表达拆解逻辑,增加代码的可读性的同时,也简化了编码效率。

C++ Fold

std::accumulate 就是C++ 提供的一个 Fold 函数,包含前面所说的三个参数。基本形式如下:

1// left fold
2std::accumulate(begin, end, initval, func)
3// right fold
4std::accumulate(rbegin, rend, initval, func)

下面是一个例子:

1// fr. https://en.cppreference.com/w/cpp/algorithm/accumulate
2#include <functional>
3#include <iostream>
4#include <numeric>
5#include <string>
6#include <vector>
7
8int main()
9
{
10    std::vector<int> v{12345678910};
11
12    int sum = std::accumulate(v.begin(), v.end(), 0);
13    int product = std::accumulate(v.begin(), v.end(), 1std::multiplies<int>());
14
15    auto dash_fold = [](std::string a, int b)
16    {
17        return std::move(a) + '-' + std::to_string(b);
18    };
19
20    std::string s = std::accumulate(std::next(v.begin()), v.end(),
21                                    std::to_string(v[0]), // start with first element
22                                    dash_fold);
23
24    // Right fold using reverse iterators
25    std::string rs = std::accumulate(std::next(v.rbegin()), v.rend(),
26                                     std::to_string(v.back()), // start with last element
27                                     dash_fold);
28
29    std::cout << "sum: " << sum << '\n'
30              << "product: " << product << '\n'
31              << "dash-separated string: " << s << '\n'
32              << "dash-separated string (right-folded): " << rs << '\n';
33}

输出为:

1sum: 55
2product: 3628800
3dash-separated string: 1-2-3-4-5-6-7-8-9-10
4dash-separated string (right-folded): 10-9-8-7-6-5-4-3-2-1

函数刚好接受一个列表、一个初始值、一个二元定制函数作为参数,所以类似的需求皆可摆脱传统的遍历方式,表达起来更加简单。但是 std::accumulate 只支持二元函数,且表意不够广泛,因而 C++23 又增加了 Ranges fold 系列算法。基本形式如下:

1std::ranges::fold_left(range, initval, func)
2std::ranges::fold_right(range, initval, func)
3std::ranges::fold_left_first(range, func)
4std::ranges::fold_right_last(range, func)

新的系列算法更加顾名思义,同时支持一元函数和二元函数,进一步简化了折叠方式。

同样提供一个例子:

1// fold algorithms
2int xs[] = { 12345 };
3auto concatl = [](std::string s, int i) { return s + std::to_string(i); };
4auto concatr = [](int i, std::string s) { return s + std::to_string(i); };
5
6auto fold_left  = ranges::fold_left(xs, std::string(), concatl);
7fmt::print("fold left: {}\n", fold_left);
8
9auto fold_right = ranges::fold_right(xs, std::string(), concatr);
10fmt::print("fold right: {}\n", fold_right);
11
12// Output:
13// fold left: 12345
14// fold right: 54321

C++ Fold Expressions

Fold 函数主要适用于列表或其他线性数据结构,而 C++17 Fold Expressions 则适用于可变模板参数包。不同的是,前者是 Library 级别的特性,而后者却是 Language 级别的特性,可用性更强。

Fold Expressions 同样支持一元及二元的左折叠和右折叠,形式如下:

1( pack op ... )          // Unary right fold
2( ... op pack )          // Unary left fold
3( pack op ... op init )  // Binary right fold
4( init op ... op pack )  // Binary left fold

逻辑其实都一样,只是语法形式稍异而已。... 在参数包的左边,就属于左折叠,在右边,就属于右折叠。但是,在二元折叠中,不能同时包含参数包,例如:

1// fr. C++20 standard §7.5.6 (ISO/IEC 14882:2020)
2template<typename ...Args>
3bool f(Args ...args) {
4  return (true && ... && args); // OK
5}
6
7template<typename ...Args>
8bool f(Args ...args) {
9  return (args + ... + args);   // error: both operands contain unexpanded packs
10}

而在一元折叠中,参数包若为空,只有以下三个操作符具有合法的默认值:

Operator

Value when parameter pack is empty

&&

true

||

false

,

void()

这三个操作符也恰恰是各种高级技巧的基石。

Smart Tricks with Fold Expressions

Fold 是更加抽象化的遍历方式,Fold Expressions 的核心作用就是替代传统的递归遍历方式。但是,要精细化控制这种遍历方式的各种细节,例如条件、中断、中间值、下标等,便需要各种高级技巧了。

本节便分别展示各种小技巧,将它们分布在各个算法当中。

Conditions and Counting

根据一个 Predicate 函数,计算符合条件的元素个数。

all_of 计算是否所有元素都满足条件,any_of 计算是否任一元素满足条件,count_of 计算满足条件的个数。实现如下:

1// Check whether all elements matches a predicate.
2auto all_of(auto F, auto... args) -> bool {
3    return (F(args) && ...);
4}
5
6// Check whether any elements matches a predicate.
7auto any_of(auto F, auto... args) -> bool {
8    return (F(args) || ...);
9}
10
11// Count the elements matches a predicate.
12auto count_of(auto Pred, auto... args) -> int {
13    return (Pred(args) + ...);
14}

标准中存在类似的算法给容器使用,实现都采用 std::find_if 之类的算法,查找算法内部又都涉及传统的循环遍历。对于参数包,若是用传统的递归来完成此类操作,相较也会麻烦,而 Fold Expressions 这种更高一级的遍历方式则显得灵活而简洁。

本技巧主要利用了两个特性,一个是 &&|| 所具有的 short-circuit 评估能力,可以用来实现条件和中断,另一个是 boolint 之间所存在的隐式转换,可以用来计数。

Random Access

任意访问参数包某个索引指向的元素。

首先,若是参数包的元素属于同构类型,可以通过以下方式访问其首尾元素。

1// Find the first element.
2auto first_of(auto... args) -> std::common_type_t<decltype(args)...> {
3    std::common_type_t<decltype(args)...> result;
4    ((result = args, true) || ...);
5    return result;
6}
7
8// Find the last element.
9auto last_of(auto... args) -> std::common_type_t<decltype(args)...> {
10    std::common_type_t<decltype(args)...> result;
11    (result = (args, ...));
12    return result;
13}

本处技巧主要是利用 ||, 的特性,保存中间值的过程中,决定是否继续往下走。同时,还用到 = 让右边的表达式先计算,从而保存最终结果。

其次,若是参数包是异构类型,可以借助 std::tuple 的索引式访问算法来实现,返回值只能依靠自动推导。

1auto generic_first_of(auto... args) {
2    auto values = std::forward_as_tuple(args...);
3    return std::get<1>(values);
4}
5
6auto generic_last_of(auto... args) {
7    auto values = std::forward_as_tuple(args...);
8    return std::get<sizeof...(args)-1>(values);
9}

这种方式本质就是利用已有的 std::tuple 算法来达到目的,虽说复杂度降低,但却要构造一个额外的对象。

最后,若是参数包的元素属于同构类型,不借助 std::tuple,可以通过以下方式实现索引式访问。

1// Find the nth element.
2template <std::size_t I>
3auto nth_of(auto... args) -> std::common_type_t<decltype(args)...> {
4    std::common_type_t<decltype(args)...> result;
5    std::size_t n{};
6    ((n++ == I ? (result = args, true) : false) || ...);
7    return result;
8}

手法结合了前面几个技巧,再以三目运算符作为条件,分发逻辑,false 时继续往下遍历,true 时结束遍历。

Maximum and Minimum

取同构元素列表的最大最小值。

同样需要保存中间结果,但不需要中断遍历流程,实现如下:

1// Find the minimum element.
2auto min_of(auto... args) -> std::common_type_t<decltype(args)...> {
3    auto min = (args, ...);
4    ((min > args ? min = args : 0), ...);
5    return min;
6}
7
8// Find the maximum element.
9auto max_of(auto... args) -> std::common_type_t<decltype(args)...> {
10    auto max = (args, ...);
11    ((max < args ? max = args : 0), ...);
12    return max;
13}

不中断的情况下,使用 , 展开更加便捷。

Reverse Packs

逆转列表元素位置,返回转换后的列表。

因为需要返回一个列表,所以只得借助 std::tuple,再反转 std::tuple。实现如下:

1auto reverse_of(auto... args) {
2    auto tuple = std::make_tuple(args...);
3    return [&tuple]<auto... I>(std::index_sequence<I...>) {
4        return std::make_tuple(std::get<sizeof...(args)-1-I>(std::forward<decltype(tuple)>(tuple))...);
5    }(std::index_sequence_for<decltype(args)...>{});
6}

此处便出现了第六章介绍的 Compile-time for,这种技巧可以在编译期遍历并操作 std::tuple,本质就是创造一个索引参数包,再以 Fold Expressions 遍历。

Overload pattern

Fold Expressions 相关联的另一个技术称为 Overload pattern,这项技术通过展开参数包实现多继承,以在视觉层面模拟 Lambda 重载。代码只有两行:

1template<class... Tsstruct overloaded : Ts... { using Ts::operator()...; };
2template<class... Tsoverloaded(Ts...) -> overloaded<Ts...>;

短短两行代码中,除了使用参数包展开实现多继承,还借助 Using-declration 来绕开重载合并规则,避免重载歧义。这些特性和 Fold Expressions 类似,都能展开参数包,生成重复代码。

Lambda 重载是一种灵活的定制方式,以前展示过对象工厂和抽象工厂的应用,下面看一个与 Fold Expressions 一起使用的新例子。

Implement any_visit with Fold Expressions

std::any 采用类型擦除技术,实现了异构类型表示,可以和容器类型结合起来构成异构容器。std::variant也能够起到类似的作用,标准提供 std::visit 来访问元素,因此可以直接使用已有算法来遍历:

1using value_type = std::variant<intdoublestd::string>;
2std::vector<value_type> container;
3container.push_back(5);
4container.push_back(0.42);
5container.push_back("hello");
6
7// Iterate the heterogeneous container
8std::ranges::for_each(container, [](const value_type& value) {
9    std::visit([](const auto& x){ std::print("{} ", x); }, value);
10});

std::any 没有对应的 visit 访问函数,只能通过下面这种 type-switch 方式访问:

1for (const auto& a : container) {
2    if (a.type() == typeid(int)) {
3        const auto& value = std::any_cast<int>(a);
4    } else if (a.type() == typeid(const char*)) {
5        const auto& value = std::any_cast<const char*>(a);
6    } else if (a.type() == typeid(bool)) {
7        const auto& value = std::any_cast<bool>(a);
8    }
9}

重复、变化、繁琐……于是实现一个 any_visit 的想法顿时出现,而这个遍历就可以用 Fold Expressions 来抽象得更高一级。实现为:

1template <class... Ts>
2void any_visit(auto fconst std:
:any& a) {
3    ((std::type_index(a.type()) == std::type_index(typeid(Ts))
4        && (f(std::any_cast<Ts>(a)), true)) || ...);
5}

可以细品一下这个实现是如何借助 &&||, 消除 forif 的,技巧都是前面讲过的内容。有了这个工具,便可以像 std::visit 那样,借助算法来迭代 std::any 构成的异构容器。例子:

1std::vector<std::any> container { 50.42"hello"false };
2// Output: 5 0.42 hello boolean: false
3std::ranges::for_each(container, [](const auto& a) {
4    any_visit<intdoubleconst char*>([](const auto& x) { std::print("{} ", x); }, a);
5    any_visit<bool>([](const auto& x) { std::print("boolean: {} ", x); }, a);
6});

简单是简单,但由于 Fold Expressions 要借助参数包展开,模板参数的变化性依旧没有消除,而这些信息其实可以表现到 Lambda 参数之中,以 Overload pattern 封装这部分变化。即目标用法变成这样:

1std::ranges::for_each(container, [](const auto& a) {
2    any_visit(overloaded {
3            [](int x) { std::print("int: {} ", x); },
4            [](double x) { std::print("double: {} ", x); },
5            [](std::string_view x) { std::print("string: {} ", x); },
6            [](bool x) { std::print("bool: {} ", x); }
7        }, a);
8    });

如此一来,不仅可以精确处理每一个异构类型,还不用重复调用 any_visit,同时也消除了显式模板参数。overloaded 在上一节已然介绍,那么现在只剩下一个关键问题——如何获取 Lambda 的参数类型?解决了这个问题,实现便也水到渠成了。

Lambda 是一个可以携带状态的函数,其实现是一个含有 operator() 重载的匿名类,捕获的参数作为匿名类的数据成员直接初始化。Lambda 使用时调用的便是这个重载的 operator(),返回的类型就是匿名类的类型,称为 closure type。因此,问题进一步转化为如何获取成员函数 operator() 的参数类型,通过第五、第六章的高级模板内容,获取起来犹如探囊取物。

方法就是把想要的类型,通过模板参数,显式写出来:

1// For a function pointer
2template<typename R, typename Arg, typename... Rest>
3Arg extract_first_arg(R(*) (Arg, Rest...));
4
5// For a member function pointer without a qualifier
6template<typename R, typename F, typename Arg, typename... Rest>
7Arg extract_first_arg(R(F::*) (Arg, Rest...));
8
9// For a const-qualified member function pointer
10template<typename R, typename F, typename Arg, typename... Rest>
11Arg extract_first_arg(R(F::*) (Arg, Rest...) const);
12
13// ...

这里只写了支持函数指针和带基本修饰的 Lambda 函数,更多修饰可以接着往下写。这些函数都属于稻草人函数,不实际使用,只提取类型,故无需实现。接着,通过 decltype() 将函数模板的返回类型提取出来:

1template<typename L>
2using lambda_arg_t = decltype(extract_first_arg(&L::operator()));

至此,最复杂的问题便解决了。

最后,利用 Overload pattern 和 Fold expressions 访问 std::any 构成的异构容器。代码为:

1template<class... Tsstruct overloaded : Ts... { using Ts::operator()...; };
2template<class... Tsoverloaded(Ts...) -> overloaded<Ts...>;
3
4template<typename... Lam>
5void any_visit(const overloaded<Lam...>& f, auto&& any)
6
{
7    ((std::type_index(any.type()) == std::type_index(typeid(lambda_arg_t<Lam>))
8        && (f(std::any_cast<lambda_arg_t<Lam>>(any)), true)) || ...);
9}

寥寥数行代码,便解决了一个相对复杂的问题,这就是 Fold expressions 的威力。

Related Features and Discussions

Fold expressions 是专门针对参数包的 Fold 特性,好用是好用,但是前提得先有参数包,否则也是巧妇难为无米之炊。

比如我们要循环输出 5 次 hello fold expressions,那前提是先得有五个模板参数,在哪儿凭空产生固定个数的模板参数呢?这其实就是第六章所介绍的 Compile-time for 技术,采用该技术,可以这样实现:

1[]<auto... Is>(std::index_sequence<Is...>) {
2    ((std::println("hello fold expressions"), Is), ...);
3}(std::make_index_sequence<5>{});

这能达到预期效果,但显得复杂,更本质的问题在于 C++ 缺少直接创建参数包的能力,std::index_sequence 也只是扬汤止沸的产物,没有解决根本问题。而 Circle 便支持直接创建参数包,于是可以简洁地完成以下功能:

1struct obj_t {
2  int x;
3  double y;
4  std::string s;
5};
6
7int main() {
8  auto f = [](const char* name, auto... i) {
9    std::cout<< name<< ":\n";
10    std::cout<< "  "<< i<< "\n" ...;
11  };
12
13  // Just expand a pack into an argument list.
14  f("integers"int...(5) ...);
15
16  obj_t obj { 1003.14"A string" };
17  f("object", obj...);
18}
19
20// Compiler Explorer: https://godbolt.org/z/4K61vM9hf

int...(5) ... 可以直接创建一个 int 型参数包并展开,obj... 可以直接将结构体的成员转换为参数包。有了这种创建参数包的能力,不但能够简化代码,而且可以直接折叠结构体,极大增强操纵模板参数的能力。

不过,随着 C++26 Pack structure bindings 和 Pack indexing 的加入,在一定程度上能够改善这部分问题。那时 Compile-time for 便无需借助 Lambda 充当辅助函数了,可以直接这样写:

1auto [...Is] = std::make_index_sequence<5>{};
2((std::println("hello fold expressions"), Is), ...);

这才是最直接的方式,Fold expressions 如今主要就受限于参数包的特性不足。

Ternary Right Fold Expression

Fold expressions 在遍历时,还缺少一种处理错误的方式,看如下例子:

1template<std::size_t... Is>
2auto test_impl(std::size_t j, std::index_sequence<Is...>)
3
{
4    return ((j == Is ? (std::println("found"), true) : 0) || ...);
5}
6
7template<std::size_t N>
8auto test(std::size_t j)
9
{
10    return test_impl(j, std::make_index_sequence<N>{});
11}
12
13int main() {
14    test<5>(5);
15}

当遍历查找失败时,Fold expressions 无法处理错误,而 Ternary Right Fold Expression 就是解决这个问题的,代码变成:

1template<std::size_t... Is>
2auto test_impl(std::size_t j, std::index_sequence<Is...>)
3
{
4    return ((j == Is ? std::println("found")
5        : ... : throw std::range_error("Out of range"));
6}

可以看到,这个特性能够消除 ||,,进一步简化代码,并且在查找失败时,可以处理错误。

总结一下,它的语法格式是这样的:

1( C ? E : ... : D )

展开变为:

1( C(1) ? E(1) : ( ... ( C(N-1) ? E(N-1) : ( C(N) ? E(N) : D ) ) ) )

这个特性可能会入 C++26。

Conclusion

本章更为全面而深入地介绍了 Fold Expressions,这是元编程的编译期遍历方式,能够消除传统递归和循环,简化代码,减少重复。

Fold 这个概念从函数式编程而来,是一种抽象等级更高的遍历函数,因此不同于传统面向对象和面向过程范式中的递归和循环遍历,表达起来更加自然。示例代码依旧非常丰富,展示了各种高级技巧,这些技巧用于提供条件、中断、中间值等能力。

同时,Fold Expressions 也有不足之处,主要不足源于参数包特性和错误处理模块的缺失,这将在 C++26 进一步得到解决。




CppMore
Dive deep into the C++ core, and discover more!
 最新文章