调教AI给我写了一个KD树的算法

慈云数据 2024-05-09 技术支持 37 0

我不擅长C++,但是目前需要用C++写一个KD树的算法。首先我有一份点云数据,需要找给定坐标范围0.1mm内的所有点。

于是我开始问AI,他一开始给的答案,完全是错误的,但是我一步步给出反馈,告诉他的问题,他居然给出的正确的答案(我和Python代码做了对比,python现成程的库实现,可作为结果进行参考)

下面,是我和AI的对话过程。

这看起来像是一系列三维空间中的点的坐标数据,每行表示一个点,包含了该点的 x、y、z 坐标值。如果你想要使用 KD 树对这些点进行空间检索或者其他操作,你可以使用 C++ 来读取这些数据,然后构建 KD 树。以下是一个简单的示例代码,可以帮助你读取这个文件并构建 KD 树:

#include 
#include 
#include 
#include 
#include 
// Define a structure to represent a point in 3D space
struct Point3D {
    double x, y, z;
};
// Define a node structure for the KD tree
struct KDNode {
    Point3D point;
    KDNode* left;
    KDNode* right;
    int axis; // 0 for x, 1 for y, 2 for z
    KDNode(const Point3D& p) : point(p), left(nullptr), right(nullptr), axis(0) {}
};
// Function to recursively build a KD tree from a list of points
KDNode* buildKDTree(std::vector& points, int depth = 0) {
    if (points.empty()) return nullptr;
    // Determine current axis
    int axis = depth % 3;
    // Sort points based on current axis
    if (axis == 0)
        std::sort(points.begin(), points.end(), [](const Point3D& a, const Point3D& b) { return a.x axis = axis;
    // Recursively build left and right subtrees
    std::vector leftPoints(points.begin(), points.begin() + medianIdx);
    std::vector rightPoints(points.begin() + medianIdx + 1, points.end());
    node->left = buildKDTree(leftPoints, depth + 1);
    node->right = buildKDTree(rightPoints, depth + 1);
    return node;
}
// Function to print a KD tree (for debugging)
void printKDTree(KDNode* root) {
    if (root) {
        std::cout > y >> comma >> z) {
            points.push_back({x, y, z});
        } else {
            std::cerr right = buildKDTree(rightPoints, depth + 1);
    return node;
}
// Function to calculate Euclidean distance between two points
double distance(const Point3D& a, const Point3D& b) {
    return std::sqrt(std::pow(a.x - b.x, 2) + std::pow(a.y - b.y, 2) + std::pow(a.z - b.z, 2));
}
// Function to perform range search in KD tree
void rangeSearch(KDNode* root, const Point3D& target, double radius, std::vector& result) {
    if (!root) return;
    // Calculate distance to current point
    double dist = distance(root->point, target);
    // If the current point is within the radius, add it to the result
    if (dist point);
    }
    // Recursively search left and right subtrees if they might contain points within the radius
    if (root->left && target.axis[root->axis] - radius point.axis[root->axis]) {
        rangeSearch(root->left, target, radius, result);
    }
    if (root->right && target.axis[root->axis] + radius >= root->point.axis[root->axis]) {
        rangeSearch(root->right, target, radius, result);
    }
}
int main() {
    std::ifstream file("point_cloud.txt");
    if (!file.is_open()) {
        std::cerr  x >> comma >> y >> comma >> z) {
            points.push_back({x, y, z});
        } else {
            std::cerr axis;
    // Recursively search left and right subtrees if they might contain points within the radius
    if (axis == 0) {
        if (root->left && (target.x - radius) point.x) {
            rangeSearch(root->left, target, radius, result);
        }
        if (root->right && (target.x + radius) >= root->point.x) {
            rangeSearch(root->right, target, radius, result);
        }
    }
    else if (axis == 1) {
        if (root->left && (target.y - radius) point.y) {
            rangeSearch(root->left, target, radius, result);
        }
        if (root->right && (target.y + radius) >= root->point.y) {
            rangeSearch(root->right, target, radius, result);
        }
    }
    else {
        if (root->left && (target.z - radius) point.z) {
            rangeSearch(root->left, target, radius, result);
        }
        if (root->right && (target.z + radius) >= root->point.z) {
            rangeSearch(root->right, target, radius, result);
        }
    }
}

在这个修正版本中,我添加了对当前节点的坐标与目标点的距离的检查,以确定是否需要继续向下搜索。这样就可以确保每个节点只在满足条件时才向下搜索,从而避免了重复添加点的问题。现在应该能够正确地找到给定点附近半径为0.1mm范围内的所有点。

抱歉,打印结果的小数位数不够精确是因为我没有设置输出精度。让我修正这个问题,确保打印结果的小数位数和原始数据一致。

#include  // Add this include for setprecision
// Function to perform range search in KD tree
void rangeSearch(KDNode* root, const Point3D& target, double radius, std::vector& result) {
    if (!root) return;
    // Calculate distance to current point
    double dist = distance(root->point, target);
    // If the current point is within the radius, add it to the result
    if (dist point);
    }
    // Determine current axis
    int axis = root->axis;
    // Recursively search left and right subtrees if they might contain points within the radius
    if (axis == 0) {
        if (root->left && (target.x - radius) point.x) {
            rangeSearch(root->left, target, radius, result);
        }
        if (root->right && (target.x + radius) >= root->point.x) {
            rangeSearch(root->right, target, radius, result);
        }
    }
    else if (axis == 1) {
        if (root->left && (target.y - radius) point.y) {
            rangeSearch(root->left, target, radius, result);
        }
        if (root->right && (target.y + radius) >= root->point.y) {
            rangeSearch(root->right, target, radius, result);
        }
    }
    else {
        if (root->left && (target.z - radius) point.z) {
            rangeSearch(root->left, target, radius, result);
        }
        if (root->right && (target.z + radius) >= root->point.z) {
            rangeSearch(root->right, target, radius, result);
        }
    }
}
int main() {
    // Build KD tree and other code...
    // Print the points within the range
    std::cout 
微信扫一扫加客服

微信扫一扫加客服

点击启动AI问答
Draggable Icon