RANSAC方法例子

发布时间 2023-06-25 19:59:11作者: 兜尼完

RANSAC是一种常用的剔除数据中异常点的方法。本文以拟合圆为例展示RANSAC的工作方式。首先我们有一组点,假设内点的概率是p。我们要使RANSAC的成功率至少达到${ \eta=99.9\% }$,那么至少需要重复选择多少次样本?首先确定一个圆需要3个点,即每次选择随机选3个点,因此一次选择包含外点的概率是${ 1-p^{3} }$。那么N次选择中每次都包含外点的概率是${ \left( 1-p^{3} \right)^{N} }$。我们需要N次选择至少有一次全是内点的概率超过${ \eta }$,有:

$${ 1-\left( 1-p^{3} \right)^{N} \geqslant \eta }$$

解之得:

$${ N \geqslant \frac{ ln\left( 1-\eta \right) }{ ln\left( 1-p^{3} \right) } }$$

下面给出代码:

class Ransac
{
public:
    Ransac(int ifitNumber, float confidence = 0.995f);
    virtual ~Ransac() = default;
    int execute(vector<Point2f>& points);

protected:
    virtual void fitModel(const vector<Point2f>& selects, Mat& result) = 0;
    virtual float evalModel(const Mat& model, const Point2f& pt) = 0;

private:
    int fitNumber;
    int loopCount;
};

Ransac::Ransac(int ifitNumber, float confidence)
{
    fitNumber = ifitNumber;
    loopCount = logf(1 - confidence) / logf(1 - powf(0.7f, fitNumber));
}

int Ransac::execute(vector<Point2f>& points)
{
    std::random_device rd;
    std::mt19937 e(rd());
    int count = (int)points.size();
    int bestInner = 0;
    vector<Point2f> innerDots, outerDots;
    for (int i = 0; i < loopCount; i++)
    {
        for (int j = 0; j < fitNumber; j++)
        {
            std::uniform_int_distribution<int> dist(j, count - 1);
            int r = dist(e);
            std::swap(points[j], points[r]);
        }
        Mat result;
        fitModel(points, result);
        int inner = 0;
        vector<Point2f> dots, outs;
        dots.reserve(count);
        outs.reserve(count);
        for (auto& item : points)
        {
            float score = evalModel(result, item);
            if (score >= 0.5f)
            {
                inner++;
                dots.push_back(item);
            }
            else
            {
                outs.push_back(item);
            }
        }
        if (inner > bestInner)
        {
            bestInner = inner;
            innerDots = std::move(dots);
            outerDots = std::move(outs);
        }
    }
    innerDots.insert(innerDots.end(), outerDots.begin(), outerDots.end());
    points = std::move(innerDots);
    return bestInner;
}

class CircleRansac : public Ransac
{
public:
    CircleRansac(float ierror, float confidence = 0.995f);

protected:
    void fitModel(const vector<Point2f>& selects, Mat& result) override;
    float evalModel(const Mat& model, const Point2f& pt) override;

private:
    float error;
};

CircleRansac::CircleRansac(float ierror, float confidence) :
    Ransac(3, confidence)
{
    error = ierror;
}

//---------------------------------------------------------------------------------------
// 3点求圆
//---------------------------------------------------------------------------------------
void CircleRansac::fitModel(const vector<Point2f>& selects, Mat& result)
{
    result.create(3, 1, CV_32FC1);
    float a = selects[0].x - selects[1].x;
    float b = selects[0].y - selects[1].y;
    float c = selects[0].x - selects[2].x;
    float d = selects[0].y - selects[2].y;
    float e = ((selects[0].x * selects[0].x - selects[1].x * selects[1].x) -
        (selects[1].y * selects[1].y - selects[0].y * selects[0].y)) * 0.5f;
    float f = ((selects[0].x * selects[0].x - selects[2].x * selects[2].x) -
        (selects[2].y * selects[2].y - selects[0].y * selects[0].y)) * 0.5f;
    float cx = (e * d - b * f) / (a * d - b * c);
    float cy = (a * f - e * c) / (a * d - b * c);
    float r = sqrtf((cx - selects[0].x) * (cx - selects[0].x) + (cy - selects[0].y) * (cy - selects[0].y));
    result.at<float>(0) = cx;
    result.at<float>(1) = cy;
    result.at<float>(2) = r;
}

float CircleRansac::evalModel(const Mat& model, const Point2f& pt)
{
    float cx = model.at<float>(0);
    float cy = model.at<float>(1);
    float r = model.at<float>(2);
    float dist = fabs(sqrtf((cx - pt.x) * (cx - pt.x) + (cy - pt.y) * (cy - pt.y)) - r);
    return dist < error ? 1 : 0;
}

使用方法如下:

int main()
{
    vector<Point2f> points = { ... };
    CircleRansac ransac(2.0f);
    int innerCount = ransac.execute(points);
    if (innerCount < 4)
    {
        return -1;
    }
    points.erase(points.begin() + innerCount, points.end());
    // 现在points里都是内点
    return 0;
}