背景

最近对 coroutine 比较感兴趣,看了几个 coroutine 的实现,发现逻辑上都是大同小异,原理也差不多,有些细节可能不同的实现做法不一样(比如 context switch 的逻辑)。

看来看去还是 cloudwu 的实现最干净漂亮,所以最后顺便简单分析一下他写的协程库。

ucontext 函数族

ucontext 函数有 4 个,如下所示:

1
2
3
4
5
6
7
8
9
#include <ucontext.h>

// 用户上下文的获取和设置
int getcontext(ucontext_t *ucp);
int setcontext(const ucontext_t *ucp);

// 操纵用户上下文
void makecontext(ucontext_t *ucp, void (*func)(void), int argc, ...);
int swapcontext(ucontext_t *oucp, const ucontext_t *ucp);

ucontext 函数用户进程内部的 context 控制,帮助用户更方便实现 coroutine,可视为更先进的 setjmp/longjmp。

4 个函数都依赖于 ucontext_t 类型,这个类型大致为:

1
2
3
4
5
6
7
typedef struct {
    ucontext_t *uc_link;
    sigset_t    uc_sigmask;
    stack_t     uc_stack;
    mcontext_t  uc_mcontext;
    ...
} ucontext_t;

其中:

  • uc_link:当前上下文结束时要恢复到的上下文,其中上下文由 makecontext() 创建;

  • uc_sigmask:上下文要阻塞的信号集合;

  • uc_stack:上下文所使用的 stack;

  • uc_mcontext:其中 mcontext_t

类型与机器相关的类型。这个字段是机器特定的保护上下文的表示,包括协程的机器寄存器;

这几个 API 的作用:

getcontext()

将当前的 context 保存在 ucp 中。成功返回 0,错误时返回 -1 并设置 errno;

setcontext()

恢复用户上下文为 ucp 所指向的上下文,成功调用不用返回。错误时返回 -1 并设置 errno。 ucp 所指向的上下文应该是 getcontext() 或者 makecontext() 产生。 如果上下文是由 getcontext() 产生,则切换到该上下文后,程序的执行在 getcontext() 后继续执行。比如下面这个例子每隔 1 秒将打印 1 个字符串:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
int main(void)
{
    ucontext_t context;

    getcontext(&context);
    printf("Hello world\n");
    sleep(1);
    setcontext(&context);
    return 0;
}

如果上下文是由 makecontext() 产生,切换到该上下文,程序的执行切换到 makecontext() 调用所指定的第二个参数的函数上。当函数返回后,如果 ucp.uc_link 为 NULL,则结束运行;反之跳转到对应的上下文。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
void foo(void)
{
    printf("foo\n");
}
    
int main(void)
{
    ucontext_t context;
    char stack[1024];
       
    getcontext(&context);
    context.uc_stack.ss_sp = stack;
    context.uc_stack.ss_size = sizeof(stack);
    context.uc_link = NULL;
    makecontext(&context, foo, 0);
       
    printf("Hello world\n");
    sleep(1);
    setcontext(&context);
    return 0;
}

以上输出 Hello world 之后会执行 foo(),然后由于 uc_link 为 NULL,将结束运行。

下面这个例子:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
void foo(void)
{
    printf("foo\n");
}
    
void bar(void)
{
    printf("bar\n");
}
    
int main(void)
{
    ucontext_t context1, context2;
    char stack1[1024];
    char stack2[1024];
       
    getcontext(&context1);
    context.uc_stack.ss_sp = stack1;
    context.uc_stack.ss_size = sizeof(stack1);
    context.uc_link = NULL;
    makecontext(&context1, foo, 0);
       
    getcontext(&context2);
    context.uc_stack.ss_sp = stack2;
    context.uc_stack.ss_size = sizeof(stack2);
    context.uc_link = &context1;
    makecontext(&context1, bar, 0);
        
    printf("Hello world\n");
    sleep(1);
    setcontext(&context2);
        
    return 0;
}

此时调用 makecontext() 后将切换到 context2 执行 bar(),然后再调用 context1foo()。由于 context1uc_linkNULL,程序停止。

makecontext()

修改 ucp 所指向的上下文;

swapcontext()

保存当前的上下文到 ocup,并且设置到 ucp 所指向的上下文。成功返回 0,失败返回 -1 并设置 errno。

如下面这个例子所示:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#include <stdio.h>
#include <ucontext.h>
    
static ucontext_t ctx[3];
    
static void
f1(void)
{
    printf("start f1\n");
    // 将当前 context 保存到 ctx[1],切换到 ctx[2]
    swapcontext(&ctx[1], &ctx[2]);
    printf("finish f1\n");
}
    
static void
f2(void)
{
    printf("start f2\n");
    // 将当前 context 保存到 ctx[2],切换到 ctx[1]
    swapcontext(&ctx[2], &ctx[1]);
    printf("finish f2\n");
}
    
int main(void)
{
    char stack1[8192];
    char stack2[8192];
    
    getcontext(&ctx[1]);
    ctx[1].uc_stack.ss_sp = stack1;
    ctx[1].uc_stack.ss_size = sizeof(stack1);
    ctx[1].uc_link = &ctx[0]; // 将执行 return 0
    makecontext(&ctx[1], f1, 0);
    
    getcontext(&ctx[2]);
    ctx[2].uc_stack.ss_sp = stack2;
    ctx[2].uc_stack.ss_size = sizeof(stack2);
    ctx[2].uc_link = &ctx[1];
    makecontext(&ctx[2], f2, 0);
    
    // 将当前 context 保存到 ctx[0],切换到 ctx[2]
    swapcontext(&ctx[0], &ctx[2]);
    return 0;
}      

此时将输出:

1
2
3
4
start f2
start f1
finish f2
finish f1

实现一个协程库

参考云风写的 coroutine,分析一下他写的库。

协程是一种共享堆,不共享栈,由用户主动调度的执行体(一般需要提供 yield 和 resume 语义)。

这个库实现基于多个协程共享栈的方式。但是每个 coroutine 都会从 heap 上分配内存来保存自己 stack 的内容,当前运行实只有一个 stack。

协程的状态

一个协程有 4 种状态:

1
2
3
4
#define COROUTINE_DEAD    0   // 结束
#define COROUTINE_READY   1   // 准备好调度
#define COROUTINE_RUNNING 2   // 运行中
#define COROUTINE_SUSPEND 3   // 挂起(即暂停中)

schedule 数据结构

schedule 对象用来管理所有 coroutine 的状态。所有 API 都必须传入 schedule 对象才可以操作具体的 coroutine。

一个 schedule 数据结构为:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
struct schedule {
    char stack[STACK_SIZE];
    ucontext_t main;  // 发生调度的上下文
    int nco;          // coroutine 的数量
    int cap;          // 可支持的 coroutine 数量
    int running;      // 当前正在执行的 coroutine id
    
    // 当前 schedule 下的 coroutine 列表,每个 coroutine 以 id 标识
    struct coroutine **co; 
};

创建一个新的 schedule 数据结构:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
// 使用 coroutine 库时的初始化语句就是创建一个 schedule
struct schedule *
coroutine_open(void) {
    struct schedule *S = malloc(sizeof(*S));
    S->nco = 0;
    S->cap = DEFAULT_COROUTINE; // 设置为 16
    S->running = -1;
    
    // 一个有 cap 个大小的 coroutine 指针数组
    S->co = malloc(sizeof(struct coroutine *) * S->cap);
    memset(S->co, 0, sizeof(struct coroutine *) * S->cap);
    
    return S;
}

coroutine 数据结构

一个典型的 coroutine 必须拥有自己的栈,此处默认设置为:

1
2
// 超过了 MMAP_THRESHOLD(128 KB),将使用 mmap() 创建匿名映射
#define STACK_SIZE (1024*1024) // 1MB

一个 coroutine 实例表示为:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
struct coroutine {
    coroutine_func func; // coroutine 将要指向的函数
    void *ud; // 指向用户数据的指针
    ucontext_t ctx;
    struct schedule *sch; // 全局的 schedule 对象
    ptrdiff_t cap; // coroutine 栈的大小
    ptrdiff_t size; // coroutine 栈的当前大小
    int status; // 当前的状态
    char *stack; // coroutine 保存 S->stack 中内容的栈
};

其中 func 定义为:

1
typedef void (*coroutine_func)(struct schedule *, void *ud);

创建一个新的 coroutine 对象使用的内部函数为:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
struct coroutine *
_co_new(struct schedule *S, coroutine_func func, void *ud) {
    struct coroutine *co = malloc(sizeof(*co));
    co->func = func;
    co->ud = ud;
    co->sch = S;
    co->cap = 0;
    co->size = 0;
    co->status = COROUTINE_READY;
    co->stack = NULL;
    return co;
}

删除一个 coroutine 对象(这里应该加上一些必要的保护才对):

1
2
3
4
5
6
7
void
_coroutine_delete(struct coroutine *co) {
    // 释放 coroutine 自己分配的 stack(C->stack)
    free(co->stack);
    // 释放 coroutine 对象
    free(co);
}

栈的管理

每个 coroutine 运行时都共享使用 S->stack (即大小为 1MB),当发生 yield 动作时,coroutine 会调用 _save_stackS->stack 的内容 copy 到自己的 C->stack 上。当下一次获取到 CPU 时(发生 resume 动作时),则将 C->stack 上的内容 memcpy 到 S->stack 上,然后开始执行(swapcontext()

其中 _save_stack() 的函数如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
static void
_save_stack(struct coroutine *C, char *top) {
    // dummy 将在 S->stack 上进行分配
    // top 指向 S->stack 的栈顶
    char dummy = 0;
    // 如果 C->stack 不够放下 coroutine 在 S->stack 上的内容
    // 重新进行分配
    if (C->cap < top - &dummy) {
        free(C->stack); // 如果是 NULL,free 没什么影响
        C->cap = top-&dummy;
        C->stack = malloc(C->cap);
    }
    C->size = top - &dummy;
    // 将以 dummy 为开始的 size 大小的数据保存到 C->stack 上
    // C->stack 是在 heap 上
    memcpy(C->stack, &dummy, C->size);
}

对协程的操作

当使用 coroutine_open() 创建了全局的 schdule 对象后,我们可以使用 coroutine_new() 来创建新的协程:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
int
coroutine_new(struct schedule *S, coroutine_func func, void *ud) {
    // 创建一个 coroutine 对象
    struct coroutine *co = _co_new(S, func, ud);
    
    // 如果 schedule 中的 coroutine 对象数量已经超过限定值
    // 扩容 2 倍
    if (S->nco >= S->cap) {
        int id = S->cap;
        
        // 调用 realloc() 扩容 2 倍 S->cap
        S->co = realloc(S->co, S->cap * 2 * sizeof(struct coroutine *));
        memset(S->co + S->cap, 0, sizeof(struct coroutine *) * S->cap);
        S->co[S->cap] = co;
        S->cap *= 2;
        ++S->nco;
        return id;
    } else {
        int i;
        // 遍历 coroutine 列表,找到一个空闲位置
        // 实际应该不需要遍历
        for (i = 0; i < S->cap; i++) {
            int id = (i + S->nco) % S->cap;
            if (S->co[id] == NULL) {
                S->co[id] = co;
                ++S->nco;
                return id;
            }
        }
    }
    
}

让 schedule 中某个 id 的 coroutine 启动执行:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
void
coroutine_resume(struct schedule *S, int id) {
    // 获取对应 id 的 coroutine 对象
    struct coroutine *C = S->co[id];
    if (C == NUUL)
        return;
    int status = C->status;
    
    // 根据 coroutine 的状态做分支
    switch (status) {
    // 如果是从来没有执行过的 coroutine
    case COROUTINE_READY:
        getcontext(&C->ctx);
        C->ctx.uc_stack.ss_sp = S->stack;  // 将一直使用这个栈
        C->ctx.uc_stack.ss_size = STACK_SIZE;
        C->ctx.uc_link = &S->main; // 回到主函数中
        S->running = id;
        C->status = COROUTINE_RUNNING; // 将状态标记为 运行中
        uintptr_t ptr = (uintptr_t)S;
        // 将 C->ctx 指向 mainfunc 函数,并把 schedule 指针地址传递过去
        // mainfunc 是用来执行 coroutine 的函数
        makecontext(&C->ctx, (void (*)(void)) mainfunc, 2, (uint32_t)ptr, (uint32_t)(ptr>>32));
        // 将内存保存在 S->main 中,切换到 C->ctx
        swapcontext(&S->main, &C->ctx);
    case COROUTINE_SUSPEND:
        // 如果是之前运行过的,就把 C->stack 的 C->size 内容复制到 S->stack + STACK_SIZE - C->size 上
        memcpy(S->stack + STACK_SIZE - C->Size, C->stack, C->size);
        S->running = id;
        C->status = COROUTINE_RUNNING;
        swapcontext(&S->main, &C->ctx);
        break;
    default:
        assert(0);
    }
}

由于 coroutine 需要主动让出 CPU,所以必须实现 yield 语义:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
void
coroutine_yield(struct schedule *S) {
    // 获取当前正在执行的 coroutine id
    int id = S->running;
    
    struct coroutine *C = S->co[id];
    // coroutine 自己会在 heap 上分配一个 stack,让 coroutine 把在 S->stack 上的内容
    // memcpy 到 C->stack 上
    _save_stack(C, S->stack + STACK_SIZE);
    C->status = COROUTINE_SUSPEND;
    S->running = -1;
    // 保存 coroutine 的 context,切到 main
    swapcontext(&C->ctx, &S->main);
}

记住:coroutine 是由用户层来进行调度的(yield 和 resume),所以不存在调度算法。用户想让谁执行就让谁执行。

参考文档

  1. setcontext
  2. ucontext 簇函数学习