plutolove’s diary

I love three things in this world, the sun, the moon and you. The sun for the day, the moon for the night, and you forever。

Add Range Query on Spark DataSet

准备工作

首先下载代码并编译,将编译之后的代码导入到IDEA中,若在IDEA中编译出现问题,一般是由于有的代码在编译时才生成,在导入到IDEA之后要重新生成一下,点击Generate Sources and Update folders后重新编译即可。(我用的是Spark-2.1版本的代码)

git clone -b branch-2.1 https://github.com/apache/spark.git
./build/mvn -DskipTests clean package

目标

在DataSet中添加Range查询操作,使得DataSet支持Range查询。

import org.apache.spark.sql.SparkSession
object Range {
  case class PointData(x: Double, y: Double, z: Double, other: String)
  def main(args: Array[String]): Unit = {

    val sparkSession = SparkSession
      .builder()
      .master("local[4]")
      .appName("SparkSession")
      .getOrCreate()

    import sparkSession.implicits._
    val caseClassDS = Seq(PointData(1.0, 1.0, 3.0, "1"),  PointData(2.0, 2.0, 3.0, "2"), PointData(2.0, 2.0, 3.0, "3"),
      PointData(2.0, 2.0, 3.0, "4"),PointData(3.0, 3.0, 3.0, "5"),PointData(4.0, 4.0, 3.0, "6")).toDS()
    caseClassDS.range(Array("x", "y"), Array(0.0, 0.0), Array(3.0, 3.0)).show()
  }
}

实现UDT(User Define Type)类型

实现UDT是为了在Spark SQL中添加用户自定义的复杂类型,满足用户需求。定义ShapeType类是告诉Spark用户自定义的Point,Shape都属于ShapeType 类型。在实现ShapeType时需要实现两个函数接口,一个是序列化接口,一个是反序列化接口,我用kryo来实现序列化和反序列化。

private[sql] class ShapeType extends UserDefinedType[Shape] {
  override def sqlType: DataType = ArrayType(ByteType, containsNull = false)

  override def serialize(s: Shape): Any = {
    new GenericArrayData(ShapeSerializer.serialize(s))
  }

  override def userClass: Class[Shape] = classOf[Shape]

  override def deserialize(datum: Any): Shape = {
    datum match {
      case values: ArrayData =>
        ShapeSerializer.deserialize(values.toByteArray)
    }
  }
}

case object ShapeType extends ShapeType

@SQLUserDefinedType(udt = classOf[ShapeType])
abstract class Shape extends Serializable {
  val dimensions: Int
  def minDist(other: Shape): Double
}

@SQLUserDefinedType(udt = classOf[ShapeType])
case class Point(coord: Array[Double]) extends Shape {
  override val dimensions :Int = coord.length

  override def minDist(other: Shape): Double = {
    other match {
      case p: Point => minDist(p)
      case mbr: MBR => mbr.minDist(this)
    }
  }

  def minDist(other: Point): Double = {
    require(coord.length == other.coord.length)
    var ans = 0.0
    for (i <- coord.indices)
      ans += (coord(i) - other.coord(i)) * (coord(i) - other.coord(i))
    Math.sqrt(ans)
  }

  def ==(other: Point): Boolean = {
    other match {
      case p: Point =>
        if (p.coord.length != coord.length) false
        else {
          for (i <- coord.indices)
            if (coord(i) != p.coord(i)) return false
          true
        }
      case _ => false
    }
  }

  def <=(other: Point): Boolean = {
    for (i <- coord.indices)
      if (coord(i) > other.coord(i)) return false
    true
  }
}

添加表达式

  • PointWrapper表达式,表达式的eval函数实现的是通过给定属性(exps参数)从InternalRow中得到对应属性的值,将其构造成一个Point返回。
case class PointWrapper(exps: Seq[Expression]) extends Expression with CodegenFallback {
  override def nullable: Boolean = false

  override def dataType: DataType = ShapeType

  override def children: Seq[Expression] = exps

  override def eval(input: InternalRow): Any = {
    val coord = exps.map(_.eval(input).asInstanceOf[Double]).toArray
    Point(coord)
  }
}
  • ExpRange表达式,该表达式是Range查询操作的核心。该表达式的shape参数为查询操作用到的属性,low和high分别为查询操作的范围。eval函数计算出来InternalRow中shape对应属性的值构造成一个Point,再计算出low和high构造出来两个点,合成一个MBR类型,通过MBR的contains来检查该InternalRow中对应属性的值是否在给定范围内。
case class ExpRange(shape: Expression, low: Expression, high: Expression) extends Predicate with CodegenFallback{
  override def nullable: Boolean = false

  override def toString: String = s" **($shape) IN Rectangle ($low) - ($high)**  "

  override def children: Seq[Expression] = Seq(shape, low, high)

  override def eval(input: InternalRow): Any = {
    val data = ShapeUtil.getShape(shape, input)
    val low_point = low.asInstanceOf[Literal].value.asInstanceOf[Point]
    val high_point = high.asInstanceOf[Literal].value.asInstanceOf[Point]
    require(data.dimensions == low_point.dimensions && low_point.dimensions == high_point.dimensions)
    val mbr = MBR(low_point, high_point)
    println("eval: ", mbr.contains(data))
    mbr.contains(data)
  }
}

实现Physical Plan

SparkSQL在执行时,调用planner.plan(ReturnAnswer(optimizedPlan)).next()将优化后的LogicalPlan转换成SparkPlan,再通过prepareForExecution(sparkPlan)将SparkPlan通过现有策略重新生executedPlan,最后运行executedPlan.execute()得到结果。 首先实现我们自己的SparkPlan => SpatialFilterExec,然后实现我们自己的策略SpatialFilter,并将SpatialFilter添加到SparkPlanner的strategies数组中。实现SpatialFilter将LogicalPlan转换成SpatialFilterExec,在实现SpatialFilterExec时重写doExecute函数,判断condition是否为ExpRange表达式,然后在RDD上执行该表达式,过滤结果。

case class SpatialFilterExec(condition: Expression, child: SparkPlan) extends SparkPlan with PredicateHelper {
  override def output: Seq[Attribute] = child.output
  override protected def doExecute(): RDD[InternalRow] = {
    val root_rdd = child.execute()
    condition match {
      case ExpRange(_, _, _) =>
        root_rdd.mapPartitions(iter => iter.filter(newPredicate(condition, child.output).eval(_)))
    }
  }
  override def outputOrdering: Seq[SortOrder] = child.outputOrdering
  override def children: Seq[SparkPlan] = child :: Nil
  override def outputPartitioning: Partitioning = child.outputPartitioning
}

object SpatialFilter extends Strategy {
  def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
    case logical.Filter(condition, child) =>
      SpatialFilterExec(condition, planLater(child)) :: Nil
    case _ => Nil
  }
}

class SparkPlanner(
    val sparkContext: SparkContext,
    val conf: SQLConf,
    val extraStrategies: Seq[Strategy])
  extends SparkStrategies {

  def numPartitions: Int = conf.numShufflePartitions

  def strategies: Seq[Strategy] =
      extraStrategies ++ (
        SpatialFilter ::
      FileSourceStrategy ::
      DataSourceStrategy ::
      DDLStrategy ::
      SpecialLimits ::
      Aggregation ::
      JoinSelection ::
      InMemoryScans ::
      BasicOperators :: Nil)
。。。
}

添加DataSet的接口

最后添加DataSet的接口,创建一个Filter表达式并返回结果。

def range(keys: Array[String], low: Array[Double], high: Array[Double]): DataFrame = withPlan {
    val attrs = getAttributes(keys)
    attrs.foreach(attr => assert(attr != null, "cloumn not found"))
    Filter(ExpRange(PointWrapper(attrs),
      Literal.create(new Point(low), ShapeType),
      Literal.create(new Point(high), ShapeType)),
      logicalPlan
    )
 }

总结

调用Range操作,传入参数,生成一个Filter表达式,Filter表达式的参数为一个ExpRange表达式和logicalPlan,在后续的执行中调用planner.plan(ReturnAnswer(optimizedPlan)).next()将优化后的LogicalPlan转换成SparkPlan,再通过prepareForExecution(sparkPlan)将SparkPlan中的Filter转换为我们自己定义的SpatialFilterExec生成新的SparkPlan,再执行executedPlan.execute(),在executedPlan.execute()中执行SpatialFilterExec的doExecute,最后返回结果。所有代码在Github中。