前言
kd树的实现原理,我之前的一篇博客写了一下kd树优化的k近邻算法
参考文章:wenffe:python实现KD树
1. kd树的构造
import
numpy
as
np
class
Node
(
object
)
:
""" 定义节点类: val:节点中的实例点 label:节点中实例的类别 dim:当前节点的分割维度 left:节点的左子树 right:节点的右子树 parent:节点的父节点 """
def
__init__
(
self
,
val
=
None
,
label
=
None
,
dim
=
None
,
left
=
None
,
right
=
None
,
parent
=
None
)
:
self
.
val
=
val
self
.
label
=
label
self
.
dim
=
dim
self
.
left
=
left
self
.
right
=
right
self
.
parent
=
parent
class
kdTree
(
object
)
:
""" 定义树类: dataNum:训练集的样本数量 root:构造的kd树的根节点 """
def
__init__
(
self
,
dataSet
,
labelList
)
:
self
.
dataNum
=
0
self
.
root
=
self
.
buildKdTree
(
dataSet
,
labelList
)
## 注意父节点的传值。
def
buildKdTree
(
self
,
dataSet
,
labelList
,
parentNode
=
None
)
:
data
=
np
.
array
(
dataSet
)
dataNum
,
dimNum
=
data
.
shape
# 训练集的样本数,单个数据的维数
label
=
np
.
array
(
labelList
)
.
reshape
(
dataNum
,
1
)
if
dataNum
==
0
:
# 如果训练集为数据,返回None
return
None
varList
=
self
.
getVar
(
data
)
# 计算各维度的方差
mid
=
dataNum
//
2
# 找到中位数
maxVarDimIndex
=
varList
.
index
(
max
(
varList
)
)
# 找到方差最大的维度
sortedDataIndex
=
data
[
:
,
maxVarDimIndex
]
.
argsort
(
)
# 按照方差最大的维度进行排序
midDataIndex
=
sortedDataIndex
[
mid
]
# 找到该维度处于中间位置的数据,作为根节点
if
dataNum
==
1
:
# 如果只有一个数据,那么直接返回根节点就行
self
.
dataNum
=
dataNum
return
Node
(
val
=
data
[
midDataIndex
]
,
label
=
label
[
midDataIndex
]
,
dim
=
maxVarDimIndex
,
left
=
None
,
right
=
None
,
parent
=
parentNode
)
root
=
Node
(
data
[
midDataIndex
]
,
label
[
midDataIndex
]
,
maxVarDimIndex
,
parent
=
parentNode
,
)
""" 划分左子树和右子树,然后递归 """
leftDataSet
=
data
[
sortedDataIndex
[
:
mid
]
]
# 注意是mid而不是不是midDataIndex
leftLabel
=
label
[
sortedDataIndex
[
:
mid
]
]
rightDataSet
=
data
[
sortedDataIndex
[
mid
+
1
:
]
]
rightLabel
=
label
[
sortedDataIndex
[
mid
+
1
:
]
]
root
.
left
=
self
.
buildKdTree
(
leftDataSet
,
leftLabel
,
parentNode
=
root
)
root
.
right
=
self
.
buildKdTree
(
rightDataSet
,
rightLabel
,
parentNode
=
root
)
self
.
dataNum
=
dataNum
# 记录训练记得样本数
return
root
def
root
(
self
)
:
return
self
.
root
def
getVar
(
self
,
data
)
:
# 求方差函数
rowLen
,
colLen
=
data
.
shape
varList
=
[
]
for
i
in
range
(
colLen
)
:
varList
.
append
(
np
.
var
(
data
[
:
,
i
]
)
)
return
varList
2. kd树转换成list和dict
2.1 转换成list
""" list中的每一个元素都是字典,字典的键分别是: 节点的值、节点的维度、节点的类别、节点的左右子树以及节点的父节点。 每一个字典,都表示一个节点。 """
def
transferTreeToList
(
self
,
root
,
rootList
=
[
]
)
:
if
root
==
None
:
return
None
tempDict
=
{
}
tempDict
[
"data"
]
=
root
.
val
tempDict
[
"left"
]
=
root
.
left
.
val
if
root
.
left
else
None
tempDict
[
"right"
]
=
root
.
right
.
val
if
root
.
right
else
None
tempDict
[
"parent"
]
=
root
.
parent
.
val
if
root
.
parent
else
None
tempDict
[
"label"
]
=
root
.
label
[
0
]
tempDict
[
"dim"
]
=
root
.
dim
rootList
.
append
(
tempDict
)
self
.
transferTreeToList
(
root
.
left
,
rootList
)
self
.
transferTreeToList
(
root
.
right
,
rootList
)
return
rootList
2.2 转换成字典
def
transferTreeToDict
(
self
,
root
)
:
if
root
==
None
:
return
None
""" 注意:字典的键必须是不可变的,不能使用数组或列表,因此这里使用元祖tuple """
dict
=
{
}
dict
[
tuple
(
root
.
val
)
]
=
{
}
dict
[
tuple
(
root
.
val
)
]
[
"label"
]
=
root
.
label
[
0
]
# root.label是一个np数组,要想返回值的话用下标即可。
dict
[
tuple
(
root
.
val
)
]
[
"dim"
]
=
root
.
dim
dict
[
tuple
(
root
.
val
)
]
[
"parent"
]
=
root
.
parent
.
val
if
root
.
parent
else
None
dict
[
tuple
(
root
.
val
)
]
[
"left"
]
=
self
.
transferTreeToDict
(
root
.
left
)
dict
[
tuple
(
root
.
val
)
]
[
"right"
]
=
self
.
transferTreeToDict
(
root
.
right
)
return
dict
3. kd树搜索
3.1 搜索包含目标点的x的叶节点
def
findtheNearestLeafNode
(
self
,
root
,
x
)
:
if
root
==
None
:
#或者直接用self.dataNum是否等于0即可检查
return
None
if
root
.
left
==
None
and
root
.
right
==
None
:
return
root
node
=
root
while
True
:
# 找到叶节点或没有某一个子树的节点
curDim
=
node
.
dim
if
x
[
curDim
]
<
node
.
val
[
curDim
]
:
if
not
node
.
left
:
return
node
node
=
node
.
left
else
:
if
not
node
.
right
:
return
node
node
=
node
.
right
3.2 搜索k个近邻点
""" 这里搜索了k个近邻点,和最近邻算法的唯一不同是,需要一个数组保存,当前的前k个近邻点, 而且判定条件,不是最近距离了,而是第K小的距离(结果的守门员), 只有当结果中的节点数不超过K或节点与输入实例的距离小于第K小的距离时才能进入结果数组 """
def
knnSearch
(
self
,
x
,
k
)
:
""" 当整个训练数据集不超K个时,训练数据集都是近邻点。 直接借助一个字典进行统计类别,按照多数决策原则进行判断即可 """
if
self
.
dataNum
<=
k
:
labelDict
=
{
}
for
element
in
self
.
transferTreeToList
(
self
.
root
)
:
if
element
[
"label"
]
not
in
labelDict
:
labelDict
[
element
[
'label'
]
]
=
0
labelDict
[
element
[
"label"
]
]
+=
1
sortedLabelList
=
sorted
(
labelDict
.
items
(
)
,
key
=
lambda
item
:
item
[
1
]
,
reverse
=
True
)
# 对字典排序返回的是由元祖组成的一个列表。
return
sortedLabelList
[
0
]
[
0
]
""" 先找到最近的叶子节点,然后递归的向上寻找 """
node
=
self
.
findtheNearestLeafNode
(
self
.
root
,
x
)
nodeList
=
[
]
if
node
==
None
:
# 如果是空树,直接返回None
return
None
x
=
np
.
array
(
x
)
distance
=
np
.
sqrt
(
sum
(
(
x
-
node
.
val
)
**
2
)
)
# 计算最近叶子节点和输入实例的距离
nodeList
.
append
(
[
distance
,
tuple
(
node
.
val
)
,
node
.
label
[
0
]
]
)
# 将距离,节点实例和类别作为一个数组加入结果中。
while
True
:
# 循环
if
node
==
self
.
root
:
# 当循环到根节点时,停止循环
break
parentNode
=
node
.
parent
# 找到当前节点的父节点
parentDis
=
np
.
sqrt
(
sum
(
(
x
-
parentNode
.
val
)
**
2
)
)
# 计算输入实例x和父节点的距离
if
k
>
len
(
nodeList
)
or
distance
>
parentDis
:
# 如果当前的结果中不足K个节点或与父节点的距离小于当前列表中距离x最大的距离,
nodeList
.
append
(
[
parentDis
,
tuple
(
parentNode
.
val
)
,
parentNode
.
label
[
0
]
]
)
# 压入结果列表
nodeList
.
sort
(
)
# 排序
distance
=
nodeList
[
-
1
]
[
0
]
if
k
>
len
(
nodeList
)
else
nodeList
[
k
-
1
]
[
0
]
# 更新dis为入队节点中第K小的距离或者直接就是距离最大的距离
if
k
>
len
(
nodeList
)
or
abs
(
x
[
parentNode
.
dim
]
-
parentNode
.
val
[
parentNode
.
dim
]
)
<
distance
:
# 判断另一子节点区域有没有距离更近的节点
if
x
[
parentNode
.
dim
]
<
parentNode
.
val
[
parentNode
.
dim
]
:
otherChild
=
parentNode
.
right
# 如果x当前维度的值小于父节点的值
# 说明x在父节点的左子树上,往右节点寻找
self
.
search
(
nodeList
,
otherChild
,
x
,
k
)
# 递归的进行近邻点的寻找
else
:
# 否则,往左子节点寻找
otherChild
=
parentNode
.
left
self
.
search
(
nodeList
,
otherChild
,
x
,
k
)
node
=
node
.
parent
labelDict
=
{
}
# 统计类别,并判断实例点的类别
nodeList
=
nodeList
[
:
k
]
if
k
<=
len
(
nodeList
)
else
nodeList
for
element
in
nodeList
:
if
element
[
2
]
not
in
labelDict
:
labelDict
[
element
[
2
]
]
=
0
labelDict
[
element
[
2
]
]
+=
1
sortedLabel
=
sorted
(
labelDict
.
items
(
)
,
key
=
lambda
x
:
x
[
1
]
,
reverse
=
True
)
return
sortedLabel
[
0
]
[
0
]
def
search
(
self
,
nodeList
,
root
,
x
,
k
)
:
# 递归的进行k近邻的搜素,和上面的函数几乎一样,只是没有类别的统计和判断
if
root
==
None
:
return
nodeList
nodeList
.
sort
(
)
dis
=
nodeList
[
-
1
]
[
0
]
if
k
>
len
(
nodeList
)
else
nodeList
[
k
-
1
]
[
0
]
x
=
np
.
array
(
x
)
node
=
self
.
findtheNearestLeafNode
(
root
,
x
)
distance
=
np
.
sqrt
(
sum
(
(
x
-
node
.
val
)
**
2
)
)
if
k
>
len
(
nodeList
)
or
distance
<
dis
:
nodeList
.
append
(
[
distance
,
tuple
(
node
.
val
)
,
node
.
label
[
0
]
]
)
nodeList
.
sort
(
)
dis
=
nodeList
[
-
1
]
[
0
]
if
k
>
len
(
nodeList
)
else
nodeList
[
k
-
1
]
[
0
]
while
True
:
if
node
==
root
:
break
parentNode
=
node
.
parent
parentDis
=
np
.
sqrt
(
sum
(
(
x
-
parentNode
.
val
)
**
2
)
)
if
k
>
len
(
nodeList
)
or
parentDis
<
dis
:
nodeList
.
append
(
[
parentDis
,
tuple
(
parentNode
.
val
)
,
parentNode
.
label
[
0
]
]
)
nodeList
.
sort
(
)
dis
=
nodeList
[
-
1
]
[
0
]
if
k
>
len
(
nodeList
)
else
nodeList
[
k
-
1
]
[
0
]
if
k
>
len
(
nodeList
)
or
abs
(
x
[
parentNode
.
dim
]
-
parentNode
.
val
[
parentNode
.
dim
]
)
<
dis
:
if
x
[
parentNode
.
dim
]
<
parentNode
.
val
[
parentNode
.
val
]
:
otherChild
=
parentNode
.
right
self
.
search
(
nodeList
,
otherChild
,
x
,
k
)
else
:
otherChild
=
parentNode
.
left
self
.
search
(
nodeList
,
otherChild
,
x
,
k
)
node
=
node
.
parent
4. 举例
if
__name__
==
"__main__"
:
dataArray
=
[
[
7
,
2
]
,
[
5
,
4
]
,
[
2
,
3
]
,
[
4
,
7
]
,
[
9
,
6
]
,
[
8
,
1
]
]
label
=
[
[
0
]
,
[
1
]
,
[
0
]
,
[
1
]
,
[
1
]
,
[
1
]
]
kd
=
kdTree
(
dataArray
,
label
)
Tree
=
kd
.
buildKdTree
(
dataArray
,
label
)
## tree是根节点
list
=
kd
.
transferTreeToList
(
Tree
,
[
]
)
dict
=
kd
.
transferTreeToDict
(
Tree
)
node
=
kd
.
findtheNearestLeafNode
(
Tree
,
[
6
,
3
]
)
result
=
kd
.
knnSearch
(
[
6
,
3
]
,
1
)
print
(
list
)
print
(
result
)
""" 输出结果为:[ {'data': array([7, 2]), 'left': array([5, 4]), 'right': array([9, 6]), 'parent': None, 'label': 0, 'dim': 0}, {'data': array([5, 4]), 'left': array([2, 3]), 'right': array([4, 7]), 'parent': array([7, 2]), 'label': 1, 'dim': 1}, {'data': array([2, 3]), 'left': None, 'right': None, 'parent': array([5, 4]), 'label': 0, 'dim': 0}, {'data': array([4, 7]), 'left': None, 'right': None, 'parent': array([5, 4]), 'label': 1, 'dim': 0}, {'data': array([9, 6]), 'left': array([8, 1]), 'right': None, 'parent': array([7, 2]), 'label': 1, 'dim': 1}, {'data': array([8, 1]), 'left': None, 'right': None, 'parent': array([9, 6]), 'label': 1, 'dim': 0}] """
# 类别为:1
5. 完整代码
```python
import
numpy
as
np
class
Node
(
object
)
:
def
__init__
(
self
,
val
=
None
,
label
=
None
,
dim
=
None
,
left
=
None
,
right
=
None
,
parent
=
None
)
:
self
.
val
=
val
self
.
label
=
label
self
.
dim
=
dim
self
.
left
=
left
self
.
right
=
right
self
.
parent
=
parent
class
kdTree
(
object
)
:
def
__init__
(
self
,
dataSet
,
labelList
)
:
self
.
dataNum
=
0
self
.
root
=
self
.
buildKdTree
(
dataSet
,
labelList
)
## 注意父节点的传值。
def
buildKdTree
(
self
,
dataSet
,
labelList
,
parentNode
=
None
)
:
data
=
np
.
array
(
dataSet
)
dataNum
,
dimNum
=
data
.
shape
label
=
np
.
array
(
labelList
)
.
reshape
(
dataNum
,
1
)
if
dataNum
==
0
:
return
None
varList
=
self
.
getVar
(
data
)
mid
=
dataNum
//
2
maxVarDimIndex
=
varList
.
index
(
max
(
varList
)
)
sortedDataIndex
=
data
[
:
,
maxVarDimIndex
]
.
argsort
(
)
midDataIndex
=
sortedDataIndex
[
mid
]
if
dataNum
==
1
:
self
.
dataNum
=
dataNum
return
Node
(
val
=
data
[
midDataIndex
]
,
label
=
label
[
midDataIndex
]
,
dim
=
maxVarDimIndex
,
left
=
None
,
right
=
None
,
parent
=
parentNode
)
root
=
Node
(
data
[
midDataIndex
]
,
label
[
midDataIndex
]
,
maxVarDimIndex
,
parent
=
parentNode
,
)
leftDataSet
=
data
[
sortedDataIndex
[
:
mid
]
]
##### 注意是mid不是midDataIndex
leftLabel
=
label
[
sortedDataIndex
[
:
mid
]
]
rightDataSet
=
data
[
sortedDataIndex
[
mid
+
1
:
]
]
rightLabel
=
label
[
sortedDataIndex
[
mid
+
1
:
]
]
root
.
left
=
self
.
buildKdTree
(
leftDataSet
,
leftLabel
,
parentNode
=
root
)
root
.
right
=
self
.
buildKdTree
(
rightDataSet
,
rightLabel
,
parentNode
=
root
)
self
.
dataNum
=
dataNum
return
root
def
root
(
self
)
:
return
self
.
root
def
transferTreeToDict
(
self
,
root
)
:
if
root
==
None
:
return
None
""" 字典的键必须是不可变的 """
dict
=
{
}
dict
[
tuple
(
root
.
val
)
]
=
{
}
dict
[
tuple
(
root
.
val
)
]
[
"label"
]
=
root
.
label
[
0
]
# root.label是一个数组,要想返回值的话用下标即可。
dict
[
tuple
(
root
.
val
)
]
[
"dim"
]
=
root
.
dim
dict
[
tuple
(
root
.
val
)
]
[
"parent"
]
=
root
.
parent
.
val
if
root
.
parent
else
None
dict
[
tuple
(
root
.
val
)
]
[
"left"
]
=
self
.
transferTreeToDict
(
root
.
left
)
dict
[
tuple
(
root
.
val
)
]
[
"right"
]
=
self
.
transferTreeToDict
(
root
.
right
)
return
dict
def
transferTreeToList
(
self
,
root
,
rootList
=
[
]
)
:
if
root
==
None
:
return
None
tempDict
=
{
}
tempDict
[
"data"
]
=
root
.
val
tempDict
[
"left"
]
=
root
.
left
.
val
if
root
.
left
else
None
tempDict
[
"right"
]
=
root
.
right
.
val
if
root
.
right
else
None
tempDict
[
"parent"
]
=
root
.
parent
.
val
if
root
.
parent
else
None
tempDict
[
"label"
]
=
root
.
label
[
0
]
tempDict
[
"dim"
]
=
root
.
dim
rootList
.
append
(
tempDict
)
self
.
transferTreeToList
(
root
.
left
,
rootList
)
self
.
transferTreeToList
(
root
.
right
,
rootList
)
return
rootList
def
getVar
(
self
,
data
)
:
rowLen
,
colLen
=
data
.
shape
varList
=
[
]
for
i
in
range
(
colLen
)
:
varList
.
append
(
np
.
var
(
data
[
:
,
i
]
)
)
return
varList
def
findtheNearestLeafNode
(
self
,
root
,
x
)
:
if
root
==
None
:
#或者直接用self.dataNum是否等于0即可检查
return
None
if
root
.
left
==
None
and
root
.
right
==
None
:
return
root
node
=
root
while
True
:
curDim
=
node
.
dim
if
x
[
curDim
]
<
node
.
val
[
curDim
]
:
if
not
node
.
left
:
return
node
node
=
node
.
left
else
:
if
not
node
.
right
:
return
node
node
=
node
.
right
def
knnSearch
(
self
,
x
,
k
)
:
if
self
.
dataNum
<=
k
:
labelDict
=
{
}
for
element
in
self
.
transferTreeToList
(
self
.
root
)
:
if
element
[
"label"
]
not
in
labelDict
:
labelDict
[
element
[
'label'
]
]
=
0
labelDict
[
element
[
"label"
]
]
+=
1
sortedLabelList
=
sorted
(
labelDict
.
items
(
)
,
key
=
lambda
item
:
item
[
1
]
,
reverse
=
True
)
# 对字典排序返回的是由元祖组成的一个列表。
return
sortedLabelList
[
0
]
[
0
]
node
=
self
.
findtheNearestLeafNode
(
self
.
root
,
x
)
nodeList
=
[
]
if
node
==
None
:
return
None
x
=
np
.
array
(
x
)
distance
=
np
.
sqrt
(
sum
(
(
x
-
node
.
val
)
**
2
)
)
nodeList
.
append
(
[
distance
,
tuple
(
node
.
val
)
,
node
.
label
[
0
]
]
)
while
True
:
if
node
==
self
.
root
:
break
parentNode
=
node
.
parent
parentDis
=
np
.
sqrt
(
sum
(
(
x
-
parentNode
.
val
)
**
2
)
)
if
k
>
len
(
nodeList
)
or
distance
>
parentDis
:
nodeList
.
append
(
[
parentDis
,
tuple
(
parentNode
.
val
)
,
parentNode
.
label
[
0
]
]
)
nodeList
.
sort
(
)
distance
=
nodeList
[
-
1
]
[
0
]
if
k
>
len
(
nodeList
)
else
nodeList
[
k
-
1
]
[
0
]
if
k
>
len
(
nodeList
)
or
abs
(
x
[
parentNode
.
dim
]
-
parentNode
.
val
[
parentNode
.
dim
]
)
<
distance
:
if
x
[
parentNode
.
dim
]
<
parentNode
.
val
[
parentNode
.
dim
]
:
otherChild
=
parentNode
.
right
self
.
search
(
nodeList
,
otherChild
,
x
,
k
)
else
:
otherChild
=
parentNode
.
left
self
.
search
(
nodeList
,
otherChild
,
x
,
k
)
node
=
node
.
parent
labelDict
=
{
}
nodeList
=
nodeList
[
:
k
]
if
k
<=
len
(
nodeList
)
else
nodeList
for
element
in
nodeList
:
if
element
[
2
]
not
in
labelDict
:
labelDict
[
element
[
2
]
]
=
0
labelDict
[
element
[
2
]
]
+=
1
sortedLabel
=
sorted
(
labelDict
.
items
(
)
,
key
=
lambda
x
:
x
[
1
]
,
reverse
=
True
)
return
sortedLabel
[
0
]
[
0
]
def
search
(
self
,
nodeList
,
root
,
x
,
k
)
:
if
root
==
None
:
return
nodeList
nodeList
.
sort
(
)
dis
=
nodeList
[
-
1
]
[
0
]
if
k
>
len
(
nodeList
)
else
nodeList
[
k
-
1
]
[
0
]
x
=
np
.
array
(
x
)
node
=
self
.
findtheNearestLeafNode
(
root
,
x
)
distance
=
np
.
sqrt
(
sum
(
(
x
-
node
.
val
)
**
2
)
)
if
k
>
len
(
nodeList
)
or
distance
<
dis
:
nodeList
.
append
(
[
distance
,
tuple
(
node
.
val
)
,
node
.
label
[
0
]
]
)
nodeList
.
sort
(
)
dis
=
nodeList
[
-
1
]
[
0
]
if
k
>
len
(
nodeList
)
else
nodeList
[
k
-
1
]
[
0
]
while
True
:
if
node
==
root
:
break
parentNode
=
node
.
parent
parentDis
=
np
.
sqrt
(
sum
(
(
x
-
parentNode
.
val
)
**
2
)
)
if
k
>
len
(
nodeList
)
or
parentDis
<
dis
:
nodeList
.
append
(
[
parentDis
,
tuple
(
parentNode
.
val
)
,
parentNode
.
label
[
0
]
]
)
nodeList
.
sort
(
)
dis
=
nodeList
[
-
1
]
[
0
]
if
k
>
len
(
nodeList
)
else
nodeList
[
k
-
1
]
[
0
]
if
k
>
len
(
nodeList
)
or
abs
(
x
[
parentNode
.
dim
]
-
parentNode
.
val
[
parentNode
.
dim
]
)
<
dis
:
if
x
[
parentNode
.
dim
]
<
parentNode
.
val
[
parentNode
.
val
]
:
otherChild
=
parentNode
.
right
self
.
search
(
nodeList
,
otherChild
,
x
,
k
)
else
:
otherChild
=
parentNode
.
left
self
.
search
(
nodeList
,
otherChild
,
x
,
k
)
node
=
node
.
parent
if
__name__
==
"__main__"
:
dataArray
=
[
[
7
,
2
]
,
[
5
,
4
]
,
[
2
,
3
]
,
[
4
,
7
]
,
[
9
,
6
]
,
[
8
,
1
]
]
label
=
[
[
0
]
,
[
1
]
,
[
0
]
,
[
1
]
,
[
1
]
,
[
1
]
]
kd
=
kdTree
(
dataArray
,
label
)
Tree
=
kd
.
buildKdTree
(
dataArray
,
label
)
## tree是根节点
list
=
kd
.
transferTreeToList
(
Tree
,
[
]
)
dict
=
kd
.
transferTreeToDict
(
Tree
)
node
=
kd
.
findtheNearestLeafNode
(
Tree
,
[
6
,
3
]
)
result
=
kd
.
knnSearch
(
[
6
,
3
]
,
1
)
print
(
list
)
print
(
dict
)
print
(
result
)
print
(
node
.
val
)