kd树和knn算法的c语言实现

2018-06-18 04:11:22来源:未知 阅读 ()

新老客户大回馈,云服务器低至5折

  基于kd树的knn的实现原理可以参考文末的链接,都是一些好文章。

  这里参考了别人的代码。用c语言写的包括kd树的构建与查找k近邻的程序。

 

  code:

  1 #include<stdio.h>
  2 #include<stdlib.h>
  3 #include<math.h>
  4 #include<time.h>
  5 
  6 typedef struct{//数据维度
  7     double x;
  8     double y;
  9 }data_struct;
 10 
 11 typedef struct kd_node{
 12     data_struct split_data;//数据结点
 13     int split;//分裂维
 14     struct kd_node *left;//由位于该结点分割超面左子空间内所有数据点构成的kd-tree
 15     struct kd_node *right;//由位于该结点分割超面右子空间内所有数据点构成的kd-tree
 16 }kd_struct;
 17 
 18 //用于排序
 19 int cmp1( const void *a , const void *b )
 20 {
 21     return (*(data_struct *)a).x > (*(data_struct *)b).x ? 1:-1;
 22 }
 23 //用于排序
 24 int cmp2( const void *a , const void *b )
 25 {
 26     return (*(data_struct *)a).y > (*(data_struct *)b).y ? 1:-1;
 27 }
 28 //计算分裂维和分裂结点
 29 void choose_split(data_struct data_set[],int size,int dimension,int *split,data_struct *split_data)
 30 {
 31     int i;
 32     data_struct *data_temp;
 33     data_temp=(data_struct *)malloc(size*sizeof(data_struct));
 34     for(i=0;i<size;i++)
 35         data_temp[i]=data_set[i];
 36     static int count=0;//设为静态
 37     *split=(count++)%dimension;//分裂维
 38     if((*split)==0) qsort(data_temp,size,sizeof(data_temp[0]),cmp1);
 39     else qsort(data_temp,size,sizeof(data_temp[0]),cmp2);
 40     *split_data=data_temp[(size-1)/2];//分裂结点排在中位
 41 }
 42 //判断两个数据点是否相等
 43 int equal(data_struct a,data_struct b){
 44     if(a.x==b.x && a.y==b.y)    return 1;
 45     else    return 0;
 46 }
 47 //建立KD树
 48 kd_struct *build_kdtree(data_struct data_set[],int size,int dimension,kd_struct *T)
 49 {
 50     if(size==0) return NULL;//递归出口
 51     else{
 52         int sizeleft=0,sizeright=0;
 53         int i,split;
 54         data_struct split_data;
 55         choose_split(data_set,size,dimension,&split,&split_data);
 56         data_struct data_right[size];
 57         data_struct data_left[size];
 58 
 59         if (split==0){//x维
 60             for(i=0;i<size;++i){
 61                 if(!equal(data_set[i],split_data) && data_set[i].x <= split_data.x){//比分裂结点小
 62                     data_left[sizeleft].x=data_set[i].x;
 63                     data_left[sizeleft].y=data_set[i].y;
 64                     sizeleft++;//位于分裂结点的左子空间的结点数
 65                 }
 66                 else if(!equal(data_set[i],split_data) && data_set[i].x > split_data.x){//比分裂结点大
 67                     data_right[sizeright].x=data_set[i].x;
 68                     data_right[sizeright].y=data_set[i].y;
 69                     sizeright++;//位于分裂结点的右子空间的结点数
 70                 }
 71             }
 72         }
 73         else{//y维
 74             for(i=0;i<size;++i){
 75                 if(!equal(data_set[i],split_data) && data_set[i].y <= split_data.y){
 76                     data_left[sizeleft].x=data_set[i].x;
 77                     data_left[sizeleft].y=data_set[i].y;
 78                     sizeleft++;
 79                 }
 80                 else if (!equal(data_set[i],split_data) && data_set[i].y > split_data.y){
 81                     data_right[sizeright].x = data_set[i].x;
 82                     data_right[sizeright].y = data_set[i].y;
 83                     sizeright++;
 84                 }
 85             }
 86         }
 87         T=(kd_struct *)malloc(sizeof(kd_struct));
 88         T->split_data.x=split_data.x;
 89         T->split_data.y=split_data.y;
 90         T->split=split;
 91         T->left=build_kdtree(data_left,sizeleft,dimension,T->left);//左子空间
 92         T->right=build_kdtree(data_right,sizeright,dimension,T->right);//右子空间
 93         return T;//返回指针
 94     }
 95 }
 96 //计算欧氏距离
 97 double compute_distance(data_struct a,data_struct b){
 98     double tmp=pow(a.x-b.x,2.0)+pow(a.y-b.y,2.0);
 99     return sqrt(tmp);
100 }
101 //搜索1近邻
102 void search_nearest(kd_struct *T,int size,data_struct test,data_struct *nearest_point,double *distance)
103 {
104     int path_size;//搜索路径内的指针数目
105     kd_struct *search_path[size];//搜索路径保存各结点的指针
106     kd_struct* psearch=T;
107     data_struct nearest;//最近邻的结点
108     double dist;//查询结点与最近邻结点的距离
109     search_path[0]=psearch;//初始化搜索路径
110     path_size=1;
111     while(psearch->left!=NULL || psearch->right!=NULL){
112         if (psearch->split==0){
113             if(test.x <= psearch->split_data.x)//如果小于就进入左子树
114                 psearch=psearch->left;
115             else
116                 psearch=psearch->right;
117         }
118         else{
119             if(test.y <= psearch->split_data.y)//如果小于就进入右子树
120                 psearch=psearch->left;
121             else
122                 psearch=psearch->right;
123         }
124         search_path[path_size++]=psearch;//将经过的分裂结点保存在搜索路径中
125     }
126     //取出search_path最后一个元素,即叶子结点赋给nearest
127     nearest.x=search_path[path_size-1]->split_data.x;
128     nearest.y=search_path[path_size-1]->split_data.y;
129     path_size--;//search_path的指针数减一
130     dist=compute_distance(nearest,test);//计算与该叶子结点的距离作为初始距离
131 
132     //回溯搜索路径
133     kd_struct* pback;
134     while(path_size!=0){
135         pback=search_path[path_size-1];//取出search_path最后一个结点赋给pback
136         path_size--;//search_path的指针数减一
137 
138         if(pback->left==NULL && pback->right==NULL){//如果pback为叶子结点
139             if(dist>compute_distance(pback->split_data,test)){
140                 nearest=pback->split_data;
141                 dist=compute_distance(pback->split_data,test);
142             }
143         }
144         else{//如果pback为分裂结点
145             int s=pback->split;
146             if(s==0){//x维
147                 if(fabs(pback->split_data.x-test.x)<dist){//若以查询点为中心的圆(球或超球),半径为dist的圆与分割超平面相交,那么就要跳到另一边的子空间去搜索
148                     if(dist>compute_distance(pback->split_data,test)){
149                         nearest=pback->split_data;
150                         dist=compute_distance(pback->split_data, test);
151                     }
152                     if(test.x<=pback->split_data.x)//若查询点位于pback的左子空间,那么就要跳到右子空间去搜索
153                         psearch=pback->right;
154                     else
155                         psearch=pback->left;//若以查询点位于pback的右子空间,那么就要跳到左子空间去搜索
156                     if(psearch!=NULL)
157                         search_path[path_size++]=psearch;//psearch加入到search_path中
158                 }
159             }
160             else {//y维
161                 if(fabs(pback->split_data.y-test.y)<dist){//若以查询点为中心的圆(球或超球),半径为dist的圆与分割超平面相交,那么就要跳到另一边的子空间去搜索
162                     if(dist>compute_distance(pback->split_data,test)){
163                         nearest=pback->split_data;
164                         dist=compute_distance(pback->split_data,test);
165                     }
166                     if(test.y<=pback->split_data.y)//若查询点位于pback的左子空间,那么就要跳到右子空间去搜索
167                         psearch=pback->right;
168                     else
169                         psearch=pback->left;//若查询点位于pback的的右子空间,那么就要跳到左子空间去搜索
170                     if(psearch!=NULL)
171                         search_path[path_size++]=psearch;//psearch加入到search_path中
172                 }
173             }
174         }
175     }
176 
177     (*nearest_point).x=nearest.x;//最近邻
178     (*nearest_point).y=nearest.y;
179     *distance=dist;//距离
180 }
181 
182 int main()
183 {
184     int n=6;//数据个数
185     data_struct nearest_point;
186     double distance;
187     kd_struct *root=NULL;
188     data_struct data_set[6]={{2,3},{5,4},{9,6},{4,7},{8,1},{7,2}};//数据集
189     data_struct test={7.1,2.1};//查询点
190     root=build_kdtree(data_set,n,2,root);
191 
192     search_nearest(root,n,test,&nearest_point,&distance);
193     printf("nearest neighbor:(%.2f,%.2f)\ndistance:%.2f \n",nearest_point.x,nearest_point.y,distance);
194     return 0;
195 }
196 /*                    x          5,4
197                                 /    \
198                       y       2,3    7.2
199                                 \    /  \
200                       x        4,7  8.1 9.6
201 */

 

参考:

  https://www.joinquant.com/post/2627?f=study&m=math

  https://www.joinquant.com/post/2843?f=study&m=math

  http://blog.csdn.net/zhl30041839/article/details/9277807

 

标签:

版权申明:本站文章部分自网络,如有侵权,请联系:west999com@outlook.com
特别注意:本站所有转载文章言论不代表本站观点,本站所提供的摄影照片,插画,设计作品,如需使用,请与原作者联系,版权归原作者所有

上一篇:[算法]——快速排序(Quick Sort)

下一篇:C语言计算2个数的最小公倍数