引言
我们可能会有这样的一种需求,像是打车软件中呼叫附近的车来接送自己,或者是在qq中查看附近的人。我们都需要知道距离自己一定范围内的其它目标的集合。如果将上面举例的功能抽象出来,就是要实现以某个点为中心,以一定的距离为半径,在空间中查找其它点所构成的集合。诚然,当空间中点的数目较少时,我们可以采用遍历所有点的方式来计算出当前点与其它点之间的距离的方式来得到对应的结果集,但是空间中的点数目较多(假如达到千万级别),且存在多个点要计算出距离当前点一定范围内的点所构成的集合时,这个计算的时间复杂度便达到了O(/(mn/))级别,为此,我们需要改进该实现方式。
如何实现
我们先不考虑多维空间中的点,先考虑一维平面上的数,假如为1,92,8,11,18,91,7,47这几个数。如果我们想得到某个数一定距离范围内的那几个数,如数字11,想得到距离不超过为7的数(<=7)。那么我们可以将这些数按照从小到大的顺序进行排序,使其有序,然后找到对应的数字,以该数为中心,分别往左和往右遍历对应的点并计算距离,如果是在符合要求的范围内,则将其纳入结果集中。以上面例子为例,将数字进行排序,得到 1,7,8,11,18,47,91,92的数字序列。然后找到数11,分别往左计算各点的距离,得到符合条件的数7,8(11-1=10>7不符合条件)。往右计算各点的距离,得到符合条件的数18(47-11=36>7不符合条件)。为此,可以得到符合条件的结果集7,8,18。
除了将数进行排序这种方式,我们也可以采用二叉树这一数据结构,通过直接定位到节点7然后搜索查找以节点7为根的子树来缩小需要进行查找的数据范围。需要注意的是,如果需要得到一定距离范围内的数为新加入的数,那么可以采用先找到该数最临近的数的方式并以该数为中心点进行查找的方式,也可以采用将该数加入到集合中,并进行重新排序然后以该数位中心点的方式进行查找。
对于空间中的点,我们也可以采用类似的方式进行考虑,由于是多维的,我们就得考虑将其降维,使其成为一维,并使其在空间中相邻的点,在一维空间中也大体保持相邻。除了降维之外,我们也可考虑将空间根据某种方式划分为多个子空间,以减少需要进行搜索查找的数据范围。为此,我们有如下两种方式。
1.kd-tree数据结构方式
基本原理: 其基本原理是将空间按照各个维度依次对其进行分割成多个子空间,使得空间中的点集均匀的分布在按照对应维度分割的子空间中,从而达到在搜索目标点时减少在空间中搜索的数据范围的目的。
介绍: kd-tree是一种多维空间划分的的二叉树,其可以看成是将线段树扩充到多维的变体。二叉树上的每个节点都对应了一个k维的超矩形区域。对于将空间进行划分的维度的选择其可以有多种方式,一般就是按照某个维度中方差较大的那个维度或者轮流交替维度选择的方式。空间进行分割维度的选择方式会由于数据的分布情况不同而影响查找的效率。
kd-tree构建过程: 建树的过程为一个递归的过程,其非叶子节点不存储任何坐标值,只存储划分的维度以及对应的划分值,叶子节点保存对应的节点值,该方式便于后续的查找过程的实现。相关伪代码如下:
function buildTree(dataSet): if(len(dataSet) == 0): return none; //只剩下一个点坐标或者所有坐标值都相同,则其为叶子节点 if(len(dataSet) == 1 || getMaxVariance(dataSet) == 0): return buildNode(dataSet); //得到分隔的维度以及分隔维度所对应的值 maxVarianceSplitDimension <- getSplitDimension(dataSet); sort(dataSet,sortedBy maxVarianceSplitDimension); medianData <- len(dataSet)//2; splitValue <- dataSet[medianData].getValue(maxVarianceSplitDimension); //将点划分为两部分,一部分为小于划分值的,一部分为大于等于划分值的 for data in dataSet: if(data.getValue(maxVarianceSplitDimension) < medianData.getValue(maxVarianceSplitDimension)): leftDataSet.add(data); else: rightDataSet.add(data); root <- buildNode(splitValue,maxVarianceSplitDimension); left <- buildTree(leftDataSet); right <- buildTree(rightDataSet); root.left <- left; root.right <- right; left.parent <- root; right.parent <- root; return root;
kd-tree查找过程: kd-tree的查找过程是一个二分搜索加回溯的过程,每次二分搜索到叶子节点之后,会回溯回另一个未曾访问过的分支,去判断该分支下是否有更接近的点。相关伪代码如下:
// point为要查找临近点的坐标,m为要查找与点point最临近的点的数目 function search(kdTree,point,m): stack.push(kdTree.root); while !stack.isEmpty(): node <- stack.pop(); if(isLeaf(node)): distance <- getDistance(point,node); if(priorityQueue.size() < m): priorityQueue.push(saveNode(node,distance)); else if(priorityQueue.peek(sortedBy distance) > distance): priorityQueue.pop(sortedBy distance); priorityQueue.push(saveNode(node,distance)); continue; // 查找到的节点数目小于m个时,需要回溯该分支下的相关节点进行查找,如果节点数等于m个时,需要拿出当前查找到的点中距离最远的点进行判断,判断是否要进行回溯查找 currentMaxDistance <- priorityQueue.size() < m ? MAX_VALUE:priorityQueue.peek(sortedBy distance).distance; //当为根节点或者和当前节点的父节点判断发现存在超球面和超矩形相交时,进入对应的分支 if (node == root || Math.abs(point.getValue(node.parent.splitDimension) - node.parent.splitValue) < currentMaxDistance): while !isLeaf(node): if(point.getValue(node.splitDimension) < node.splitValue): stack.push(node.right); node <- node.left; else: stack.push(node.left); node <- node.right; stack.push(node); return priorityQueue;
实现:
package com.example.nearest; import lombok.AllArgsConstructor; import lombok.Data; import lombok.NoArgsConstructor; import java.util.*; /** * @author 学徒 */ @Data public class KDTree { /** * 根节点 */ private Node root; public KDTree(List<double[]> dataSet) { if (Objects.isNull(dataSet) || dataSet.size() == 0) { return; } this.root = this.buildTree(dataSet); } public KDTree(double[][] dataSet) { if (Objects.isNull(dataSet) || dataSet.length == 0 || dataSet[0].length == 0) { return; } List<double[]> data = new ArrayList<>(dataSet.length); for (double[] d : dataSet) { data.add(d); } this.root = this.buildTree(data); } /** * 树节点 */ private class Node { /** * 分割维度 */ private int splitDimension; /** * 分割值 */ private double splitValue; /** * 点坐标值 */ private double[] value; private Node left; private Node right; private Node parent; public Node(int splitDimension, double splitValue, Node left, Node right) { this.splitDimension = splitDimension; this.splitValue = splitValue; this.left = left; this.right = right; } public Node(double[] value) { this.value = value; } public boolean isLeaf() { return Objects.isNull(left) && Objects.isNull(right); } } /** * 用于查找数据的辅助点 */ @Data @AllArgsConstructor @NoArgsConstructor private class SearchNode { /** * 中心点到目标点的距离 */ private double distance; /** * 目标点节点 */ private Node node; } /** * 用于根据数据集建立对应的kd树 * * @param dataSet 点构成的数据集 */ private Node buildTree(List<double[]> dataSet) { if (dataSet.size() == 1) { return new Node(dataSet.get(0)); } // 最大方差及其相关维度 int dimensions = dataSet.get(0).length; double maxVariance = Double.MIN_VALUE; int splitDimension = -1; for (int dimension = 0; dimension < dimensions; dimension++) { double tempVariance = Utils.getVariance(dataSet, dimension); if (maxVariance < tempVariance) { maxVariance = tempVariance; splitDimension = dimension; } } // 最大方差各个维度相同。且均为0,则表示各个点坐标相同 if (Objects.equals(maxVariance, 0)) { return new Node(dataSet.get(0)); } // 分割为左右两个子树 double splitValue = Utils.getMedian(dataSet, splitDimension); int size = (dataSet.size() + 1) / 2; List<double[]> left = new ArrayList<>(size); List<double[]> right = new ArrayList<>(size); for (double[] data : dataSet) { if (data[splitDimension] < splitValue) { left.add(data); } else { right.add(data); } } Node leftTree = buildTree(left); Node rightTree = buildTree(right); Node root = new Node(splitDimension, splitValue, leftTree, rightTree); leftTree.parent = root; rightTree.parent = root; return root; } /** * 用于查找该KDTree中距离指定点最近的m个点 * * @param point 需要进行查找的点 * @param number 进行查找的点的数量 * @return 距离最近的点的数量构成的列表 */ public List<double[]> search(double[] point, int number) { Queue<SearchNode> searchNodes = assistSearch(point, number); List<double[]> result = new ArrayList<>(searchNodes.size()); for (SearchNode node : searchNodes) { result.add(node.getNode().value); } return result; } /** * 用于得到某个点中的最接近点的坐标 * * @param point 查找点 * @return 最接近的点坐标 */ public double[] searchNearest(double[] point) { List<double[]> result = search(point, 1); if (result.size() == 1) { return result.get(0); } return null; } /** * 用于查找该KDTree中距离指定点最近的m个点 * * @param point 需要进行查找的点 * @param number 进行查找的点的数量 * @return 距离最近的点的数量构成的队列 */ private Queue<SearchNode> assistSearch(double[] point, int number) { Stack<Node> stack = new Stack<>(); Queue<SearchNode> result = new PriorityQueue<>(number, (o1, o2) -> Double.compare(o2.distance, o1.distance)); stack.push(this.root); while (!stack.isEmpty()) { Node node = stack.pop(); if (node.isLeaf()) { double distance = Utils.getDistance(point, node.value); if (result.size() < number) { result.add(new SearchNode(distance, node)); } else if (result.peek().distance > distance) { result.poll(); result.add(new SearchNode(distance, node)); } continue; } // 查找到的节点数目小于m个时,需要回溯该分支下的相关节点进行查找, // 如果节点数等于m个时,需要拿出当前查找到的点中距离最远的点进行判断, // 判断是否要进行回溯查找 double currentMaxDistance = result.size() < number ? Double.MAX_VALUE : result.peek().distance; //当为根节点或者和当前节点的父节点判断发现存在超球面和超矩形相交时,进入对应的分支 if (Objects.equals(node, root) || Math.abs(point[node.parent.splitDimension] - node.parent.splitValue) < currentMaxDistance) { while (!node.isLeaf()) { if (point[node.splitDimension] < node.splitValue) { stack.push(node.right); node = node.left; } else { stack.push(node.left); node = node.right; } } stack.push(node); } } return result; } } package com.example.nearest; import java.util.List; /** * @author 学徒 */ public class Utils { /** * 用于获取数据集中的某个数据维度的方差 * * @param data 数据集 * @param dimension 维度 */ public static double getVariance(List<double[]> data, int dimension) { double vsum = 0; double sum = 0; for (double[] d : data) { sum += d[dimension]; vsum += Math.pow(d[dimension], 2); } int n = data.size(); return vsum / n - Math.pow(sum / n, 2); } /** * 用于获取数据集中某个数据维度的中间值 * * @param data 数据集 * @param dimension 维度 * @return 对应维度的中间值 */ public static double getMedian(List<double[]> data, int dimension) { double[] numbers = new double[data.size()]; for (int i = 0; i < data.size(); i++) { numbers[i] = data.get(i)[dimension]; } return findPositionValue(numbers, 0, numbers.length - 1, numbers.length / 2); } /** * 三向切分的快排思想,用于得到对应位置的数据值 * * @param data 数据 * @param low 左边界 * @param high 右边界 * @param position 在边界内进行查找的位置 * @return 对应位置的值 */ private static double findPositionValue(double[] data, int low, int high, int position) { //第一个相等元素的下标,第一个未访问元素的下标,最后一个未访问元素的下标 int equalityFirstIndex = low, unvisitedFirstIndex = low + 1, unvisitedEndIndex = high; double number = data[low]; while (unvisitedFirstIndex <= unvisitedEndIndex) { if (data[unvisitedFirstIndex] < number) { exchange(data, unvisitedFirstIndex, equalityFirstIndex); unvisitedFirstIndex++; equalityFirstIndex++; } else if (data[unvisitedFirstIndex] == number) { unvisitedFirstIndex++; } else { exchange(data, unvisitedFirstIndex, unvisitedEndIndex); unvisitedEndIndex--; } } // low ~ equalityFirstIndex-1 为小于基准元素的部分 // equalityFirstIndex ~ unvisitedFirstIndex-1 为等于基准元素的部分 // unvisitedFirstIndex ~ high 为大于基准元素的部分 if (position <= unvisitedFirstIndex - 1 && position >= equalityFirstIndex) { return data[position]; } if (position <= equalityFirstIndex - 1 && position >= low) { return findPositionValue(data, low, equalityFirstIndex - 1, position); } return findPositionValue(data, unvisitedFirstIndex, high, position); } /** * 用于交换两个元素 * * @param data 数据集 * @param i 元素1下标 * @param j 元素2下标 */ private static void exchange(double[] data, int i, int j) { double temp = data[i]; data[i] = data[j]; data[j] = temp; } /** * 用于得到两个点之间的距离(此处计算欧式距离) * * @param point1 点1 * @param point2 点2 * @return 两点之间的距离 */ public static double getDistance(double[] point1, double[] point2) { if (point1.length != point2.length) { return Double.MAX_VALUE; } double sum = 0; for (int i = 0; i < point1.length; i++) { sum += Math.pow(point1[i] - point2[i], 2); } return Math.sqrt(sum); } }
测试代码:
以下为一点测试代码,用于供调试,测试使用
以下代码位于KDTree类中 /** * 打印树,测试时用 */ public void print() { printRec(root, 0); } private void printRec(KDTree.Node node, int lv) { if (!node.isLeaf()) { for (int i = 0; i < lv; i++) System.out.print("--"); System.out.println(node.splitDimension + ":" + node.splitValue); printRec(node.left, lv + 1); printRec(node.right, lv + 1); } else { for (int i = 0; i < lv; i++) System.out.print("--"); StringBuilder s = new StringBuilder(); s.append('('); for (int i = 0; i < node.value.length - 1; i++) { s.append(node.value[i]).append(','); } s.append(node.value[node.value.length - 1]).append(')'); System.out.println(s); } } 以下测试代码位于单独的测试类中 public static void correct(){ int count = 10000; while(count-->0){ int num = 1000; double[][] input = new double[num][2]; for(int i=0;i<num;i++){ input[i][0]=Math.random()*10; input[i][1]=Math.random()*10; } double[] query = new double[]{Math.random()*50,Math.random()*50}; KDTree tree=new KDTree(input); double[] result = tree.searchNearest(query); double[] result1 = nearest(query,input); if (result[0]!=result1[0]||result[1]!=result1[1]) { System.out.println(count); System.out.println("查找点:"+query[0]+"/t"+query[1]); System.out.println("tree找到的:"+result[0]+"/t"+result[1]); System.out.println("线性找到的"+result1[0]+"/t"+result1[1]); System.out.println("tree找到的距离:"+Utils.getDistance(result,query)); System.out.println("线性找到的距离:"+Utils.getDistance(result1,query)); tree.print(); break; } } } public static void performance(int iteration,int datasize){ int count = iteration; int num = datasize; double[][] input = new double[num][2]; for(int i=0;i<num;i++){ input[i][0]=Math.random()*num; input[i][1]=Math.random()*num; } KDTree tree=new KDTree(input); double[][] query = new double[iteration][2]; for(int i=0;i<iteration;i++){ query[i][0]= Math.random()*num*1.5; query[i][1]= Math.random()*num*1.5; } long start = System.nanoTime(); for(int i=0;i<iteration;i++){ double[] result = tree.searchNearest(query[i]); } long timekdtree = System.nanoTime()-start; start = System.nanoTime(); for(int i=0;i<iteration;i++){ double[] result = nearest(query[i],input); } long timelinear = System.nanoTime()-start; System.out.println("datasize:"+datasize+";iteration:"+iteration); System.out.println("kdtree:"+timekdtree); System.out.println("linear:"+timelinear); System.out.println("linear/kdtree:"+(timelinear*1.0/timekdtree)); }
分析: 构建一棵kd-tree的时间复杂度根据划分维度的选择方式不同,时间复杂度也不同。采用方差最大的那个维度作为空间划分的依据,其时间复杂度为O(nklogn),采用轮流交替维度作为空间划分依据的,其时间复杂度为O(nlogn)。建树的空间复杂度为O(n)。查找过程的时间复杂度为O(/(n^{1-/tfrac{1}{k}}/)+m),其中m为每次要进行查找的最近点个数,k为空间的维度。
适用情况: 该方式适用于点坐标变动较少甚至基本不变而且维度k不大的情况,其可以精确的获得所需的一定距离的点。对于每次空间中的点坐标发生变化的情况,其可以采用重新建立kd-tree数据结构的方式或者采用替罪羊方式实现的改进后的kd-tree来实现。对于维度k小于20时能够获得较高的效率,当空间维度较高时,其接近于线性扫描。假设数据集的维数为D,一般来说要求数据的规模N满足N>>/(2^D/),才能达到高效的搜索。
2.降维编码方式
在介绍降维编码方式之前,我们可以先看下这个视频让自己对如何降维多维空间有个直观的认识。
空间填充曲线
空间填充曲线从数学的角度上看,可以看成是一种把N维空间数据转换到1维连续空间上的映射函数。实际上,存储磁盘(常见的关系型数据库)是一维的存储设备,而空间数据大多都是多维数据,很少存在一维顺序。因此,为了使多维空间上邻近的元素映射也尽可能是一维直线上邻近的点,专家们提出了许多的映射方法。其中最常用的方法包括Z Ordering、Hilbert、Peano 曲线等等。
为此,对于需要查找某个多维空间中的临近点,我们可以先采用一定的方式对其进行降维,使其在空间中临近的点,在一维中也尽可能的临近。然后再在一维空间中查找临近值从而达到查找空间中临近点的目的。
对于空间填充曲线的实现方式,我们在此只考虑实现二维下的希尔伯特空间填充曲线。
实现:
package com.example.nearest; import com.google.common.collect.Lists; import lombok.Data; import java.util.*; /** * @author 学徒 */ public class HilbertCode { @Data private class HilbertNode { /** * 原始数据 */ private double[] value; /** * 编码值 */ private long codeValue; /** * 计算坐标点保留的精度 */ private final static int PRECISION = 10000; /** * 计算的数据维度 */ private final static int ALLOW_DIMENSION = 2; public HilbertNode(double[] value) { this.value = value; this.codeValue = this.encode(value); } /** * 用于将空间点坐标进行编码 * * @param value 空间点坐标 * @return 编码后的值 */ private long encode(double[] value) { if (Objects.isNull(value) || !Objects.equals(value.length, ALLOW_DIMENSION)) { throw new IllegalArgumentException("点参数存在错误"); } long x = new Double(value[0] * PRECISION).longValue(); long y = new Double(value[1] * PRECISION).longValue(); long n = 1 << 62; long rx, ry, s, d = 0; boolean rxb, ryb; for (s = n >> 1; s > 0; s >>= 1) { rxb = (x & s) > 0; ryb = (y & s) > 0; rx = rxb ? 1 : 0; ry = ryb ? 1 : 0; d += s * s * ((3 * rx) ^ ry); long[] xy = rot(s, x, y, rxb, ryb); x = xy[0]; y = xy[1]; } return d; } private long[] rot(long n, long x, long y, boolean rxb, boolean ryb) { if (ryb) { return new long[]{x, y}; } if (rxb) { x = n - 1 - x; y = n - 1 - y; } return new long[]{y, x}; } } @Data private class HilbertSearchNode { private HilbertNode hilbertNode; private double distance; public HilbertSearchNode(HilbertNode node, double[] point) { if (Objects.isNull(point) || !Objects.equals(point.length, HilbertNode.ALLOW_DIMENSION)) { throw new IllegalArgumentException("点参数存在错误"); } this.hilbertNode = node; this.distance = Utils.getDistance(point, node.value); } } /** * 维护希尔伯特编码序列 */ private TreeSet<HilbertNode> hilbertSequence; public HilbertCode() { // 得到按照编码值从小到大进行编码的序列 this.hilbertSequence = new TreeSet<>((o1, o2) -> { if (o1.codeValue > o2.codeValue) { return 1; } else if (o1.codeValue == o2.codeValue) { return 0; } return -1; }); } /** * 添加点坐标 * * @param value 点坐标 * @return 添加结果 */ public boolean add(double[] value) { return hilbertSequence.add(new HilbertNode(value)); } /** * 尽可能的获取当前点前后2m个元素,总共4m个,最少20个 * * @param point 要得到对应位置的点 * @param m 元素个数 * @return 结果列表 */ private List<HilbertNode> getNear2MNode(double[] point, int m) { SortedSet<HilbertNode> headNodes = hilbertSequence.headSet(new HilbertNode(point)); SortedSet<HilbertNode> tailNodes = hilbertSequence.tailSet(new HilbertNode(point)); List<HilbertNode> result = new ArrayList<>(m << 2); // 查找的点越多,精度越高 final int MINIMUM_POINTS = 100; int count = 4 * m < MINIMUM_POINTS ? MINIMUM_POINTS : 4 * m; while (result.size() < count) { if (headNodes.size() == 0 && tailNodes.size() == 0) { break; } if (headNodes.size() != 0) { HilbertNode node = headNodes.last(); headNodes.remove(node); result.add(node); } if (tailNodes.size() != 0) { HilbertNode node = tailNodes.first(); tailNodes.remove(node); result.add(node); } } // 结束之后,要将从维护队列中是删除的节点重新添加回去 hilbertSequence.addAll(result); return result; } /** * 查找某个点的临近点 * * @param point 查找点 * @param m 临近点的个数 * @return 临近点坐标集合 */ public List<double[]> search(double[] point, int m) { List<double[]> result = new ArrayList<>(m); Queue<HilbertSearchNode> queue = new PriorityQueue<>((o1, o2) -> { if (o2.getDistance() > o1.getDistance()) { return 1; } else if (o1.getDistance() == o2.getDistance()) { return 0; } return -1; }); List<HilbertNode> nearNode = getNear2MNode(point, m); for (HilbertNode node : nearNode) { queue.add(new HilbertSearchNode(node, point)); if (queue.size() > m) { queue.poll(); } } for (HilbertSearchNode node : queue) { result.add(node.getHilbertNode().getValue()); } return result; } /** * 用于得到某个点中的最接近点的坐标 * * @param point 查找点 * @return 最接近的点坐标 */ public double[] searchNearest(double[] point) { List<double[]> result = search(point, 1); if (result.size() == 1) { return result.get(0); } return null; } }
分析: 将空间点降维编码成一维数据,其时间复杂度为O(1)。维护有序数据结构,其时间复杂度为O(logn),查找时间复杂度为O(logn+m),其中m为查找的点的数目。
适用情况: 空间填充曲线适用于空间中点坐标经常变动甚至点数目经常变化的情况,但是其查找效果受填充曲线选择的影响,为此在查找一定距离的点的时候,可以考虑多查找几个点临近点,以获得更准确的结果。
总结
在生产中,我们一定要根据项目的实际情况来灵活的组合实现方式去高效的完成功能。当多维空间中点集数目不多且对该功能的调用频率较低时,由于没有建立和维护对应数据结构的成本,通过采用遍历所有点的方式并不一定比采用数据结构或者降维编码的方式实现的效率更低。为此我们可以根据实际的项目情况,在一定数据规模和功能使用频率作为指标来衡量采用哪种实现方式的方法,通过多种功能组合实现,可以达到更高效的使用情况。
像博主本人之前遇到过的一个问题,需要首先得到某个运行可能异常的机器人的最临近点,之后得到与最临近点之间的距离以此来断定机器人是否脱轨的功能。由于机器人脱轨执行的可能性很低,而且地图普遍较小且不轻易改动。为此,采用遍历判断的方式去得到最临近点效率会比维护各种数据结构等的实现更高效且开销更小。但为了兼容可能的大地图以及异常情况,为此设置了一个阈值,当满足情况的时候,会切换到采用维护KD-Tree的方式进行实现。