RTree index speed up Range query in SparkSQL

本文接着Add Range Query on Spark DataSet继续,上一篇只是添加了基本的Range操作,为了加速执行,本文将在RDD上实现一个RTree Index来加速执行Range和Knn操作,全部代码在Github。实现之后的例子如下:

import org.apache.spark.sql.SparkSession
import scala.util.Random
object Main {
  case class PointData(x: Double, y: Double, z: Double, other: String)
  def main(args: Array[String]): Unit = {
    val sparkSession = SparkSession
    import sparkSession.implicits._
    var points = Seq[PointData]()
    for(i <- 0 until 3000) {
      points = points :+ PointData(Random.nextInt()%30, Random.nextInt()%30, Random.nextInt()%30, "point: "+i.toString)
    val pointsList = points.toDS()
    pointsList.createIndex("rtree", "RtreeForData",  Array("x", "y") )
    pointsList.range(Array("x", "y"), Array(0, 0), Array(10, 10)).show()
    pointsList.knn(Array("x", "y"),Array(1.0, 1.0),4).show(4)


首先RDD在执行过程中按照Partitioner被分成很多个Partition,然后分布式并行执行。我们在实现索引时分为两个部分实现,一部分是Global Index,一部分是Local Index。Global Index是为了快速得到数据具体在哪一个Partition中,Local Index则是快速在Partition中得到具体的数据。所以我们在实现索引时要实现自己的Partitioner和自己的Partition。大致的数据定义如下:

case class IndexPartition(data: Array[InternalRow], index: Index)
type IndexRDD = RDD[IndexPartition]


Spark自己的Partitioner有很多,有HashPartitioner,RangePartitioner等等。在建利RTree索引时,我们需要对数据进行分区,每个分区的数据属于一个MBR,各个分区的MBR不能有交集,由此我们实现一个自己的Parttioner => STRPartitioner(参考论文STR: A Simple and Efficient Algorithm for R-Tree Packing)。思想大致如下,为了提高建利索引的速度,我们对RDD中的数据进行采样,然后对采样的数据执行recursiveGroupPoint来得到在采样数据中的Array[MBR],然后对数据进行分区。核心代码如下:

class STRPartitioner(est_partition: Int,
                     sample_rate: Double,
                     dimension: Int,
                     transfer_threshold: Long,
                     max_entries_per_node: Int,
                     rdd: RDD[_ <: Product2[Point, Any]])
  extends Partitioner {
  override def numPartitions: Int = partitions
  private case class Bounds(min: Array[Double], max: Array[Double])
  var (mbrBound, partitions) = {

    //data_bounds.min.foreach(item => print(s"${item} , "))
    //data_bounds.max.foreach(item => print(s"${item}, "))
    val seed = System.currentTimeMillis()
    val sampled =  if (total_size <= 0.2 * transfer_threshold) {
      rdd.mapPartitions(part =>
    def recursiveGroupPoint(entries: Array[Point], now_min: Array[Double],
                            now_max: Array[Double], cur_dim: Int, until_dim: Int): Array[MBR] = {
      val len = entries.length
      val grouped = entries.sortWith(_.coord(cur_dim) < _.coord(cur_dim))
        .grouped(Math.ceil(len * 1.0 / dim(cur_dim)).toInt).toArray
      var ans = mutable.ArrayBuffer[MBR]()
      if (cur_dim < until_dim) {
        for (i <- grouped.indices) {
          val cur_min = now_min
          val cur_max = now_max
          if (i == 0 && i == grouped.length - 1) {
            cur_min(cur_dim) = data_bounds.min(cur_dim)
            cur_max(cur_dim) = data_bounds.max(cur_dim)
          } else if (i == 0) {
            cur_min(cur_dim) = data_bounds.min(cur_dim)
            cur_max(cur_dim) = grouped(i + 1).head.coord(cur_dim)
          } else if (i == grouped.length - 1) {
            cur_min(cur_dim) = grouped(i).head.coord(cur_dim)
            cur_max(cur_dim) = data_bounds.max(cur_dim)
          } else {
            cur_min(cur_dim) = grouped(i).head.coord(cur_dim)
            cur_max(cur_dim) = grouped(i + 1).head.coord(cur_dim)
          ans ++= recursiveGroupPoint(grouped(i), cur_min, cur_max, cur_dim + 1, until_dim)
      } else {
        for (i <- grouped.indices) {
          if (i == 0 && i == grouped.length - 1) {
            now_min(cur_dim) = data_bounds.min(cur_dim)
            now_max(cur_dim) = data_bounds.max(cur_dim)
          } else if (i == 0) {
            now_min(cur_dim) = data_bounds.min(cur_dim)
            now_max(cur_dim) = grouped(i + 1).head.coord(cur_dim)
          } else if (i == grouped.length - 1) {
            now_min(cur_dim) = grouped(i).head.coord(cur_dim)
            now_max(cur_dim) = data_bounds.max(cur_dim)
          } else {
            now_min(cur_dim) = grouped(i).head.coord(cur_dim)
            now_max(cur_dim) = grouped(i + 1).head.coord(cur_dim)
          ans += MBR(new Point(now_min.clone()), new Point(now_max.clone()))
    val cur_min = new Array[Double](dimension)
    val cur_max = new Array[Double](dimension)
    val mbrs = recursiveGroupPoint(sampled, cur_min, cur_max, 0, dimension - 1)
    (mbrs.zipWithIndex, mbrs.length)

  val rt = RTree( => (x._1, x._2, 1)), max_entries_per_node)
  override def getPartition(key: Any): Int = {
    val k = key.asInstanceOf[Point]
    rt.circleRange(k, 0.0).head._2

实现RTree Index

通过自定义的STRPartitioner对RDD进行分区之后生成新的RDD,然后在新的RDD上建利索引。自定义STRPartition的apply方法使其返回新的RDD和对应Partition的MBR(MBR对应的Partition编号:Array[(MBR, Int)]),然后先对Array[(MBR, Int)]简历Global Index,然后对RDD的每个Partition简历Local Index。核心代码如下:

def buildIndex(): Unit = {
    val numPartitions = sparksession.sessionState.conf.indexPsrtitions
    val maxEntriesPerNode = sparksession.sessionState.conf.maxEntriesPerNode
    val sampleRate = sparksession.sessionState.conf.sampleRate
    val transferThreshold = sparksession.sessionState.conf.transfer_Threshold
    val tmpRDD = child.execute().map(row =>{
      (IndexUtil.getPointFromRow(row, col_keys, child), row)
    val (partitionRDD, mbrs) = STRPartition(tmpRDD, dimension, numPartitions, sampleRate, transferThreshold, maxEntriesPerNode)
    //create local RTree index for each partition
    val indexedrdd = partitionRDD.mapPartitions{iter =>
      val tmpdata = iter.toArray
      var index: RTree = null
      if(tmpdata.length > 0) index = RTree(, maxEntriesPerNode)
      Array(IndexPartition(, index)).iterator

    val partitionSize = indexedrdd.mapPartitions(iter =>

    //create global RTree index
    global_rtree = RTree(
      .map(x => (x._1._1, x._1._2, x._2)), maxEntriesPerNode)
    indexedrdd.setName( => s"$n $index_name").getOrElse(child.toString))
    indexRDD_data = indexedrdd

使用RTree Index加速Range

首先生成query_mbr,在Global Index上对query_mbr执行range查询,即就是查询出每个Partition的MBR和query_mbr有交集的MBR,得到对应的Partition编号(Partition id:Set(Int)),然后分两种情况:

  • Perfect Cover:整个Partition的MBR都被query_mbr包含,那么整个Partition都在query_mbr范围内
  • 在Global Index得到的Partition的对应的Local Index对query_mbr执行Range查询,返回结果

以上就是在RTree Index加速的Range查询,想了解其他操作的RTree Index优化参考论文Simba: Efficient In-Memory Spatial Analytics。核心代码如下:

 val low_point = low.asInstanceOf[Literal].value.asInstanceOf[Point]
        val high_point = high.asInstanceOf[Literal].value.asInstanceOf[Point]
        //require(data.dimensions == low_point.dimensions && low_point.dimensions == high_point.dimensions)
        val query_mbr = MBR(low_point, high_point)
        var indexdata = session.sessionState.indexManager.lookupIndexedData(lp.children.head).orNull
//println("-------------------------range index-------------")
          val rtree = indexdata.indexrelation.asInstanceOf[RTreeIndexRelation]
          val col_keys = rtree.col_keys

          var global_part = rtree.global_rtree.range(query_mbr).map(_._2).toSeq


          val prdd = new PartitionPruningRDD(rtree.indexRDD_data, global_part.contains)
          val tmp = prdd.flatMap{datas =>
            val index = datas.index.asInstanceOf[RTree]
            if(index != null) {
              val root_mbr = index.root.m_mbr
              val perfect_cover = query_mbr.contains(root_mbr.low) &&
              if(perfect_cover) {
              } else {
                index.range(query_mbr).map(x =>
            }else Array[InternalRow]()