C++实现Mark & Copy 算法

发布时间 2023-04-03 20:25:57作者: fallen_leaves

Mark & Copy 算法

1、引言

Garbage Collector(GC)广泛存在于现代流行的编程语言之中,比如Java,C#,Python等。笔者认为结合源代码学习可以更加有效地了解一个算法,但是在网上找到更多的是Mark&Sweep算法,如http://journal.stuffwithstuff.com/2013/12/08/babys-first-garbage-collector/,而对于Mark&Copy的文章较少,所以本文使用C++实现一个简单的基于Mark&Copy的GC,以实际的代码理解此算法。


2、算法介绍

Mark&Copy算法需要两片内存,当我们进行垃圾回收的时候,将其中一片内存上所有的可达对象复制到另一片内存上,然后清空第一片内存。接下来我们详细介绍一下算法流程。

一般我们的数据类型可以分为两种,一种是不会引用其他对象的类型,比如数字,字符串等类型,一种是会引用其他对象的类型,比如数组。在本文中,我们使用整数来代表第一种类型,使用列表来代表第二种类型。

假设我们有两片内存,heap1和heap2,一组根节点roots,在heap1中一共有4个对象,分别为1,2,list1, list2。其中,list1中两个元素,一个是1,一个是list2,而list2中有一个元素,即list1。roots中有三个引用,分别引用了1,list1,list2。heap2是一片空的内存,这里先不展示。

图1

接下来我们需要对roots中的引用进行遍历,将所有的可达对象Copy到heap2中,如下所示:

图2

注意,这个时候我们仅仅是将对象从heap1拷贝到heap2,但是对于list1和list2来说,它们的引用依旧指向heap1上的对象,我们后续还要对于heap2上的对象进行遍历,更新它们的引用。更新完之后如下图所示。

图3

最后,我们更新一下roots中的所有引用,让他们指向heap2中的对象即可。

图4

到此为止,我们的算法就结束了。


3、C++实现

3.1、定义内存块

class memory_chunk
{
    constexpr static std::size_t Alignment = 16; 

    constexpr static std::size_t DefaultSize = 1024; 

public:

    // [m_start, m_end)表示已经分配的内存呢,[m_end, m_end_of_capacity)表示未分配的内存
    std::byte* m_start;
    std::byte* m_end;
    std::byte* m_end_of_capacity;

    // 已经分配的对象地址
    std::vector<struct object*> m_objects;

public:

    memory_chunk(std::size_t size = DefaultSize)
    {
        // 设置对齐
        m_start = (std::byte*)::operator new(size, std::align_val_t(Alignment));
        m_end = m_start;
        m_end_of_capacity = m_start + size; 
    }

    ~memory_chunk()
    {
        ::operator delete(m_start, std::align_val_t(Alignment));
    }

    // 释放对象,清空内存
    void clear()
    {
        std::ranges::destroy(m_objects);
        m_objects.clear();  
        m_end = m_start;
    }

    // 判断obj是否位于该内存块上
    bool has_object(void* obj) const
    {
        return m_start <= obj && obj < m_end_of_capacity;
    }

    // 分配内存
    void* allocate(std::size_t size)
    {
        size = (size + 7) & -8;
        auto addr = m_end;
        m_end += size;
        m_objects.emplace_back((object*)addr);
        return addr;
    }
};

3.2、定义数据类型

using list = std::vector<struct object*>;
struct object
{
    std::variant<int, list> m_value;

    // 如果当前对象已经被拷贝到另一片内存上,那么m_forward表示拷贝之后的地址,否在为0
    object* m_forward = nullptr;

    // 定义构造函数
    object(int i) : m_value(std::in_place_index<0>, i) { }

    object(list ls) : m_value(std::in_place_index<1>, std::move(ls)) { }

    // 一些辅助函数
    bool is_int() const
    {
        return std::holds_alternative<int>(m_value);
    }

    int& as_int() 
    {
        return std::get<int>(m_value);
    }

    list& as_list()
    {
        return std::get<list>(m_value);
    }
};

3.3、定义变量

我们额外定义一些变量。其中,roots包含了所有根节点(比如虚拟机栈中的元素),凡是从roots中出发可以遍历到的对象都是可达的。references表示这些可达对象中的数组对象,因为我们后续需要更新数组对象的引用,我们在这里需要将它们保存下来。heap1和heap2代表内存。

std::vector<struct object*> roots;
std::vector<struct object*> references;
MemoryChunk heap1;
MemoryChunk heap2;

3.4、算法实现

接下来我们定义mark函数,对于整数类型,我们只需要简单标记一下即可,对于数组这种会引用其他对象的类型,我们需要递归的对其所引用的对象进行标记。

void mark()
{
    if (m_forward)
        return; // 当前对象已经标记过了,直接退出

    auto new_address = (object*)heap2.allocate(sizeof(object));

    // 标记
    m_forward = new_address;

    std::cout << "Copy object from " 
        << std::hex << this << " to "
        << std::hex << new_address << ' ';

    // 将对象拷贝到heap2上
    if (is_int())
    {
        std::cout << "mark int\n";
        std::construct_at(new_address, std::get<int>(m_value));
    }
    else
    {
        std::cout << "mark list\n";
        // 我们需要将当前的对象copy到heap2中, C++需要注意一下深浅拷贝以及构造析构
        // 从算法流程中不难看出我们这里需要采用浅拷贝
        std::construct_at(new_address, std::move(std::get<list>(m_value)));

        // 用于后续更新引用对象
        references.emplace_back(new_address);

        // 递归地标记引用对象
        std::ranges::for_each(
            new_address->as_list(), 
            [](auto obj) { obj->mark(); }
        );
    }
}

更新引用的函数很简单,只需要将数组里面所有的地址更新一下即可,这个过程并不需要递归的进行。

void update_reference()
{
    if (is_int())
        return;
    std::ranges::for_each(as_list(), [](auto& o) { o = (object*)o->m_mark; });
}

我们按照图1所示构建对象和引用。

// 构建对象
auto i1 = (object*)heap1.allocate(sizeof(object));
auto i2 = (object*)heap1.allocate(sizeof(object));
auto list1 = (object*)heap1.allocate(sizeof(object));
auto list2 = (object*)heap1.allocate(sizeof(object));

std::construct_at(i1, 1);
std::construct_at(i2, 2);
std::construct_at(list1, list());
std::construct_at(list2, list());

// 添加引用
roots.emplace_back(i1);
roots.emplace_back(list1);
roots.emplace_back(list2);

list1->as_list().emplace_back(i1);
list1->as_list().emplace_back(list2);

list2->as_list().emplace_back(list1);

使用Mark&Copy将所有可达对象转移到新的内存上:

// mark & copy
for (auto obj : roots)
{
    obj->mark();
}

// 更新引用
for (auto obj : references)
{
    obj->update_reference();
}

// 更新roots
for (auto& obj : roots)
{
    obj = (object*)obj->m_mark;;
}

heap1.clear();

references换成heap2.m_objects也是一样的,但是对于number或者string这样的类型,无需更新,而且在程序中这些类型一般远多于数组类型,即references里面元素的数量应该是远小于heap2.m_objects里面的元素数量。

到此位置,一次GC过程就结束了,更新完之后所有的对象都转移到了heap2上。运行后可能的得到输出:

Copy object from 0x224608f24a0 to 0x224608f28c0 mark int 
Copy object from 0x224608f24f0 to 0x224608f28e8 mark list
Copy object from 0x224608f2518 to 0x224608f2910 mark list
OK

可以看出,所有可达对象都被移动到了新的地址,而不可达对象被留在了heap1上,最终被清理。


3.5、其他

我们在memory_chunk使用了一个m_objects去记录该内存上的所有对象而没有使用如下方式:

for (auto addr = m_start; addr != m_end; addr += sizeof(object))
{
    std::destroy_at((object*)addr);
}

因为有时候不同的object大小可能是不同的:

struct object { };
struct integer : object { int value; };
struct string : object { std::string value; };

4、代码整理

#include <cstddef>
#include <algorithm>
#include <vector>
#include <iostream>
#include <variant>
#include <memory>
#include <assert.h>

class memory_chunk
{
    constexpr static std::size_t Alignment = 16; 

    constexpr static std::size_t DefaultSize = 1024; 

public:

    // [m_start, m_end)表示已经分配的内存呢,[m_end, m_end_of_capacity)表示未分配的内存
    std::byte* m_start;
    std::byte* m_end;
    std::byte* m_end_of_capacity;

    // 已经分配的对象地址
    std::vector<struct object*> m_objects;

public:

    memory_chunk(std::size_t size = DefaultSize)
    {
        // 设置对齐
        m_start = (std::byte*)::operator new(size, std::align_val_t(Alignment));
        m_end = m_start;
        m_end_of_capacity = m_start + size; 
    }

    ~memory_chunk()
    {
        ::operator delete(m_start, std::align_val_t(Alignment));
    }

    // 释放对象,清空内存
    void clear()
    {
        std::ranges::destroy(m_objects);
        m_objects.clear();  
        m_end = m_start;
    }

    // 判断obj是否位于该内存块上
    bool has_object(void* obj) const
    {
        return m_start <= obj && obj < m_end_of_capacity;
    }

    // 分配内存
    void* allocate(std::size_t size)
    {
        size = (size + 7) & -8;
        auto addr = m_end;
        m_end += size;
        m_objects.emplace_back((object*)addr);
        return addr;
    }
};

std::vector<struct object*> roots;
std::vector<struct object*> references;
memory_chunk heap1;
memory_chunk heap2;

using list = std::vector<struct object*>;

struct object
{
    std::variant<int, list> m_value;

    // 如果当前对象已经被拷贝到另一片内存上,那么m_forward表示拷贝之后的地址,否在为0
    object* m_forward = nullptr;

    // 定义构造函数
    object(int i) : m_value(std::in_place_index<0>, i) { }

    object(list ls) : m_value(std::in_place_index<1>, std::move(ls)) { }

    void mark()
    {
        if (m_forward)
            return; // 当前对象已经标记过了,直接退出

        auto new_address = (object*)heap2.allocate(sizeof(object));

        // 标记
        m_forward = new_address;

        std::cout << "Copy object from " 
            << std::hex << this << " to "
            << std::hex << new_address << ' ';

        // 将对象拷贝到heap2上
        if (is_int())
        {
            std::cout << "mark int\n";
            std::construct_at(new_address, std::get<int>(m_value));
        }
        else
        {
            std::cout << "mark list\n";
            // 我们需要将当前的对象copy到heap2中, C++需要注意一下深浅拷贝以及构造析构
            // 从算法流程中不难看出我们这里需要采用浅拷贝
            std::construct_at(new_address, std::move(std::get<list>(m_value)));

            // 用于后续更新引用对象
            references.emplace_back(new_address);

            // 递归地标记引用对象
            std::ranges::for_each(
                new_address->as_list(), 
                [](auto obj) { obj->mark(); }
            );
        }
    }

    void update_reference()
    {
        if (is_int())
            return;
        std::ranges::for_each(as_list(), [](auto& o) { o = (object*)o->m_forward; });
    }

    // 一些辅助函数
    bool is_int() const
    {
        return std::holds_alternative<int>(m_value);
    }

    int& as_int() 
    {
        return std::get<int>(m_value);
    }

    list& as_list()
    {
        return std::get<list>(m_value);
    }
};

int main(int argc, char const *argv[])
{
    // 构建对象
    auto i1 = (object*)heap1.allocate(sizeof(object));
    auto i2 = (object*)heap1.allocate(sizeof(object));
    auto list1 = (object*)heap1.allocate(sizeof(object));
    auto list2 = (object*)heap1.allocate(sizeof(object));

    std::construct_at(i1, 1);
    std::construct_at(i2, 2);
    std::construct_at(list1, list());
    std::construct_at(list2, list());

    // 添加引用
    roots.emplace_back(i1);
    roots.emplace_back(list1);
    roots.emplace_back(list2);

    list1->as_list().emplace_back(i1);
    list1->as_list().emplace_back(list2);

    list2->as_list().emplace_back(list1);

    // 检测对象地址
    for (auto obj : roots)
    {
        assert(heap1.has_object(obj));
    }

    assert(roots[0]->as_int() == 1);

    assert(roots[1]->as_list().size() == 2);
    assert(roots[1]->as_list()[0]  == roots[0]);
    assert(roots[1]->as_list()[1]  == roots[2]);

    assert(roots[2]->as_list().size() == 1);
    assert(roots[2]->as_list()[0]  == roots[1]);

    // mark & copy
    for (auto obj : roots)
    {
        obj->mark();
    }

    // 更新引用
    assert(references.size() == 2);
    assert(heap2.m_objects.size() == 3);
    for (auto obj : references)
    {
        obj->update_reference();
    }

    // 更新roots
    for (auto& obj : roots)
    {
        obj = (object*)obj->m_forward;;
    }

    heap1.clear();

    // 检测对象地址
    for (auto obj : roots)
    {
        assert(heap2.has_object(obj));
    }

    assert(roots[0]->as_int() == 1);

    assert(roots[1]->as_list().size() == 2);
    assert(roots[1]->as_list()[0]  == roots[0]);
    assert(roots[1]->as_list()[1]  == roots[2]);

    assert(roots[2]->as_list().size() == 1);
    assert(roots[2]->as_list()[0]  == roots[1]);

    std::cout << "OK\n";
    return 0;
}