0%

C++ Coroutine (1)

看到很多文章,總覺得從co_await開始講解實在很難清楚表達
寫了自己的版本當作筆記

The simplest coroutine

這程式碼什麼都沒幹,不過就是一個最小的coroutine了
一個coroutine至少要有co_returnco_awaitco_yield其中之一

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
using namespace std;
#if defined(__clang__)
#include <experimental/coroutine>
using namespace std::experimental;
#else
#include <coroutine>
#endif

struct Task {
struct promise_type {
auto initial_suspend() { return suspend_never{}; }
auto final_suspend() noexcept { return suspend_never{}; }
auto get_return_object() { return Task(coroutine_handle<promise_type>::from_promise(*this)); }
void return_void() {}
void unhandled_exception() {}
};
Task(coroutine_handle<promise_type> h) : handle(h) {}
~Task() {
if (handle)
handle.destroy();
}
coroutine_handle<promise_type> handle;
};

Task coroutineDemo()
{
co_return;
}

int main() {
auto task = coroutineDemo();
return 0;
}

Under the hood

Compilier做了很多事情
像這樣的Psuedo Code

1
2
3
4
5
template <typename TRet, typename … TArgs>
TRet func(TArgs args…)
{
body;
}

被Compilier處理之後大概變成這個樣子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
template <typename TRet, typename ... TArgs>
TRet func(TArgs args...)
{
using promise_t = typename coroutine_traits<TRet, TArgs...>::promise_type;

promise_t promise;
auto __return__ = promise.get_return_object();

co_await promise.initial_suspend();

try
{ // co_return expr; => promise.return_value(expr); goto final_suspend;
body; // co_return; => promise.return_void(); goto final_suspend;
} // co_yield expr; => co_await promise.yield_value(expr);
catch (...)
{
promise.unhandled_exception();
}

final_suspend:
co_await promise.final_suspend();
}

先忽略co_await的語句,之後補上
因此我們可以看到prmise_type裡面有initial_suspendfinal_suspend等function,這個promise_type是定義coroutine的行為模式
之後會對promise_type做更進一步的說明,這邊就此打住

How the compiler chooses the promise type

看到上面的Pseudo code

1
using promise_t = typename coroutine_traits<TRet, TArgs...>::promise_type;

然後看看coroutine_traits的定義

1
2
3
4
5
6
7
8
9
10
template <class _Ret, class = void>
struct _Coroutine_traits {};

template <class _Ret>
struct _Coroutine_traits<_Ret, void_t<typename _Ret::promise_type>> {
using promise_type = typename _Ret::promise_type;
};

template <class _Ret, class...>
struct coroutine_traits : _Coroutine_traits<_Ret> {};

就是看TRet裡面有沒有promise_type的struct definition了
者李又分成兩類

直接定義在class裡面

就是我們範例那個做法,簡單直接

將promise_type抽出

當你有一群Coroutine,然後這群Coroutine雖然有些許不同,但是對Coroutine的控制流程相同,就可以用這方案

1
2
3
4
5
6
7
8
9
10
11
12
13
template <typename T>
struct Promise {
auto initial_suspend() { return suspend_never{}; }
auto final_suspend() noexcept { return suspend_never{}; }
auto get_return_object() { return T(coroutine_handle<Promise>::from_promise(*this)); }
void return_void() {}
void unhandled_exception() {}
};

struct Task {
// Ignore
using promise_type = Promise<Task>;
};

著名的C++ coroutine都使用此方式

coroutine_handle

看到我們的Task當中有coroutine_handle了嗎

1
2
3
struct Task {
coroutine_handle<promise_type> handle;
};

這個才是coroutine的本體,負責讓暫停的coroutine繼續執行,或是判斷coroutine是否執行完畢
將原先的範例加強一下,來示範coroutine_handle該怎麼使用

A failure example

原本我想寫個像這樣的程式加深對coroutine_handle的使用

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#include <iostream>
using namespace std;
#if defined(__clang__)
#include <experimental/coroutine>
using namespace std::experimental;
#else
#include <coroutine>
#endif

struct Task {
struct promise_type {
auto initial_suspend() { return suspend_always{}; }
auto final_suspend() noexcept { return suspend_never{}; }
auto get_return_object() { return Task{ coroutine_handle<promise_type>::from_promise(*this) }; }
void return_void() {}
void unhandled_exception() {}
};
coroutine_handle<promise_type> handle;
~Task() {
if (handle)
handle.destroy();
}
void resume() { handle.resume(); }
bool done() const { return handle.done(); }
};

Task coroutineDemo(int times)
{
for (size_t i = 0; i < times; i++) {
cout << "coroutineDemo\n";
co_await suspend_always{};
}
co_return;
}

int main() {
auto task = coroutineDemo(3);
while (!task.done()) {
task.resume();
}
std::cout << "Done\n";
return 0;
}

結果不如預期,發現問題出在

Once execution propagates outside of the coroutine body then the coroutine frame is destroyed. Destroying the coroutine frame involves a number of steps:

  1. Call the destructor of the promise object.
  2. Call the destructors of the function parameter copies.
  3. Call operator delete to free the memory used by the coroutine frame (optional)
  4. Transfer execution back to the caller/resumer.

問題就出在3這步

  • Visual C++不會在coroutine body結束時立刻delete coroutine frame
  • GCC和Clang會,當coroutine frame destroy之後,呼叫handle.done()handle.resume()是use-after-free

Solution

研究了一下, 只要將final_suspend改成

1
2
3
4
struct promise_type {
// ignored
auto final_suspend() noexcept { return suspend_always{}; }
};

讓它在coroutine結束之前停下來就可以了

Reference