k近邻算法之kd树优化(kd树的构造和搜索)——基于Python

系统 559 0

前言

kd树的实现原理,我之前的一篇博客写了一下kd树优化的k近邻算法
参考文章:wenffe:python实现KD树

1. kd树的构造

            
              
                import
              
               numpy 
              
                as
              
               np


              
                class
              
              
                Node
              
              
                (
              
              
                object
              
              
                )
              
              
                :
              
              
                """ 定义节点类: val:节点中的实例点 label:节点中实例的类别 dim:当前节点的分割维度 left:节点的左子树 right:节点的右子树 parent:节点的父节点 """
              
              
                def
              
              
                __init__
              
              
                (
              
              self
              
                ,
              
              val
              
                =
              
              
                None
              
              
                ,
              
              label
              
                =
              
              
                None
              
              
                ,
              
              dim
              
                =
              
              
                None
              
              
                ,
              
              left
              
                =
              
              
                None
              
              
                ,
              
              right
              
                =
              
              
                None
              
              
                ,
              
              parent
              
                =
              
              
                None
              
              
                )
              
              
                :
              
              
        self
              
                .
              
              val 
              
                =
              
               val
        self
              
                .
              
              label 
              
                =
              
               label
        self
              
                .
              
              dim 
              
                =
              
               dim
        self
              
                .
              
              left 
              
                =
              
               left
        self
              
                .
              
              right 
              
                =
              
               right
        self
              
                .
              
              parent 
              
                =
              
               parent
	

              
                class
              
              
                kdTree
              
              
                (
              
              
                object
              
              
                )
              
              
                :
              
              
                """ 定义树类: dataNum:训练集的样本数量 root:构造的kd树的根节点 """
              
              
                def
              
              
                __init__
              
              
                (
              
              self
              
                ,
              
              dataSet
              
                ,
              
              labelList
              
                )
              
              
                :
              
              
        self
              
                .
              
              dataNum 
              
                =
              
              
                0
              
              
        self
              
                .
              
              root 
              
                =
              
               self
              
                .
              
              buildKdTree
              
                (
              
              dataSet
              
                ,
              
              labelList
              
                )
              
              
                ## 注意父节点的传值。
              
              
                def
              
              
                buildKdTree
              
              
                (
              
              self
              
                ,
              
              dataSet
              
                ,
              
               labelList
              
                ,
              
               parentNode
              
                =
              
              
                None
              
              
                )
              
              
                :
              
              
        data 
              
                =
              
               np
              
                .
              
              array
              
                (
              
              dataSet
              
                )
              
              
        dataNum
              
                ,
              
               dimNum 
              
                =
              
               data
              
                .
              
              shape   
              
                # 训练集的样本数,单个数据的维数
              
              
        label 
              
                =
              
               np
              
                .
              
              array
              
                (
              
              labelList
              
                )
              
              
                .
              
              reshape
              
                (
              
              dataNum
              
                ,
              
              
                1
              
              
                )
              
              
                if
              
               dataNum 
              
                ==
              
              
                0
              
              
                :
              
              
                # 如果训练集为数据,返回None
              
              
                return
              
              
                None
              
              
        varList 
              
                =
              
               self
              
                .
              
              getVar
              
                (
              
              data
              
                )
              
              
                # 计算各维度的方差
              
              
        mid 
              
                =
              
               dataNum 
              
                //
              
              
                2
              
              
                # 找到中位数
              
              
        maxVarDimIndex 
              
                =
              
               varList
              
                .
              
              index
              
                (
              
              
                max
              
              
                (
              
              varList
              
                )
              
              
                )
              
              
                # 找到方差最大的维度
              
              
        sortedDataIndex 
              
                =
              
               data
              
                [
              
              
                :
              
              
                ,
              
              maxVarDimIndex
              
                ]
              
              
                .
              
              argsort
              
                (
              
              
                )
              
              
                # 按照方差最大的维度进行排序
              
              
        midDataIndex 
              
                =
              
               sortedDataIndex
              
                [
              
              mid
              
                ]
              
              
                # 找到该维度处于中间位置的数据,作为根节点
              
              
                if
              
               dataNum 
              
                ==
              
              
                1
              
              
                :
              
              
                # 如果只有一个数据,那么直接返回根节点就行
              
              
            self
              
                .
              
              dataNum 
              
                =
              
               dataNum
            
              
                return
              
               Node
              
                (
              
              val 
              
                =
              
               data
              
                [
              
              midDataIndex
              
                ]
              
              
                ,
              
              label 
              
                =
              
               label
              
                [
              
              midDataIndex
              
                ]
              
              
                ,
              
              dim 
              
                =
              
               maxVarDimIndex
              
                ,
              
              left 
              
                =
              
              
                None
              
              
                ,
              
              right 
              
                =
              
              
                None
              
              
                ,
              
              parent 
              
                =
              
               parentNode
              
                )
              
              
        root 
              
                =
              
               Node
              
                (
              
              data
              
                [
              
              midDataIndex
              
                ]
              
              
                ,
              
              label
              
                [
              
              midDataIndex
              
                ]
              
              
                ,
              
              maxVarDimIndex
              
                ,
              
              parent 
              
                =
              
               parentNode
              
                ,
              
              
                )
              
              
                """ 划分左子树和右子树,然后递归 """
              
              
        leftDataSet 
              
                =
              
                data
              
                [
              
              sortedDataIndex
              
                [
              
              
                :
              
              mid
              
                ]
              
              
                ]
              
              
                # 注意是mid而不是不是midDataIndex
              
              
        leftLabel 
              
                =
              
               label
              
                [
              
              sortedDataIndex
              
                [
              
              
                :
              
              mid
              
                ]
              
              
                ]
              
              
        rightDataSet 
              
                =
              
               data
              
                [
              
              sortedDataIndex
              
                [
              
              mid
              
                +
              
              
                1
              
              
                :
              
              
                ]
              
              
                ]
              
              
        rightLabel 
              
                =
              
               label
              
                [
              
              sortedDataIndex
              
                [
              
              mid
              
                +
              
              
                1
              
              
                :
              
              
                ]
              
              
                ]
              
              
        root
              
                .
              
              left 
              
                =
              
               self
              
                .
              
              buildKdTree
              
                (
              
              leftDataSet
              
                ,
              
              leftLabel
              
                ,
              
              parentNode 
              
                =
              
               root
              
                )
              
              
        root
              
                .
              
              right 
              
                =
              
               self
              
                .
              
              buildKdTree
              
                (
              
              rightDataSet
              
                ,
              
               rightLabel
              
                ,
              
              parentNode 
              
                =
              
               root
              
                )
              
              
        self
              
                .
              
              dataNum 
              
                =
              
               dataNum  
              
                # 记录训练记得样本数
              
              
                return
              
               root

    
              
                def
              
              
                root
              
              
                (
              
              self
              
                )
              
              
                :
              
              
                return
              
               self
              
                .
              
              root
    
              
                def
              
              
                getVar
              
              
                (
              
              self
              
                ,
              
              data
              
                )
              
              
                :
              
              
                # 求方差函数
              
              
        rowLen
              
                ,
              
              colLen 
              
                =
              
               data
              
                .
              
              shape
        varList 
              
                =
              
              
                [
              
              
                ]
              
              
                for
              
               i 
              
                in
              
              
                range
              
              
                (
              
              colLen
              
                )
              
              
                :
              
              
            varList
              
                .
              
              append
              
                (
              
              np
              
                .
              
              var
              
                (
              
              data
              
                [
              
              
                :
              
              
                ,
              
              i
              
                ]
              
              
                )
              
              
                )
              
              
                return
              
               varList

            
          

2. kd树转换成list和dict

2.1 转换成list

            
              
                """ list中的每一个元素都是字典,字典的键分别是: 节点的值、节点的维度、节点的类别、节点的左右子树以及节点的父节点。 每一个字典,都表示一个节点。 """
              
              
                def
              
              
                transferTreeToList
              
              
                (
              
              self
              
                ,
              
              root
              
                ,
              
              rootList 
              
                =
              
              
                [
              
              
                ]
              
              
                )
              
              
                :
              
              
                if
              
               root 
              
                ==
              
              
                None
              
              
                :
              
              
                return
              
              
                None
              
              
        tempDict 
              
                =
              
              
                {
              
              
                }
              
              
        tempDict
              
                [
              
              
                "data"
              
              
                ]
              
              
                =
              
               root
              
                .
              
              val
        tempDict
              
                [
              
              
                "left"
              
              
                ]
              
              
                =
              
               root
              
                .
              
              left
              
                .
              
              val 
              
                if
              
               root
              
                .
              
              left 
              
                else
              
              
                None
              
              
        tempDict
              
                [
              
              
                "right"
              
              
                ]
              
              
                =
              
               root
              
                .
              
              right
              
                .
              
              val 
              
                if
              
               root
              
                .
              
              right 
              
                else
              
              
                None
              
              
        tempDict
              
                [
              
              
                "parent"
              
              
                ]
              
              
                =
              
               root
              
                .
              
              parent
              
                .
              
              val 
              
                if
              
               root
              
                .
              
              parent 
              
                else
              
              
                None
              
              
        tempDict
              
                [
              
              
                "label"
              
              
                ]
              
              
                =
              
               root
              
                .
              
              label
              
                [
              
              
                0
              
              
                ]
              
              
        tempDict
              
                [
              
              
                "dim"
              
              
                ]
              
              
                =
              
               root
              
                .
              
              dim
        rootList
              
                .
              
              append
              
                (
              
              tempDict
              
                )
              
              
        self
              
                .
              
              transferTreeToList
              
                (
              
              root
              
                .
              
              left
              
                ,
              
              rootList
              
                )
              
              
        self
              
                .
              
              transferTreeToList
              
                (
              
              root
              
                .
              
              right
              
                ,
              
              rootList
              
                )
              
              
                return
              
               rootList

            
          

2.2 转换成字典

            
              
                def
              
              
                transferTreeToDict
              
              
                (
              
              self
              
                ,
              
              root
              
                )
              
              
                :
              
              
                if
              
               root 
              
                ==
              
              
                None
              
              
                :
              
              
                return
              
              
                None
              
              
                """ 注意:字典的键必须是不可变的,不能使用数组或列表,因此这里使用元祖tuple """
              
              
                dict
              
              
                =
              
              
                {
              
              
                }
              
              
                dict
              
              
                [
              
              
                tuple
              
              
                (
              
              root
              
                .
              
              val
              
                )
              
              
                ]
              
              
                =
              
              
                {
              
              
                }
              
              
                dict
              
              
                [
              
              
                tuple
              
              
                (
              
              root
              
                .
              
              val
              
                )
              
              
                ]
              
              
                [
              
              
                "label"
              
              
                ]
              
              
                =
              
               root
              
                .
              
              label
              
                [
              
              
                0
              
              
                ]
              
              
                # root.label是一个np数组,要想返回值的话用下标即可。
              
              
                dict
              
              
                [
              
              
                tuple
              
              
                (
              
              root
              
                .
              
              val
              
                )
              
              
                ]
              
              
                [
              
              
                "dim"
              
              
                ]
              
              
                =
              
               root
              
                .
              
              dim
        
              
                dict
              
              
                [
              
              
                tuple
              
              
                (
              
              root
              
                .
              
              val
              
                )
              
              
                ]
              
              
                [
              
              
                "parent"
              
              
                ]
              
              
                =
              
               root
              
                .
              
              parent
              
                .
              
              val 
              
                if
              
               root
              
                .
              
              parent 
              
                else
              
              
                None
              
              
                dict
              
              
                [
              
              
                tuple
              
              
                (
              
              root
              
                .
              
              val
              
                )
              
              
                ]
              
              
                [
              
              
                "left"
              
              
                ]
              
              
                =
              
               self
              
                .
              
              transferTreeToDict
              
                (
              
              root
              
                .
              
              left
              
                )
              
              
                dict
              
              
                [
              
              
                tuple
              
              
                (
              
              root
              
                .
              
              val
              
                )
              
              
                ]
              
              
                [
              
              
                "right"
              
              
                ]
              
              
                =
              
               self
              
                .
              
              transferTreeToDict
              
                (
              
              root
              
                .
              
              right
              
                )
              
              
                return
              
              
                dict
              
            
          

3. kd树搜索

3.1 搜索包含目标点的x的叶节点

            
              
                def
              
              
                findtheNearestLeafNode
              
              
                (
              
              self
              
                ,
              
              root
              
                ,
              
              x
              
                )
              
              
                :
              
              
                if
              
               root 
              
                ==
              
              
                None
              
              
                :
              
              
                #或者直接用self.dataNum是否等于0即可检查
              
              
                return
              
              
                None
              
              
                if
              
               root
              
                .
              
              left 
              
                ==
              
              
                None
              
              
                and
              
               root
              
                .
              
              right 
              
                ==
              
              
                None
              
              
                :
              
              
                return
              
               root
        node 
              
                =
              
               root
        
              
                while
              
              
                True
              
              
                :
              
              
                # 找到叶节点或没有某一个子树的节点
              
              
            curDim 
              
                =
              
               node
              
                .
              
              dim
            
              
                if
              
               x
              
                [
              
              curDim
              
                ]
              
              
                <
              
               node
              
                .
              
              val
              
                [
              
              curDim
              
                ]
              
              
                :
              
              
                if
              
              
                not
              
               node
              
                .
              
              left
              
                :
              
              
                return
              
               node
                node 
              
                =
              
               node
              
                .
              
              left
            
              
                else
              
              
                :
              
              
                if
              
              
                not
              
               node
              
                .
              
              right
              
                :
              
              
                return
              
               node
                node 
              
                =
              
               node
              
                .
              
              right

            
          

3.2 搜索k个近邻点

            
              
                """ 这里搜索了k个近邻点,和最近邻算法的唯一不同是,需要一个数组保存,当前的前k个近邻点, 而且判定条件,不是最近距离了,而是第K小的距离(结果的守门员), 只有当结果中的节点数不超过K或节点与输入实例的距离小于第K小的距离时才能进入结果数组 """
              
              
                def
              
              
                knnSearch
              
              
                (
              
              self
              
                ,
              
              x
              
                ,
              
              k
              
                )
              
              
                :
              
              
                """ 当整个训练数据集不超K个时,训练数据集都是近邻点。 直接借助一个字典进行统计类别,按照多数决策原则进行判断即可 """
              
              
                if
              
               self
              
                .
              
              dataNum 
              
                <=
              
               k
              
                :
              
               
            labelDict 
              
                =
              
              
                {
              
              
                }
              
              
                for
              
               element 
              
                in
              
               self
              
                .
              
              transferTreeToList
              
                (
              
              self
              
                .
              
              root
              
                )
              
              
                :
              
              
                if
              
               element
              
                [
              
              
                "label"
              
              
                ]
              
              
                not
              
              
                in
              
               labelDict
              
                :
              
              
                    labelDict
              
                [
              
              element
              
                [
              
              
                'label'
              
              
                ]
              
              
                ]
              
              
                =
              
              
                0
              
              
                labelDict
              
                [
              
              element
              
                [
              
              
                "label"
              
              
                ]
              
              
                ]
              
              
                +=
              
              
                1
              
              
            sortedLabelList 
              
                =
              
              
                sorted
              
              
                (
              
              labelDict
              
                .
              
              items
              
                (
              
              
                )
              
              
                ,
              
               key
              
                =
              
              
                lambda
              
               item
              
                :
              
              item
              
                [
              
              
                1
              
              
                ]
              
              
                ,
              
              reverse
              
                =
              
              
                True
              
              
                )
              
              
                # 对字典排序返回的是由元祖组成的一个列表。
              
              
                return
              
               sortedLabelList
              
                [
              
              
                0
              
              
                ]
              
              
                [
              
              
                0
              
              
                ]
              
              
                """ 先找到最近的叶子节点,然后递归的向上寻找 """
              
              
        node 
              
                =
              
               self
              
                .
              
              findtheNearestLeafNode
              
                (
              
              self
              
                .
              
              root
              
                ,
              
              x
              
                )
              
              
        nodeList 
              
                =
              
              
                [
              
              
                ]
              
              
                if
              
               node 
              
                ==
              
              
                None
              
              
                :
              
              
                # 如果是空树,直接返回None
              
              
                return
              
              
                None
              
              
        x 
              
                =
              
               np
              
                .
              
              array
              
                (
              
              x
              
                )
              
              
        distance 
              
                =
              
               np
              
                .
              
              sqrt
              
                (
              
              
                sum
              
              
                (
              
              
                (
              
              x
              
                -
              
              node
              
                .
              
              val
              
                )
              
              
                **
              
              
                2
              
              
                )
              
              
                )
              
              
                # 计算最近叶子节点和输入实例的距离
              
              
        nodeList
              
                .
              
              append
              
                (
              
              
                [
              
              distance
              
                ,
              
              
                tuple
              
              
                (
              
              node
              
                .
              
              val
              
                )
              
              
                ,
              
               node
              
                .
              
              label
              
                [
              
              
                0
              
              
                ]
              
              
                ]
              
              
                )
              
              
                # 将距离,节点实例和类别作为一个数组加入结果中。
              
              
                while
              
              
                True
              
              
                :
              
              
                # 循环
              
              
                if
              
               node 
              
                ==
              
               self
              
                .
              
              root
              
                :
              
              
                # 当循环到根节点时,停止循环
              
              
                break
              
              
            parentNode 
              
                =
              
                node
              
                .
              
              parent 
              
                # 找到当前节点的父节点
              
              
            parentDis 
              
                =
              
               np
              
                .
              
              sqrt
              
                (
              
              
                sum
              
              
                (
              
              
                (
              
              x
              
                -
              
              parentNode
              
                .
              
              val
              
                )
              
              
                **
              
              
                2
              
              
                )
              
              
                )
              
              
                # 计算输入实例x和父节点的距离
              
              
                if
              
               k 
              
                >
              
              
                len
              
              
                (
              
              nodeList
              
                )
              
              
                or
              
               distance 
              
                >
              
               parentDis
              
                :
              
              
                # 如果当前的结果中不足K个节点或与父节点的距离小于当前列表中距离x最大的距离,
              
              
                nodeList
              
                .
              
              append
              
                (
              
              
                [
              
              parentDis
              
                ,
              
              
                tuple
              
              
                (
              
              parentNode
              
                .
              
              val
              
                )
              
              
                ,
              
              parentNode
              
                .
              
              label
              
                [
              
              
                0
              
              
                ]
              
              
                ]
              
              
                )
              
              
                # 压入结果列表
              
              
                nodeList
              
                .
              
              sort
              
                (
              
              
                )
              
              
                # 排序
              
              
                distance 
              
                =
              
               nodeList
              
                [
              
              
                -
              
              
                1
              
              
                ]
              
              
                [
              
              
                0
              
              
                ]
              
              
                if
              
               k 
              
                >
              
              
                len
              
              
                (
              
              nodeList
              
                )
              
              
                else
              
               nodeList
              
                [
              
              k
              
                -
              
              
                1
              
              
                ]
              
              
                [
              
              
                0
              
              
                ]
              
              
                # 更新dis为入队节点中第K小的距离或者直接就是距离最大的距离
              
              
                if
              
               k 
              
                >
              
              
                len
              
              
                (
              
              nodeList
              
                )
              
              
                or
              
              
                abs
              
              
                (
              
              x
              
                [
              
              parentNode
              
                .
              
              dim
              
                ]
              
              
                -
              
               parentNode
              
                .
              
              val
              
                [
              
              parentNode
              
                .
              
              dim
              
                ]
              
              
                )
              
              
                <
              
               distance
              
                :
              
              
                # 判断另一子节点区域有没有距离更近的节点
              
              
                if
              
               x
              
                [
              
              parentNode
              
                .
              
              dim
              
                ]
              
              
                <
              
               parentNode
              
                .
              
              val
              
                [
              
              parentNode
              
                .
              
              dim
              
                ]
              
              
                :
              
              
                    otherChild 
              
                =
              
               parentNode
              
                .
              
              right 
                    
              
                # 如果x当前维度的值小于父节点的值
              
              
                # 说明x在父节点的左子树上,往右节点寻找
              
              
                    self
              
                .
              
              search
              
                (
              
              nodeList
              
                ,
              
              otherChild
              
                ,
              
              x
              
                ,
              
              k
              
                )
              
              
                # 递归的进行近邻点的寻找
              
              
                else
              
              
                :
              
              
                # 否则,往左子节点寻找
              
              
                    otherChild 
              
                =
              
               parentNode
              
                .
              
              left
                    self
              
                .
              
              search
              
                (
              
              nodeList
              
                ,
              
               otherChild
              
                ,
              
               x
              
                ,
              
               k
              
                )
              
              
            node 
              
                =
              
               node
              
                .
              
              parent
        labelDict 
              
                =
              
              
                {
              
              
                }
              
              
                # 统计类别,并判断实例点的类别
              
              
        nodeList 
              
                =
              
               nodeList
              
                [
              
              
                :
              
              k
              
                ]
              
              
                if
              
               k 
              
                <=
              
              
                len
              
              
                (
              
              nodeList
              
                )
              
              
                else
              
               nodeList
        
              
                for
              
               element 
              
                in
              
               nodeList
              
                :
              
              
                if
              
               element
              
                [
              
              
                2
              
              
                ]
              
              
                not
              
              
                in
              
               labelDict
              
                :
              
              
                labelDict
              
                [
              
              element
              
                [
              
              
                2
              
              
                ]
              
              
                ]
              
              
                =
              
              
                0
              
              
            labelDict
              
                [
              
              element
              
                [
              
              
                2
              
              
                ]
              
              
                ]
              
              
                +=
              
              
                1
              
              
        sortedLabel 
              
                =
              
              
                sorted
              
              
                (
              
              labelDict
              
                .
              
              items
              
                (
              
              
                )
              
              
                ,
              
              key
              
                =
              
              
                lambda
              
               x
              
                :
              
              x
              
                [
              
              
                1
              
              
                ]
              
              
                ,
              
              reverse
              
                =
              
              
                True
              
              
                )
              
              
                return
              
               sortedLabel
              
                [
              
              
                0
              
              
                ]
              
              
                [
              
              
                0
              
              
                ]
              
              
                def
              
              
                search
              
              
                (
              
              self
              
                ,
              
              nodeList
              
                ,
              
              root
              
                ,
              
              x
              
                ,
              
              k
              
                )
              
              
                :
              
              
                # 递归的进行k近邻的搜素,和上面的函数几乎一样,只是没有类别的统计和判断
              
              
                if
              
               root 
              
                ==
              
              
                None
              
              
                :
              
              
                return
              
               nodeList
        nodeList
              
                .
              
              sort
              
                (
              
              
                )
              
              
        dis 
              
                =
              
               nodeList
              
                [
              
              
                -
              
              
                1
              
              
                ]
              
              
                [
              
              
                0
              
              
                ]
              
              
                if
              
               k 
              
                >
              
              
                len
              
              
                (
              
              nodeList
              
                )
              
              
                else
              
               nodeList
              
                [
              
              k
              
                -
              
              
                1
              
              
                ]
              
              
                [
              
              
                0
              
              
                ]
              
              
        x 
              
                =
              
               np
              
                .
              
              array
              
                (
              
              x
              
                )
              
              
        node 
              
                =
              
               self
              
                .
              
              findtheNearestLeafNode
              
                (
              
              root
              
                ,
              
              x
              
                )
              
              
        distance 
              
                =
              
               np
              
                .
              
              sqrt
              
                (
              
              
                sum
              
              
                (
              
              
                (
              
              x 
              
                -
              
               node
              
                .
              
              val
              
                )
              
              
                **
              
              
                2
              
              
                )
              
              
                )
              
              
                if
              
               k 
              
                >
              
              
                len
              
              
                (
              
              nodeList
              
                )
              
              
                or
              
               distance 
              
                <
              
               dis
              
                :
              
              
            nodeList
              
                .
              
              append
              
                (
              
              
                [
              
              distance
              
                ,
              
              
                tuple
              
              
                (
              
              node
              
                .
              
              val
              
                )
              
              
                ,
              
               node
              
                .
              
              label
              
                [
              
              
                0
              
              
                ]
              
              
                ]
              
              
                )
              
              
            nodeList
              
                .
              
              sort
              
                (
              
              
                )
              
              
            dis 
              
                =
              
               nodeList
              
                [
              
              
                -
              
              
                1
              
              
                ]
              
              
                [
              
              
                0
              
              
                ]
              
              
                if
              
               k 
              
                >
              
              
                len
              
              
                (
              
              nodeList
              
                )
              
              
                else
              
               nodeList
              
                [
              
              k 
              
                -
              
              
                1
              
              
                ]
              
              
                [
              
              
                0
              
              
                ]
              
              
                while
              
              
                True
              
              
                :
              
              
                if
              
               node 
              
                ==
              
               root
              
                :
              
              
                break
              
              
            parentNode 
              
                =
              
               node
              
                .
              
              parent
            parentDis 
              
                =
              
               np
              
                .
              
              sqrt
              
                (
              
              
                sum
              
              
                (
              
              
                (
              
              x
              
                -
              
              parentNode
              
                .
              
              val
              
                )
              
              
                **
              
              
                2
              
              
                )
              
              
                )
              
              
                if
              
               k 
              
                >
              
              
                len
              
              
                (
              
              nodeList
              
                )
              
              
                or
              
               parentDis 
              
                <
              
               dis
              
                :
              
              
                nodeList
              
                .
              
              append
              
                (
              
              
                [
              
              parentDis
              
                ,
              
              
                tuple
              
              
                (
              
              parentNode
              
                .
              
              val
              
                )
              
              
                ,
              
              parentNode
              
                .
              
              label
              
                [
              
              
                0
              
              
                ]
              
              
                ]
              
              
                )
              
              
                nodeList
              
                .
              
              sort
              
                (
              
              
                )
              
              
                dis 
              
                =
              
               nodeList
              
                [
              
              
                -
              
              
                1
              
              
                ]
              
              
                [
              
              
                0
              
              
                ]
              
              
                if
              
               k 
              
                >
              
              
                len
              
              
                (
              
              nodeList
              
                )
              
              
                else
              
               nodeList
              
                [
              
              k 
              
                -
              
              
                1
              
              
                ]
              
              
                [
              
              
                0
              
              
                ]
              
              
                if
              
               k 
              
                >
              
              
                len
              
              
                (
              
              nodeList
              
                )
              
              
                or
              
              
                abs
              
              
                (
              
              x
              
                [
              
              parentNode
              
                .
              
              dim
              
                ]
              
              
                -
              
              parentNode
              
                .
              
              val
              
                [
              
              parentNode
              
                .
              
              dim
              
                ]
              
              
                )
              
              
                <
              
               dis
              
                :
              
              
                if
              
               x
              
                [
              
              parentNode
              
                .
              
              dim
              
                ]
              
              
                <
              
               parentNode
              
                .
              
              val
              
                [
              
              parentNode
              
                .
              
              val
              
                ]
              
              
                :
              
              
                    otherChild 
              
                =
              
               parentNode
              
                .
              
              right
                    self
              
                .
              
              search
              
                (
              
              nodeList
              
                ,
              
              otherChild
              
                ,
              
              x
              
                ,
              
              k
              
                )
              
              
                else
              
              
                :
              
              
                    otherChild 
              
                =
              
               parentNode
              
                .
              
              left
                    self
              
                .
              
              search
              
                (
              
              nodeList
              
                ,
              
               otherChild
              
                ,
              
               x
              
                ,
              
               k
              
                )
              
              
            node 
              
                =
              
               node
              
                .
              
              parent

            
          

4. 举例

            
              
                if
              
               __name__ 
              
                ==
              
              
                "__main__"
              
              
                :
              
              
    dataArray 
              
                =
              
              
                [
              
              
                [
              
              
                7
              
              
                ,
              
              
                2
              
              
                ]
              
              
                ,
              
              
                [
              
              
                5
              
              
                ,
              
              
                4
              
              
                ]
              
              
                ,
              
              
                [
              
              
                2
              
              
                ,
              
              
                3
              
              
                ]
              
              
                ,
              
              
                [
              
              
                4
              
              
                ,
              
              
                7
              
              
                ]
              
              
                ,
              
              
                [
              
              
                9
              
              
                ,
              
              
                6
              
              
                ]
              
              
                ,
              
              
                [
              
              
                8
              
              
                ,
              
              
                1
              
              
                ]
              
              
                ]
              
              
    label 
              
                =
              
              
                [
              
              
                [
              
              
                0
              
              
                ]
              
              
                ,
              
              
                [
              
              
                1
              
              
                ]
              
              
                ,
              
              
                [
              
              
                0
              
              
                ]
              
              
                ,
              
              
                [
              
              
                1
              
              
                ]
              
              
                ,
              
              
                [
              
              
                1
              
              
                ]
              
              
                ,
              
              
                [
              
              
                1
              
              
                ]
              
              
                ]
              
              
    kd 
              
                =
              
               kdTree
              
                (
              
              dataArray
              
                ,
              
               label
              
                )
              
              
    Tree 
              
                =
              
               kd
              
                .
              
              buildKdTree
              
                (
              
              dataArray
              
                ,
              
               label
              
                )
              
              
                ## tree是根节点
              
              
                list
              
              
                =
              
               kd
              
                .
              
              transferTreeToList
              
                (
              
              Tree
              
                ,
              
              
                [
              
              
                ]
              
              
                )
              
              
                dict
              
              
                =
              
               kd
              
                .
              
              transferTreeToDict
              
                (
              
              Tree
              
                )
              
              
    node 
              
                =
              
               kd
              
                .
              
              findtheNearestLeafNode
              
                (
              
              Tree
              
                ,
              
              
                [
              
              
                6
              
              
                ,
              
              
                3
              
              
                ]
              
              
                )
              
              
    result 
              
                =
              
               kd
              
                .
              
              knnSearch
              
                (
              
              
                [
              
              
                6
              
              
                ,
              
              
                3
              
              
                ]
              
              
                ,
              
              
                1
              
              
                )
              
              
                print
              
              
                (
              
              
                list
              
              
                )
              
              
                print
              
              
                (
              
              result
              
                )
              
              
                """ 输出结果为:[ {'data': array([7, 2]), 'left': array([5, 4]), 'right': array([9, 6]), 'parent': None, 'label': 0, 'dim': 0}, {'data': array([5, 4]), 'left': array([2, 3]), 'right': array([4, 7]), 'parent': array([7, 2]), 'label': 1, 'dim': 1}, {'data': array([2, 3]), 'left': None, 'right': None, 'parent': array([5, 4]), 'label': 0, 'dim': 0}, {'data': array([4, 7]), 'left': None, 'right': None, 'parent': array([5, 4]), 'label': 1, 'dim': 0}, {'data': array([9, 6]), 'left': array([8, 1]), 'right': None, 'parent': array([7, 2]), 'label': 1, 'dim': 1}, {'data': array([8, 1]), 'left': None, 'right': None, 'parent': array([9, 6]), 'label': 1, 'dim': 0}] """
              
              
                # 类别为:1
              
            
          

5. 完整代码

            
              
```python

              
                import
              
               numpy 
              
                as
              
               np

              
                class
              
              
                Node
              
              
                (
              
              
                object
              
              
                )
              
              
                :
              
              
                def
              
              
                __init__
              
              
                (
              
              self
              
                ,
              
              val
              
                =
              
              
                None
              
              
                ,
              
              label
              
                =
              
              
                None
              
              
                ,
              
              dim
              
                =
              
              
                None
              
              
                ,
              
              left
              
                =
              
              
                None
              
              
                ,
              
              right
              
                =
              
              
                None
              
              
                ,
              
              parent
              
                =
              
              
                None
              
              
                )
              
              
                :
              
              
        self
              
                .
              
              val 
              
                =
              
               val
        self
              
                .
              
              label 
              
                =
              
               label
        self
              
                .
              
              dim 
              
                =
              
               dim
        self
              
                .
              
              left 
              
                =
              
               left
        self
              
                .
              
              right 
              
                =
              
               right
        self
              
                .
              
              parent 
              
                =
              
               parent

              
                class
              
              
                kdTree
              
              
                (
              
              
                object
              
              
                )
              
              
                :
              
              
                def
              
              
                __init__
              
              
                (
              
              self
              
                ,
              
              dataSet
              
                ,
              
              labelList
              
                )
              
              
                :
              
              
        self
              
                .
              
              dataNum 
              
                =
              
              
                0
              
              
        self
              
                .
              
              root 
              
                =
              
               self
              
                .
              
              buildKdTree
              
                (
              
              dataSet
              
                ,
              
              labelList
              
                )
              
              
                ## 注意父节点的传值。
              
              
                def
              
              
                buildKdTree
              
              
                (
              
              self
              
                ,
              
              dataSet
              
                ,
              
               labelList
              
                ,
              
               parentNode
              
                =
              
              
                None
              
              
                )
              
              
                :
              
              
        data 
              
                =
              
               np
              
                .
              
              array
              
                (
              
              dataSet
              
                )
              
              
        dataNum
              
                ,
              
               dimNum 
              
                =
              
               data
              
                .
              
              shape
        label 
              
                =
              
               np
              
                .
              
              array
              
                (
              
              labelList
              
                )
              
              
                .
              
              reshape
              
                (
              
              dataNum
              
                ,
              
              
                1
              
              
                )
              
              
                if
              
               dataNum 
              
                ==
              
              
                0
              
              
                :
              
              
                return
              
              
                None
              
              
        varList 
              
                =
              
               self
              
                .
              
              getVar
              
                (
              
              data
              
                )
              
              
        mid 
              
                =
              
               dataNum 
              
                //
              
              
                2
              
              
        maxVarDimIndex 
              
                =
              
               varList
              
                .
              
              index
              
                (
              
              
                max
              
              
                (
              
              varList
              
                )
              
              
                )
              
              
        sortedDataIndex 
              
                =
              
               data
              
                [
              
              
                :
              
              
                ,
              
              maxVarDimIndex
              
                ]
              
              
                .
              
              argsort
              
                (
              
              
                )
              
              
        midDataIndex 
              
                =
              
               sortedDataIndex
              
                [
              
              mid
              
                ]
              
              
                if
              
               dataNum 
              
                ==
              
              
                1
              
              
                :
              
              
            self
              
                .
              
              dataNum 
              
                =
              
               dataNum
            
              
                return
              
               Node
              
                (
              
              val 
              
                =
              
               data
              
                [
              
              midDataIndex
              
                ]
              
              
                ,
              
              label 
              
                =
              
               label
              
                [
              
              midDataIndex
              
                ]
              
              
                ,
              
              dim 
              
                =
              
               maxVarDimIndex
              
                ,
              
              left 
              
                =
              
              
                None
              
              
                ,
              
              right 
              
                =
              
              
                None
              
              
                ,
              
              parent 
              
                =
              
               parentNode
              
                )
              
              
        root 
              
                =
              
               Node
              
                (
              
              data
              
                [
              
              midDataIndex
              
                ]
              
              
                ,
              
              label
              
                [
              
              midDataIndex
              
                ]
              
              
                ,
              
              maxVarDimIndex
              
                ,
              
              parent 
              
                =
              
               parentNode
              
                ,
              
              
                )
              
              
        leftDataSet 
              
                =
              
                data
              
                [
              
              sortedDataIndex
              
                [
              
              
                :
              
              mid
              
                ]
              
              
                ]
              
              
                ##### 注意是mid不是midDataIndex
              
              
        leftLabel 
              
                =
              
               label
              
                [
              
              sortedDataIndex
              
                [
              
              
                :
              
              mid
              
                ]
              
              
                ]
              
              
        rightDataSet 
              
                =
              
               data
              
                [
              
              sortedDataIndex
              
                [
              
              mid
              
                +
              
              
                1
              
              
                :
              
              
                ]
              
              
                ]
              
              
        rightLabel 
              
                =
              
               label
              
                [
              
              sortedDataIndex
              
                [
              
              mid
              
                +
              
              
                1
              
              
                :
              
              
                ]
              
              
                ]
              
              
        root
              
                .
              
              left 
              
                =
              
               self
              
                .
              
              buildKdTree
              
                (
              
              leftDataSet
              
                ,
              
              leftLabel
              
                ,
              
              parentNode 
              
                =
              
               root
              
                )
              
              
        root
              
                .
              
              right 
              
                =
              
               self
              
                .
              
              buildKdTree
              
                (
              
              rightDataSet
              
                ,
              
               rightLabel
              
                ,
              
              parentNode 
              
                =
              
               root
              
                )
              
              
        self
              
                .
              
              dataNum 
              
                =
              
               dataNum
        
              
                return
              
               root

    
              
                def
              
              
                root
              
              
                (
              
              self
              
                )
              
              
                :
              
              
                return
              
               self
              
                .
              
              root

    
              
                def
              
              
                transferTreeToDict
              
              
                (
              
              self
              
                ,
              
              root
              
                )
              
              
                :
              
              
                if
              
               root 
              
                ==
              
              
                None
              
              
                :
              
              
                return
              
              
                None
              
              
                """ 字典的键必须是不可变的 """
              
              
                dict
              
              
                =
              
              
                {
              
              
                }
              
              
                dict
              
              
                [
              
              
                tuple
              
              
                (
              
              root
              
                .
              
              val
              
                )
              
              
                ]
              
              
                =
              
              
                {
              
              
                }
              
              
                dict
              
              
                [
              
              
                tuple
              
              
                (
              
              root
              
                .
              
              val
              
                )
              
              
                ]
              
              
                [
              
              
                "label"
              
              
                ]
              
              
                =
              
               root
              
                .
              
              label
              
                [
              
              
                0
              
              
                ]
              
              
                # root.label是一个数组,要想返回值的话用下标即可。
              
              
                dict
              
              
                [
              
              
                tuple
              
              
                (
              
              root
              
                .
              
              val
              
                )
              
              
                ]
              
              
                [
              
              
                "dim"
              
              
                ]
              
              
                =
              
               root
              
                .
              
              dim
        
              
                dict
              
              
                [
              
              
                tuple
              
              
                (
              
              root
              
                .
              
              val
              
                )
              
              
                ]
              
              
                [
              
              
                "parent"
              
              
                ]
              
              
                =
              
               root
              
                .
              
              parent
              
                .
              
              val 
              
                if
              
               root
              
                .
              
              parent 
              
                else
              
              
                None
              
              
                dict
              
              
                [
              
              
                tuple
              
              
                (
              
              root
              
                .
              
              val
              
                )
              
              
                ]
              
              
                [
              
              
                "left"
              
              
                ]
              
              
                =
              
               self
              
                .
              
              transferTreeToDict
              
                (
              
              root
              
                .
              
              left
              
                )
              
              
                dict
              
              
                [
              
              
                tuple
              
              
                (
              
              root
              
                .
              
              val
              
                )
              
              
                ]
              
              
                [
              
              
                "right"
              
              
                ]
              
              
                =
              
               self
              
                .
              
              transferTreeToDict
              
                (
              
              root
              
                .
              
              right
              
                )
              
              
                return
              
              
                dict
              
              
                def
              
              
                transferTreeToList
              
              
                (
              
              self
              
                ,
              
              root
              
                ,
              
              rootList 
              
                =
              
              
                [
              
              
                ]
              
              
                )
              
              
                :
              
              
                if
              
               root 
              
                ==
              
              
                None
              
              
                :
              
              
                return
              
              
                None
              
              
        tempDict 
              
                =
              
              
                {
              
              
                }
              
              
        tempDict
              
                [
              
              
                "data"
              
              
                ]
              
              
                =
              
               root
              
                .
              
              val
        tempDict
              
                [
              
              
                "left"
              
              
                ]
              
              
                =
              
               root
              
                .
              
              left
              
                .
              
              val 
              
                if
              
               root
              
                .
              
              left 
              
                else
              
              
                None
              
              
        tempDict
              
                [
              
              
                "right"
              
              
                ]
              
              
                =
              
               root
              
                .
              
              right
              
                .
              
              val 
              
                if
              
               root
              
                .
              
              right 
              
                else
              
              
                None
              
              
        tempDict
              
                [
              
              
                "parent"
              
              
                ]
              
              
                =
              
               root
              
                .
              
              parent
              
                .
              
              val 
              
                if
              
               root
              
                .
              
              parent 
              
                else
              
              
                None
              
              
        tempDict
              
                [
              
              
                "label"
              
              
                ]
              
              
                =
              
               root
              
                .
              
              label
              
                [
              
              
                0
              
              
                ]
              
              
        tempDict
              
                [
              
              
                "dim"
              
              
                ]
              
              
                =
              
               root
              
                .
              
              dim
        rootList
              
                .
              
              append
              
                (
              
              tempDict
              
                )
              
              
        self
              
                .
              
              transferTreeToList
              
                (
              
              root
              
                .
              
              left
              
                ,
              
              rootList
              
                )
              
              
        self
              
                .
              
              transferTreeToList
              
                (
              
              root
              
                .
              
              right
              
                ,
              
              rootList
              
                )
              
              
                return
              
               rootList

    
              
                def
              
              
                getVar
              
              
                (
              
              self
              
                ,
              
              data
              
                )
              
              
                :
              
              
        rowLen
              
                ,
              
              colLen 
              
                =
              
               data
              
                .
              
              shape
        varList 
              
                =
              
              
                [
              
              
                ]
              
              
                for
              
               i 
              
                in
              
              
                range
              
              
                (
              
              colLen
              
                )
              
              
                :
              
              
            varList
              
                .
              
              append
              
                (
              
              np
              
                .
              
              var
              
                (
              
              data
              
                [
              
              
                :
              
              
                ,
              
              i
              
                ]
              
              
                )
              
              
                )
              
              
                return
              
               varList

    
              
                def
              
              
                findtheNearestLeafNode
              
              
                (
              
              self
              
                ,
              
              root
              
                ,
              
              x
              
                )
              
              
                :
              
              
                if
              
               root 
              
                ==
              
              
                None
              
              
                :
              
              
                #或者直接用self.dataNum是否等于0即可检查
              
              
                return
              
              
                None
              
              
                if
              
               root
              
                .
              
              left 
              
                ==
              
              
                None
              
              
                and
              
               root
              
                .
              
              right 
              
                ==
              
              
                None
              
              
                :
              
              
                return
              
               root
        node 
              
                =
              
               root
        
              
                while
              
              
                True
              
              
                :
              
              
            curDim 
              
                =
              
               node
              
                .
              
              dim
            
              
                if
              
               x
              
                [
              
              curDim
              
                ]
              
              
                <
              
               node
              
                .
              
              val
              
                [
              
              curDim
              
                ]
              
              
                :
              
              
                if
              
              
                not
              
               node
              
                .
              
              left
              
                :
              
              
                return
              
               node
                node 
              
                =
              
               node
              
                .
              
              left
            
              
                else
              
              
                :
              
              
                if
              
              
                not
              
               node
              
                .
              
              right
              
                :
              
              
                return
              
               node
                node 
              
                =
              
               node
              
                .
              
              right

    
              
                def
              
              
                knnSearch
              
              
                (
              
              self
              
                ,
              
              x
              
                ,
              
              k
              
                )
              
              
                :
              
              
                if
              
               self
              
                .
              
              dataNum 
              
                <=
              
               k
              
                :
              
              
            labelDict 
              
                =
              
              
                {
              
              
                }
              
              
                for
              
               element 
              
                in
              
               self
              
                .
              
              transferTreeToList
              
                (
              
              self
              
                .
              
              root
              
                )
              
              
                :
              
              
                if
              
               element
              
                [
              
              
                "label"
              
              
                ]
              
              
                not
              
              
                in
              
               labelDict
              
                :
              
              
                    labelDict
              
                [
              
              element
              
                [
              
              
                'label'
              
              
                ]
              
              
                ]
              
              
                =
              
              
                0
              
              
                labelDict
              
                [
              
              element
              
                [
              
              
                "label"
              
              
                ]
              
              
                ]
              
              
                +=
              
              
                1
              
              
            sortedLabelList 
              
                =
              
              
                sorted
              
              
                (
              
              labelDict
              
                .
              
              items
              
                (
              
              
                )
              
              
                ,
              
               key
              
                =
              
              
                lambda
              
               item
              
                :
              
              item
              
                [
              
              
                1
              
              
                ]
              
              
                ,
              
              reverse
              
                =
              
              
                True
              
              
                )
              
              
                # 对字典排序返回的是由元祖组成的一个列表。
              
              
                return
              
               sortedLabelList
              
                [
              
              
                0
              
              
                ]
              
              
                [
              
              
                0
              
              
                ]
              
              
        node 
              
                =
              
               self
              
                .
              
              findtheNearestLeafNode
              
                (
              
              self
              
                .
              
              root
              
                ,
              
              x
              
                )
              
              
        nodeList 
              
                =
              
              
                [
              
              
                ]
              
              
                if
              
               node 
              
                ==
              
              
                None
              
              
                :
              
              
                return
              
              
                None
              
              
        x 
              
                =
              
               np
              
                .
              
              array
              
                (
              
              x
              
                )
              
              
        distance 
              
                =
              
               np
              
                .
              
              sqrt
              
                (
              
              
                sum
              
              
                (
              
              
                (
              
              x
              
                -
              
              node
              
                .
              
              val
              
                )
              
              
                **
              
              
                2
              
              
                )
              
              
                )
              
              
        nodeList
              
                .
              
              append
              
                (
              
              
                [
              
              distance
              
                ,
              
              
                tuple
              
              
                (
              
              node
              
                .
              
              val
              
                )
              
              
                ,
              
               node
              
                .
              
              label
              
                [
              
              
                0
              
              
                ]
              
              
                ]
              
              
                )
              
              
                while
              
              
                True
              
              
                :
              
              
                if
              
               node 
              
                ==
              
               self
              
                .
              
              root
              
                :
              
              
                break
              
              
            parentNode 
              
                =
              
                node
              
                .
              
              parent
            parentDis 
              
                =
              
               np
              
                .
              
              sqrt
              
                (
              
              
                sum
              
              
                (
              
              
                (
              
              x
              
                -
              
              parentNode
              
                .
              
              val
              
                )
              
              
                **
              
              
                2
              
              
                )
              
              
                )
              
              
                if
              
               k 
              
                >
              
              
                len
              
              
                (
              
              nodeList
              
                )
              
              
                or
              
               distance 
              
                >
              
               parentDis
              
                :
              
              
                nodeList
              
                .
              
              append
              
                (
              
              
                [
              
              parentDis
              
                ,
              
              
                tuple
              
              
                (
              
              parentNode
              
                .
              
              val
              
                )
              
              
                ,
              
              parentNode
              
                .
              
              label
              
                [
              
              
                0
              
              
                ]
              
              
                ]
              
              
                )
              
              
                nodeList
              
                .
              
              sort
              
                (
              
              
                )
              
              
                distance 
              
                =
              
               nodeList
              
                [
              
              
                -
              
              
                1
              
              
                ]
              
              
                [
              
              
                0
              
              
                ]
              
              
                if
              
               k 
              
                >
              
              
                len
              
              
                (
              
              nodeList
              
                )
              
              
                else
              
               nodeList
              
                [
              
              k
              
                -
              
              
                1
              
              
                ]
              
              
                [
              
              
                0
              
              
                ]
              
              
                if
              
               k 
              
                >
              
              
                len
              
              
                (
              
              nodeList
              
                )
              
              
                or
              
              
                abs
              
              
                (
              
              x
              
                [
              
              parentNode
              
                .
              
              dim
              
                ]
              
              
                -
              
               parentNode
              
                .
              
              val
              
                [
              
              parentNode
              
                .
              
              dim
              
                ]
              
              
                )
              
              
                <
              
               distance
              
                :
              
              
                if
              
               x
              
                [
              
              parentNode
              
                .
              
              dim
              
                ]
              
              
                <
              
               parentNode
              
                .
              
              val
              
                [
              
              parentNode
              
                .
              
              dim
              
                ]
              
              
                :
              
              
                    otherChild 
              
                =
              
               parentNode
              
                .
              
              right
                    self
              
                .
              
              search
              
                (
              
              nodeList
              
                ,
              
              otherChild
              
                ,
              
              x
              
                ,
              
              k
              
                )
              
              
                else
              
              
                :
              
              
                    otherChild 
              
                =
              
               parentNode
              
                .
              
              left
                    self
              
                .
              
              search
              
                (
              
              nodeList
              
                ,
              
               otherChild
              
                ,
              
               x
              
                ,
              
               k
              
                )
              
              
            node 
              
                =
              
               node
              
                .
              
              parent
        labelDict 
              
                =
              
              
                {
              
              
                }
              
              
        nodeList 
              
                =
              
               nodeList
              
                [
              
              
                :
              
              k
              
                ]
              
              
                if
              
               k 
              
                <=
              
              
                len
              
              
                (
              
              nodeList
              
                )
              
              
                else
              
               nodeList
        
              
                for
              
               element 
              
                in
              
               nodeList
              
                :
              
              
                if
              
               element
              
                [
              
              
                2
              
              
                ]
              
              
                not
              
              
                in
              
               labelDict
              
                :
              
              
                labelDict
              
                [
              
              element
              
                [
              
              
                2
              
              
                ]
              
              
                ]
              
              
                =
              
              
                0
              
              
            labelDict
              
                [
              
              element
              
                [
              
              
                2
              
              
                ]
              
              
                ]
              
              
                +=
              
              
                1
              
              
        sortedLabel 
              
                =
              
              
                sorted
              
              
                (
              
              labelDict
              
                .
              
              items
              
                (
              
              
                )
              
              
                ,
              
              key
              
                =
              
              
                lambda
              
               x
              
                :
              
              x
              
                [
              
              
                1
              
              
                ]
              
              
                ,
              
              reverse
              
                =
              
              
                True
              
              
                )
              
              
                return
              
               sortedLabel
              
                [
              
              
                0
              
              
                ]
              
              
                [
              
              
                0
              
              
                ]
              
              
                def
              
              
                search
              
              
                (
              
              self
              
                ,
              
              nodeList
              
                ,
              
              root
              
                ,
              
              x
              
                ,
              
              k
              
                )
              
              
                :
              
              
                if
              
               root 
              
                ==
              
              
                None
              
              
                :
              
              
                return
              
               nodeList
        nodeList
              
                .
              
              sort
              
                (
              
              
                )
              
              
        dis 
              
                =
              
               nodeList
              
                [
              
              
                -
              
              
                1
              
              
                ]
              
              
                [
              
              
                0
              
              
                ]
              
              
                if
              
               k 
              
                >
              
              
                len
              
              
                (
              
              nodeList
              
                )
              
              
                else
              
               nodeList
              
                [
              
              k
              
                -
              
              
                1
              
              
                ]
              
              
                [
              
              
                0
              
              
                ]
              
              
        x 
              
                =
              
               np
              
                .
              
              array
              
                (
              
              x
              
                )
              
              
        node 
              
                =
              
               self
              
                .
              
              findtheNearestLeafNode
              
                (
              
              root
              
                ,
              
              x
              
                )
              
              
        distance 
              
                =
              
               np
              
                .
              
              sqrt
              
                (
              
              
                sum
              
              
                (
              
              
                (
              
              x 
              
                -
              
               node
              
                .
              
              val
              
                )
              
              
                **
              
              
                2
              
              
                )
              
              
                )
              
              
                if
              
               k 
              
                >
              
              
                len
              
              
                (
              
              nodeList
              
                )
              
              
                or
              
               distance 
              
                <
              
               dis
              
                :
              
              
            nodeList
              
                .
              
              append
              
                (
              
              
                [
              
              distance
              
                ,
              
              
                tuple
              
              
                (
              
              node
              
                .
              
              val
              
                )
              
              
                ,
              
               node
              
                .
              
              label
              
                [
              
              
                0
              
              
                ]
              
              
                ]
              
              
                )
              
              
            nodeList
              
                .
              
              sort
              
                (
              
              
                )
              
              
            dis 
              
                =
              
               nodeList
              
                [
              
              
                -
              
              
                1
              
              
                ]
              
              
                [
              
              
                0
              
              
                ]
              
              
                if
              
               k 
              
                >
              
              
                len
              
              
                (
              
              nodeList
              
                )
              
              
                else
              
               nodeList
              
                [
              
              k 
              
                -
              
              
                1
              
              
                ]
              
              
                [
              
              
                0
              
              
                ]
              
              
                while
              
              
                True
              
              
                :
              
              
                if
              
               node 
              
                ==
              
               root
              
                :
              
              
                break
              
              
            parentNode 
              
                =
              
               node
              
                .
              
              parent
            parentDis 
              
                =
              
               np
              
                .
              
              sqrt
              
                (
              
              
                sum
              
              
                (
              
              
                (
              
              x
              
                -
              
              parentNode
              
                .
              
              val
              
                )
              
              
                **
              
              
                2
              
              
                )
              
              
                )
              
              
                if
              
               k 
              
                >
              
              
                len
              
              
                (
              
              nodeList
              
                )
              
              
                or
              
               parentDis 
              
                <
              
               dis
              
                :
              
              
                nodeList
              
                .
              
              append
              
                (
              
              
                [
              
              parentDis
              
                ,
              
              
                tuple
              
              
                (
              
              parentNode
              
                .
              
              val
              
                )
              
              
                ,
              
              parentNode
              
                .
              
              label
              
                [
              
              
                0
              
              
                ]
              
              
                ]
              
              
                )
              
              
                nodeList
              
                .
              
              sort
              
                (
              
              
                )
              
              
                dis 
              
                =
              
               nodeList
              
                [
              
              
                -
              
              
                1
              
              
                ]
              
              
                [
              
              
                0
              
              
                ]
              
              
                if
              
               k 
              
                >
              
              
                len
              
              
                (
              
              nodeList
              
                )
              
              
                else
              
               nodeList
              
                [
              
              k 
              
                -
              
              
                1
              
              
                ]
              
              
                [
              
              
                0
              
              
                ]
              
              
                if
              
               k 
              
                >
              
              
                len
              
              
                (
              
              nodeList
              
                )
              
              
                or
              
              
                abs
              
              
                (
              
              x
              
                [
              
              parentNode
              
                .
              
              dim
              
                ]
              
              
                -
              
              parentNode
              
                .
              
              val
              
                [
              
              parentNode
              
                .
              
              dim
              
                ]
              
              
                )
              
              
                <
              
               dis
              
                :
              
              
                if
              
               x
              
                [
              
              parentNode
              
                .
              
              dim
              
                ]
              
              
                <
              
               parentNode
              
                .
              
              val
              
                [
              
              parentNode
              
                .
              
              val
              
                ]
              
              
                :
              
              
                    otherChild 
              
                =
              
               parentNode
              
                .
              
              right
                    self
              
                .
              
              search
              
                (
              
              nodeList
              
                ,
              
              otherChild
              
                ,
              
              x
              
                ,
              
              k
              
                )
              
              
                else
              
              
                :
              
              
                    otherChild 
              
                =
              
               parentNode
              
                .
              
              left
                    self
              
                .
              
              search
              
                (
              
              nodeList
              
                ,
              
               otherChild
              
                ,
              
               x
              
                ,
              
               k
              
                )
              
              
            node 
              
                =
              
               node
              
                .
              
              parent


              
                if
              
               __name__ 
              
                ==
              
              
                "__main__"
              
              
                :
              
              
    dataArray 
              
                =
              
              
                [
              
              
                [
              
              
                7
              
              
                ,
              
              
                2
              
              
                ]
              
              
                ,
              
              
                [
              
              
                5
              
              
                ,
              
              
                4
              
              
                ]
              
              
                ,
              
              
                [
              
              
                2
              
              
                ,
              
              
                3
              
              
                ]
              
              
                ,
              
              
                [
              
              
                4
              
              
                ,
              
              
                7
              
              
                ]
              
              
                ,
              
              
                [
              
              
                9
              
              
                ,
              
              
                6
              
              
                ]
              
              
                ,
              
              
                [
              
              
                8
              
              
                ,
              
              
                1
              
              
                ]
              
              
                ]
              
              
    label 
              
                =
              
              
                [
              
              
                [
              
              
                0
              
              
                ]
              
              
                ,
              
              
                [
              
              
                1
              
              
                ]
              
              
                ,
              
              
                [
              
              
                0
              
              
                ]
              
              
                ,
              
              
                [
              
              
                1
              
              
                ]
              
              
                ,
              
              
                [
              
              
                1
              
              
                ]
              
              
                ,
              
              
                [
              
              
                1
              
              
                ]
              
              
                ]
              
              
    kd 
              
                =
              
               kdTree
              
                (
              
              dataArray
              
                ,
              
               label
              
                )
              
              
    Tree 
              
                =
              
               kd
              
                .
              
              buildKdTree
              
                (
              
              dataArray
              
                ,
              
               label
              
                )
              
              
                ## tree是根节点
              
              
                list
              
              
                =
              
               kd
              
                .
              
              transferTreeToList
              
                (
              
              Tree
              
                ,
              
              
                [
              
              
                ]
              
              
                )
              
              
                dict
              
              
                =
              
               kd
              
                .
              
              transferTreeToDict
              
                (
              
              Tree
              
                )
              
              
    node 
              
                =
              
               kd
              
                .
              
              findtheNearestLeafNode
              
                (
              
              Tree
              
                ,
              
              
                [
              
              
                6
              
              
                ,
              
              
                3
              
              
                ]
              
              
                )
              
              
    result 
              
                =
              
               kd
              
                .
              
              knnSearch
              
                (
              
              
                [
              
              
                6
              
              
                ,
              
              
                3
              
              
                ]
              
              
                ,
              
              
                1
              
              
                )
              
              
                print
              
              
                (
              
              
                list
              
              
                )
              
              
                print
              
              
                (
              
              
                dict
              
              
                )
              
              
                print
              
              
                (
              
              result
              
                )
              
              
                print
              
              
                (
              
              node
              
                .
              
              val
              
                )
              
            
          
            
            
          

更多文章、技术交流、商务合作、联系博主

微信扫码或搜索:z360901061

微信扫一扫加我为好友

QQ号联系: 360901061

您的支持是博主写作最大的动力,如果您喜欢我的文章,感觉我的文章对您有帮助,请请扫描上面二维码支持博主1元、2元、5元等您想捐的金额吧,狠狠点击下面给点支持吧

发表我的评论
最新评论 总共0条评论