plutolove’s diary

I love three things in this world, the sun, the moon and you. The sun for the day, the moon for the night, and you forever。

RTree index speed up Range query in SparkSQL

本文接着Add Range Query on Spark DataSet继续,上一篇只是添加了基本的Range操作,为了加速执行,本文将在RDD上实现一个RTree Index来加速执行Range和Knn操作,全部代码在Github。实现之后的例子如下:

import org.apache.spark.sql.SparkSession
import scala.util.Random
object Main {
  case class PointData(x: Double, y: Double, z: Double, other: String)
  def main(args: Array[String]): Unit = {
    val sparkSession = SparkSession
      .builder()
      .master("local[4]")
      .appName("SparkSession")
      .getOrCreate()
    import sparkSession.implicits._
    var points = Seq[PointData]()
    for(i <- 0 until 3000) {
      points = points :+ PointData(Random.nextInt()%30, Random.nextInt()%30, Random.nextInt()%30, "point: "+i.toString)
    }
    val pointsList = points.toDS()
    pointsList.createIndex("rtree", "RtreeForData",  Array("x", "y") )
    pointsList.range(Array("x", "y"), Array(0, 0), Array(10, 10)).show()
    pointsList.knn(Array("x", "y"),Array(1.0, 1.0),4).show(4)
}

概述

首先RDD在执行过程中按照Partitioner被分成很多个Partition,然后分布式并行执行。我们在实现索引时分为两个部分实现,一部分是Global Index,一部分是Local Index。Global Index是为了快速得到数据具体在哪一个Partition中,Local Index则是快速在Partition中得到具体的数据。所以我们在实现索引时要实现自己的Partitioner和自己的Partition。大致的数据定义如下:

//这里index为索引,Index是一个trait,为以后加入更多索引方便,RTreeIndex继承力Index
case class IndexPartition(data: Array[InternalRow], index: Index)
type IndexRDD = RDD[IndexPartition]

实现Partitioner

Spark自己的Partitioner有很多,有HashPartitioner,RangePartitioner等等。在建利RTree索引时,我们需要对数据进行分区,每个分区的数据属于一个MBR,各个分区的MBR不能有交集,由此我们实现一个自己的Parttioner => STRPartitioner(参考论文STR: A Simple and Efficient Algorithm for R-Tree Packing)。思想大致如下,为了提高建利索引的速度,我们对RDD中的数据进行采样,然后对采样的数据执行recursiveGroupPoint来得到在采样数据中的Array[MBR],然后对数据进行分区。核心代码如下:

class STRPartitioner(est_partition: Int,
                     sample_rate: Double,
                     dimension: Int,
                     transfer_threshold: Long,
                     max_entries_per_node: Int,
                     rdd: RDD[_ <: Product2[Point, Any]])
  extends Partitioner {
  override def numPartitions: Int = partitions
  private case class Bounds(min: Array[Double], max: Array[Double])
  var (mbrBound, partitions) = {
 。。。。。。
    }

    //data_bounds.min.foreach(item => print(s"${item} , "))
    //data_bounds.max.foreach(item => print(s"${item}, "))
    val seed = System.currentTimeMillis()
    val sampled =  if (total_size <= 0.2 * transfer_threshold) {
      rdd.mapPartitions(part => part.map(_._1)).collect()
    } 
。。。。。。
    def recursiveGroupPoint(entries: Array[Point], now_min: Array[Double],
                            now_max: Array[Double], cur_dim: Int, until_dim: Int): Array[MBR] = {
      val len = entries.length
      val grouped = entries.sortWith(_.coord(cur_dim) < _.coord(cur_dim))
        .grouped(Math.ceil(len * 1.0 / dim(cur_dim)).toInt).toArray
      var ans = mutable.ArrayBuffer[MBR]()
      if (cur_dim < until_dim) {
        for (i <- grouped.indices) {
          val cur_min = now_min
          val cur_max = now_max
          if (i == 0 && i == grouped.length - 1) {
            cur_min(cur_dim) = data_bounds.min(cur_dim)
            cur_max(cur_dim) = data_bounds.max(cur_dim)
          } else if (i == 0) {
            cur_min(cur_dim) = data_bounds.min(cur_dim)
            cur_max(cur_dim) = grouped(i + 1).head.coord(cur_dim)
          } else if (i == grouped.length - 1) {
            cur_min(cur_dim) = grouped(i).head.coord(cur_dim)
            cur_max(cur_dim) = data_bounds.max(cur_dim)
          } else {
            cur_min(cur_dim) = grouped(i).head.coord(cur_dim)
            cur_max(cur_dim) = grouped(i + 1).head.coord(cur_dim)
          }
          ans ++= recursiveGroupPoint(grouped(i), cur_min, cur_max, cur_dim + 1, until_dim)
        }
        ans.toArray
      } else {
        for (i <- grouped.indices) {
          if (i == 0 && i == grouped.length - 1) {
            now_min(cur_dim) = data_bounds.min(cur_dim)
            now_max(cur_dim) = data_bounds.max(cur_dim)
          } else if (i == 0) {
            now_min(cur_dim) = data_bounds.min(cur_dim)
            now_max(cur_dim) = grouped(i + 1).head.coord(cur_dim)
          } else if (i == grouped.length - 1) {
            now_min(cur_dim) = grouped(i).head.coord(cur_dim)
            now_max(cur_dim) = data_bounds.max(cur_dim)
          } else {
            now_min(cur_dim) = grouped(i).head.coord(cur_dim)
            now_max(cur_dim) = grouped(i + 1).head.coord(cur_dim)
          }
          ans += MBR(new Point(now_min.clone()), new Point(now_max.clone()))
        }
        ans.toArray
      }
    }
    val cur_min = new Array[Double](dimension)
    val cur_max = new Array[Double](dimension)
    val mbrs = recursiveGroupPoint(sampled, cur_min, cur_max, 0, dimension - 1)
    (mbrs.zipWithIndex, mbrs.length)
  }

  //println(est_partition+"----------------"+numPartitions+"--------------"+dimension)
  val rt = RTree(mbrBound.map(x => (x._1, x._2, 1)), max_entries_per_node)
  override def getPartition(key: Any): Int = {
    val k = key.asInstanceOf[Point]
    rt.circleRange(k, 0.0).head._2
  }
}

实现RTree Index

通过自定义的STRPartitioner对RDD进行分区之后生成新的RDD,然后在新的RDD上建利索引。自定义STRPartition的apply方法使其返回新的RDD和对应Partition的MBR(MBR对应的Partition编号:Array[(MBR, Int)]),然后先对Array[(MBR, Int)]简历Global Index,然后对RDD的每个Partition简历Local Index。核心代码如下:

def buildIndex(): Unit = {
    val numPartitions = sparksession.sessionState.conf.indexPsrtitions
    val maxEntriesPerNode = sparksession.sessionState.conf.maxEntriesPerNode
    val sampleRate = sparksession.sessionState.conf.sampleRate
    val transferThreshold = sparksession.sessionState.conf.transfer_Threshold
    val tmpRDD = child.execute().map(row =>{
      (IndexUtil.getPointFromRow(row, col_keys, child), row)
    })
    val (partitionRDD, mbrs) = STRPartition(tmpRDD, dimension, numPartitions, sampleRate, transferThreshold, maxEntriesPerNode)
    //create local RTree index for each partition
    val indexedrdd = partitionRDD.mapPartitions{iter =>
      val tmpdata = iter.toArray
      var index: RTree = null
      if(tmpdata.length > 0) index = RTree(tmpdata.map(_._1).zipWithIndex, maxEntriesPerNode)
      Array(IndexPartition(tmpdata.map(_._2), index)).iterator
    }.persist(StorageLevel.MEMORY_AND_DISK_SER)

    val partitionSize = indexedrdd.mapPartitions(iter => iter.map(_.data.length)).collect()

    //create global RTree index
    global_rtree = RTree(mbrs.zip(partitionSize)
      .map(x => (x._1._1, x._1._2, x._2)), maxEntriesPerNode)
    indexedrdd.setName(table_name.map(n => s"$n $index_name").getOrElse(child.toString))
    indexRDD_data = indexedrdd
}

使用RTree Index加速Range

首先生成query_mbr,在Global Index上对query_mbr执行range查询,即就是查询出每个Partition的MBR和query_mbr有交集的MBR,得到对应的Partition编号(Partition id:Set(Int)),然后分两种情况:

  • Perfect Cover:整个Partition的MBR都被query_mbr包含,那么整个Partition都在query_mbr范围内
  • 在Global Index得到的Partition的对应的Local Index对query_mbr执行Range查询,返回结果

以上就是在RTree Index加速的Range查询,想了解其他操作的RTree Index优化参考论文Simba: Efficient In-Memory Spatial Analytics。核心代码如下:

 val low_point = low.asInstanceOf[Literal].value.asInstanceOf[Point]
        val high_point = high.asInstanceOf[Literal].value.asInstanceOf[Point]
        //require(data.dimensions == low_point.dimensions && low_point.dimensions == high_point.dimensions)
        val query_mbr = MBR(low_point, high_point)
。。。
        var indexdata = session.sessionState.indexManager.lookupIndexedData(lp.children.head).orNull
。。。          
//println("-------------------------range index-------------")
          val rtree = indexdata.indexrelation.asInstanceOf[RTreeIndexRelation]
          val col_keys = rtree.col_keys

          var global_part = rtree.global_rtree.range(query_mbr).map(_._2).toSeq

          global_part.foreach(println)

          val prdd = new PartitionPruningRDD(rtree.indexRDD_data, global_part.contains)
          val tmp = prdd.flatMap{datas =>
            val index = datas.index.asInstanceOf[RTree]
            if(index != null) {
              val root_mbr = index.root.m_mbr
              val perfect_cover = query_mbr.contains(root_mbr.low) &&
                query_mbr.contains(root_mbr.high)
              if(perfect_cover) {
                datas.data
              } else {
                index.range(query_mbr).map(x => datas.data(x._2))
              }
            }else Array[InternalRow]()
          }
。。。