本文接着Add Range Query on Spark DataSet继续,上一篇只是添加了基本的Range操作,为了加速执行,本文将在RDD上实现一个RTree Index来加速执行Range和Knn操作,全部代码在Github。实现之后的例子如下:
import org.apache.spark.sql.SparkSession import scala.util.Random object Main { case class PointData(x: Double, y: Double, z: Double, other: String) def main(args: Array[String]): Unit = { val sparkSession = SparkSession .builder() .master("local[4]") .appName("SparkSession") .getOrCreate() import sparkSession.implicits._ var points = Seq[PointData]() for(i <- 0 until 3000) { points = points :+ PointData(Random.nextInt()%30, Random.nextInt()%30, Random.nextInt()%30, "point: "+i.toString) } val pointsList = points.toDS() pointsList.createIndex("rtree", "RtreeForData", Array("x", "y") ) pointsList.range(Array("x", "y"), Array(0, 0), Array(10, 10)).show() pointsList.knn(Array("x", "y"),Array(1.0, 1.0),4).show(4) }
概述
首先RDD在执行过程中按照Partitioner被分成很多个Partition,然后分布式并行执行。我们在实现索引时分为两个部分实现,一部分是Global Index,一部分是Local Index。Global Index是为了快速得到数据具体在哪一个Partition中,Local Index则是快速在Partition中得到具体的数据。所以我们在实现索引时要实现自己的Partitioner和自己的Partition。大致的数据定义如下:
//这里index为索引,Index是一个trait,为以后加入更多索引方便,RTreeIndex继承力Index case class IndexPartition(data: Array[InternalRow], index: Index) type IndexRDD = RDD[IndexPartition]
实现Partitioner
Spark自己的Partitioner有很多,有HashPartitioner,RangePartitioner等等。在建利RTree索引时,我们需要对数据进行分区,每个分区的数据属于一个MBR,各个分区的MBR不能有交集,由此我们实现一个自己的Parttioner => STRPartitioner(参考论文STR: A Simple and Efficient Algorithm for R-Tree Packing)。思想大致如下,为了提高建利索引的速度,我们对RDD中的数据进行采样,然后对采样的数据执行recursiveGroupPoint来得到在采样数据中的Array[MBR],然后对数据进行分区。核心代码如下:
class STRPartitioner(est_partition: Int, sample_rate: Double, dimension: Int, transfer_threshold: Long, max_entries_per_node: Int, rdd: RDD[_ <: Product2[Point, Any]]) extends Partitioner { override def numPartitions: Int = partitions private case class Bounds(min: Array[Double], max: Array[Double]) var (mbrBound, partitions) = { 。。。。。。 } //data_bounds.min.foreach(item => print(s"${item} , ")) //data_bounds.max.foreach(item => print(s"${item}, ")) val seed = System.currentTimeMillis() val sampled = if (total_size <= 0.2 * transfer_threshold) { rdd.mapPartitions(part => part.map(_._1)).collect() } 。。。。。。 def recursiveGroupPoint(entries: Array[Point], now_min: Array[Double], now_max: Array[Double], cur_dim: Int, until_dim: Int): Array[MBR] = { val len = entries.length val grouped = entries.sortWith(_.coord(cur_dim) < _.coord(cur_dim)) .grouped(Math.ceil(len * 1.0 / dim(cur_dim)).toInt).toArray var ans = mutable.ArrayBuffer[MBR]() if (cur_dim < until_dim) { for (i <- grouped.indices) { val cur_min = now_min val cur_max = now_max if (i == 0 && i == grouped.length - 1) { cur_min(cur_dim) = data_bounds.min(cur_dim) cur_max(cur_dim) = data_bounds.max(cur_dim) } else if (i == 0) { cur_min(cur_dim) = data_bounds.min(cur_dim) cur_max(cur_dim) = grouped(i + 1).head.coord(cur_dim) } else if (i == grouped.length - 1) { cur_min(cur_dim) = grouped(i).head.coord(cur_dim) cur_max(cur_dim) = data_bounds.max(cur_dim) } else { cur_min(cur_dim) = grouped(i).head.coord(cur_dim) cur_max(cur_dim) = grouped(i + 1).head.coord(cur_dim) } ans ++= recursiveGroupPoint(grouped(i), cur_min, cur_max, cur_dim + 1, until_dim) } ans.toArray } else { for (i <- grouped.indices) { if (i == 0 && i == grouped.length - 1) { now_min(cur_dim) = data_bounds.min(cur_dim) now_max(cur_dim) = data_bounds.max(cur_dim) } else if (i == 0) { now_min(cur_dim) = data_bounds.min(cur_dim) now_max(cur_dim) = grouped(i + 1).head.coord(cur_dim) } else if (i == grouped.length - 1) { now_min(cur_dim) = grouped(i).head.coord(cur_dim) now_max(cur_dim) = data_bounds.max(cur_dim) } else { now_min(cur_dim) = grouped(i).head.coord(cur_dim) now_max(cur_dim) = grouped(i + 1).head.coord(cur_dim) } ans += MBR(new Point(now_min.clone()), new Point(now_max.clone())) } ans.toArray } } val cur_min = new Array[Double](dimension) val cur_max = new Array[Double](dimension) val mbrs = recursiveGroupPoint(sampled, cur_min, cur_max, 0, dimension - 1) (mbrs.zipWithIndex, mbrs.length) } //println(est_partition+"----------------"+numPartitions+"--------------"+dimension) val rt = RTree(mbrBound.map(x => (x._1, x._2, 1)), max_entries_per_node) override def getPartition(key: Any): Int = { val k = key.asInstanceOf[Point] rt.circleRange(k, 0.0).head._2 } }
实现RTree Index
通过自定义的STRPartitioner对RDD进行分区之后生成新的RDD,然后在新的RDD上建利索引。自定义STRPartition的apply方法使其返回新的RDD和对应Partition的MBR(MBR对应的Partition编号:Array[(MBR, Int)]),然后先对Array[(MBR, Int)]简历Global Index,然后对RDD的每个Partition简历Local Index。核心代码如下:
def buildIndex(): Unit = { val numPartitions = sparksession.sessionState.conf.indexPsrtitions val maxEntriesPerNode = sparksession.sessionState.conf.maxEntriesPerNode val sampleRate = sparksession.sessionState.conf.sampleRate val transferThreshold = sparksession.sessionState.conf.transfer_Threshold val tmpRDD = child.execute().map(row =>{ (IndexUtil.getPointFromRow(row, col_keys, child), row) }) val (partitionRDD, mbrs) = STRPartition(tmpRDD, dimension, numPartitions, sampleRate, transferThreshold, maxEntriesPerNode) //create local RTree index for each partition val indexedrdd = partitionRDD.mapPartitions{iter => val tmpdata = iter.toArray var index: RTree = null if(tmpdata.length > 0) index = RTree(tmpdata.map(_._1).zipWithIndex, maxEntriesPerNode) Array(IndexPartition(tmpdata.map(_._2), index)).iterator }.persist(StorageLevel.MEMORY_AND_DISK_SER) val partitionSize = indexedrdd.mapPartitions(iter => iter.map(_.data.length)).collect() //create global RTree index global_rtree = RTree(mbrs.zip(partitionSize) .map(x => (x._1._1, x._1._2, x._2)), maxEntriesPerNode) indexedrdd.setName(table_name.map(n => s"$n $index_name").getOrElse(child.toString)) indexRDD_data = indexedrdd }
使用RTree Index加速Range
首先生成query_mbr,在Global Index上对query_mbr执行range查询,即就是查询出每个Partition的MBR和query_mbr有交集的MBR,得到对应的Partition编号(Partition id:Set(Int)),然后分两种情况:
- Perfect Cover:整个Partition的MBR都被query_mbr包含,那么整个Partition都在query_mbr范围内
- 在Global Index得到的Partition的对应的Local Index对query_mbr执行Range查询,返回结果
以上就是在RTree Index加速的Range查询,想了解其他操作的RTree Index优化参考论文Simba: Efficient In-Memory Spatial Analytics。核心代码如下:
val low_point = low.asInstanceOf[Literal].value.asInstanceOf[Point] val high_point = high.asInstanceOf[Literal].value.asInstanceOf[Point] //require(data.dimensions == low_point.dimensions && low_point.dimensions == high_point.dimensions) val query_mbr = MBR(low_point, high_point) 。。。 var indexdata = session.sessionState.indexManager.lookupIndexedData(lp.children.head).orNull 。。。 //println("-------------------------range index-------------") val rtree = indexdata.indexrelation.asInstanceOf[RTreeIndexRelation] val col_keys = rtree.col_keys var global_part = rtree.global_rtree.range(query_mbr).map(_._2).toSeq global_part.foreach(println) val prdd = new PartitionPruningRDD(rtree.indexRDD_data, global_part.contains) val tmp = prdd.flatMap{datas => val index = datas.index.asInstanceOf[RTree] if(index != null) { val root_mbr = index.root.m_mbr val perfect_cover = query_mbr.contains(root_mbr.low) && query_mbr.contains(root_mbr.high) if(perfect_cover) { datas.data } else { index.range(query_mbr).map(x => datas.data(x._2)) } }else Array[InternalRow]() } 。。。