Add Range Query on Spark DataSet


首先下载代码并编译,将编译之后的代码导入到IDEA中,若在IDEA中编译出现问题,一般是由于有的代码在编译时才生成,在导入到IDEA之后要重新生成一下,点击Generate Sources and Update folders后重新编译即可。(我用的是Spark-2.1版本的代码)

git clone -b branch-2.1
./build/mvn -DskipTests clean package



import org.apache.spark.sql.SparkSession
object Range {
  case class PointData(x: Double, y: Double, z: Double, other: String)
  def main(args: Array[String]): Unit = {

    val sparkSession = SparkSession

    import sparkSession.implicits._
    val caseClassDS = Seq(PointData(1.0, 1.0, 3.0, "1"),  PointData(2.0, 2.0, 3.0, "2"), PointData(2.0, 2.0, 3.0, "3"),
      PointData(2.0, 2.0, 3.0, "4"),PointData(3.0, 3.0, 3.0, "5"),PointData(4.0, 4.0, 3.0, "6")).toDS()
    caseClassDS.range(Array("x", "y"), Array(0.0, 0.0), Array(3.0, 3.0)).show()

实现UDT(User Define Type)类型

实现UDT是为了在Spark SQL中添加用户自定义的复杂类型,满足用户需求。定义ShapeType类是告诉Spark用户自定义的Point,Shape都属于ShapeType 类型。在实现ShapeType时需要实现两个函数接口,一个是序列化接口,一个是反序列化接口,我用kryo来实现序列化和反序列化。

private[sql] class ShapeType extends UserDefinedType[Shape] {
  override def sqlType: DataType = ArrayType(ByteType, containsNull = false)

  override def serialize(s: Shape): Any = {
    new GenericArrayData(ShapeSerializer.serialize(s))

  override def userClass: Class[Shape] = classOf[Shape]

  override def deserialize(datum: Any): Shape = {
    datum match {
      case values: ArrayData =>

case object ShapeType extends ShapeType

@SQLUserDefinedType(udt = classOf[ShapeType])
abstract class Shape extends Serializable {
  val dimensions: Int
  def minDist(other: Shape): Double

@SQLUserDefinedType(udt = classOf[ShapeType])
case class Point(coord: Array[Double]) extends Shape {
  override val dimensions :Int = coord.length

  override def minDist(other: Shape): Double = {
    other match {
      case p: Point => minDist(p)
      case mbr: MBR => mbr.minDist(this)

  def minDist(other: Point): Double = {
    require(coord.length == other.coord.length)
    var ans = 0.0
    for (i <- coord.indices)
      ans += (coord(i) - other.coord(i)) * (coord(i) - other.coord(i))

  def ==(other: Point): Boolean = {
    other match {
      case p: Point =>
        if (p.coord.length != coord.length) false
        else {
          for (i <- coord.indices)
            if (coord(i) != p.coord(i)) return false
      case _ => false

  def <=(other: Point): Boolean = {
    for (i <- coord.indices)
      if (coord(i) > other.coord(i)) return false


  • PointWrapper表达式,表达式的eval函数实现的是通过给定属性(exps参数)从InternalRow中得到对应属性的值,将其构造成一个Point返回。
case class PointWrapper(exps: Seq[Expression]) extends Expression with CodegenFallback {
  override def nullable: Boolean = false

  override def dataType: DataType = ShapeType

  override def children: Seq[Expression] = exps

  override def eval(input: InternalRow): Any = {
    val coord =[Double]).toArray
  • ExpRange表达式,该表达式是Range查询操作的核心。该表达式的shape参数为查询操作用到的属性,low和high分别为查询操作的范围。eval函数计算出来InternalRow中shape对应属性的值构造成一个Point,再计算出low和high构造出来两个点,合成一个MBR类型,通过MBR的contains来检查该InternalRow中对应属性的值是否在给定范围内。
case class ExpRange(shape: Expression, low: Expression, high: Expression) extends Predicate with CodegenFallback{
  override def nullable: Boolean = false

  override def toString: String = s" **($shape) IN Rectangle ($low) - ($high)**  "

  override def children: Seq[Expression] = Seq(shape, low, high)

  override def eval(input: InternalRow): Any = {
    val data = ShapeUtil.getShape(shape, input)
    val low_point = low.asInstanceOf[Literal].value.asInstanceOf[Point]
    val high_point = high.asInstanceOf[Literal].value.asInstanceOf[Point]
    require(data.dimensions == low_point.dimensions && low_point.dimensions == high_point.dimensions)
    val mbr = MBR(low_point, high_point)
    println("eval: ", mbr.contains(data))

实现Physical Plan

SparkSQL在执行时,调用planner.plan(ReturnAnswer(optimizedPlan)).next()将优化后的LogicalPlan转换成SparkPlan,再通过prepareForExecution(sparkPlan)将SparkPlan通过现有策略重新生executedPlan,最后运行executedPlan.execute()得到结果。 首先实现我们自己的SparkPlan => SpatialFilterExec,然后实现我们自己的策略SpatialFilter,并将SpatialFilter添加到SparkPlanner的strategies数组中。实现SpatialFilter将LogicalPlan转换成SpatialFilterExec,在实现SpatialFilterExec时重写doExecute函数,判断condition是否为ExpRange表达式,然后在RDD上执行该表达式,过滤结果。

case class SpatialFilterExec(condition: Expression, child: SparkPlan) extends SparkPlan with PredicateHelper {
  override def output: Seq[Attribute] = child.output
  override protected def doExecute(): RDD[InternalRow] = {
    val root_rdd = child.execute()
    condition match {
      case ExpRange(_, _, _) =>
        root_rdd.mapPartitions(iter => iter.filter(newPredicate(condition, child.output).eval(_)))
  override def outputOrdering: Seq[SortOrder] = child.outputOrdering
  override def children: Seq[SparkPlan] = child :: Nil
  override def outputPartitioning: Partitioning = child.outputPartitioning

object SpatialFilter extends Strategy {
  def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
    case logical.Filter(condition, child) =>
      SpatialFilterExec(condition, planLater(child)) :: Nil
    case _ => Nil

class SparkPlanner(
    val sparkContext: SparkContext,
    val conf: SQLConf,
    val extraStrategies: Seq[Strategy])
  extends SparkStrategies {

  def numPartitions: Int = conf.numShufflePartitions

  def strategies: Seq[Strategy] =
      extraStrategies ++ (
        SpatialFilter ::
      FileSourceStrategy ::
      DataSourceStrategy ::
      DDLStrategy ::
      SpecialLimits ::
      Aggregation ::
      JoinSelection ::
      InMemoryScans ::
      BasicOperators :: Nil)



def range(keys: Array[String], low: Array[Double], high: Array[Double]): DataFrame = withPlan {
    val attrs = getAttributes(keys)
    attrs.foreach(attr => assert(attr != null, "cloumn not found"))
      Literal.create(new Point(low), ShapeType),
      Literal.create(new Point(high), ShapeType)),

