准备工作
首先下载代码并编译,将编译之后的代码导入到IDEA中,若在IDEA中编译出现问题,一般是由于有的代码在编译时才生成,在导入到IDEA之后要重新生成一下,点击Generate Sources and Update folders后重新编译即可。(我用的是Spark-2.1版本的代码)
git clone -b branch-2.1 https://github.com/apache/spark.git ./build/mvn -DskipTests clean package
目标
在DataSet中添加Range查询操作,使得DataSet支持Range查询。
import org.apache.spark.sql.SparkSession object Range { case class PointData(x: Double, y: Double, z: Double, other: String) def main(args: Array[String]): Unit = { val sparkSession = SparkSession .builder() .master("local[4]") .appName("SparkSession") .getOrCreate() import sparkSession.implicits._ val caseClassDS = Seq(PointData(1.0, 1.0, 3.0, "1"), PointData(2.0, 2.0, 3.0, "2"), PointData(2.0, 2.0, 3.0, "3"), PointData(2.0, 2.0, 3.0, "4"),PointData(3.0, 3.0, 3.0, "5"),PointData(4.0, 4.0, 3.0, "6")).toDS() caseClassDS.range(Array("x", "y"), Array(0.0, 0.0), Array(3.0, 3.0)).show() } }
实现UDT(User Define Type)类型
实现UDT是为了在Spark SQL中添加用户自定义的复杂类型,满足用户需求。定义ShapeType类是告诉Spark用户自定义的Point,Shape都属于ShapeType 类型。在实现ShapeType时需要实现两个函数接口,一个是序列化接口,一个是反序列化接口,我用kryo来实现序列化和反序列化。
private[sql] class ShapeType extends UserDefinedType[Shape] { override def sqlType: DataType = ArrayType(ByteType, containsNull = false) override def serialize(s: Shape): Any = { new GenericArrayData(ShapeSerializer.serialize(s)) } override def userClass: Class[Shape] = classOf[Shape] override def deserialize(datum: Any): Shape = { datum match { case values: ArrayData => ShapeSerializer.deserialize(values.toByteArray) } } } case object ShapeType extends ShapeType @SQLUserDefinedType(udt = classOf[ShapeType]) abstract class Shape extends Serializable { val dimensions: Int def minDist(other: Shape): Double } @SQLUserDefinedType(udt = classOf[ShapeType]) case class Point(coord: Array[Double]) extends Shape { override val dimensions :Int = coord.length override def minDist(other: Shape): Double = { other match { case p: Point => minDist(p) case mbr: MBR => mbr.minDist(this) } } def minDist(other: Point): Double = { require(coord.length == other.coord.length) var ans = 0.0 for (i <- coord.indices) ans += (coord(i) - other.coord(i)) * (coord(i) - other.coord(i)) Math.sqrt(ans) } def ==(other: Point): Boolean = { other match { case p: Point => if (p.coord.length != coord.length) false else { for (i <- coord.indices) if (coord(i) != p.coord(i)) return false true } case _ => false } } def <=(other: Point): Boolean = { for (i <- coord.indices) if (coord(i) > other.coord(i)) return false true } }
添加表达式
- PointWrapper表达式,表达式的eval函数实现的是通过给定属性(exps参数)从InternalRow中得到对应属性的值,将其构造成一个Point返回。
case class PointWrapper(exps: Seq[Expression]) extends Expression with CodegenFallback { override def nullable: Boolean = false override def dataType: DataType = ShapeType override def children: Seq[Expression] = exps override def eval(input: InternalRow): Any = { val coord = exps.map(_.eval(input).asInstanceOf[Double]).toArray Point(coord) } }
- ExpRange表达式,该表达式是Range查询操作的核心。该表达式的shape参数为查询操作用到的属性,low和high分别为查询操作的范围。eval函数计算出来InternalRow中shape对应属性的值构造成一个Point,再计算出low和high构造出来两个点,合成一个MBR类型,通过MBR的contains来检查该InternalRow中对应属性的值是否在给定范围内。
case class ExpRange(shape: Expression, low: Expression, high: Expression) extends Predicate with CodegenFallback{ override def nullable: Boolean = false override def toString: String = s" **($shape) IN Rectangle ($low) - ($high)** " override def children: Seq[Expression] = Seq(shape, low, high) override def eval(input: InternalRow): Any = { val data = ShapeUtil.getShape(shape, input) val low_point = low.asInstanceOf[Literal].value.asInstanceOf[Point] val high_point = high.asInstanceOf[Literal].value.asInstanceOf[Point] require(data.dimensions == low_point.dimensions && low_point.dimensions == high_point.dimensions) val mbr = MBR(low_point, high_point) println("eval: ", mbr.contains(data)) mbr.contains(data) } }
实现Physical Plan
SparkSQL在执行时,调用planner.plan(ReturnAnswer(optimizedPlan)).next()将优化后的LogicalPlan转换成SparkPlan,再通过prepareForExecution(sparkPlan)将SparkPlan通过现有策略重新生executedPlan,最后运行executedPlan.execute()得到结果。 首先实现我们自己的SparkPlan => SpatialFilterExec,然后实现我们自己的策略SpatialFilter,并将SpatialFilter添加到SparkPlanner的strategies数组中。实现SpatialFilter将LogicalPlan转换成SpatialFilterExec,在实现SpatialFilterExec时重写doExecute函数,判断condition是否为ExpRange表达式,然后在RDD上执行该表达式,过滤结果。
case class SpatialFilterExec(condition: Expression, child: SparkPlan) extends SparkPlan with PredicateHelper { override def output: Seq[Attribute] = child.output override protected def doExecute(): RDD[InternalRow] = { val root_rdd = child.execute() condition match { case ExpRange(_, _, _) => root_rdd.mapPartitions(iter => iter.filter(newPredicate(condition, child.output).eval(_))) } } override def outputOrdering: Seq[SortOrder] = child.outputOrdering override def children: Seq[SparkPlan] = child :: Nil override def outputPartitioning: Partitioning = child.outputPartitioning } object SpatialFilter extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Filter(condition, child) => SpatialFilterExec(condition, planLater(child)) :: Nil case _ => Nil } } class SparkPlanner( val sparkContext: SparkContext, val conf: SQLConf, val extraStrategies: Seq[Strategy]) extends SparkStrategies { def numPartitions: Int = conf.numShufflePartitions def strategies: Seq[Strategy] = extraStrategies ++ ( SpatialFilter :: FileSourceStrategy :: DataSourceStrategy :: DDLStrategy :: SpecialLimits :: Aggregation :: JoinSelection :: InMemoryScans :: BasicOperators :: Nil) 。。。 }
添加DataSet的接口
最后添加DataSet的接口,创建一个Filter表达式并返回结果。
def range(keys: Array[String], low: Array[Double], high: Array[Double]): DataFrame = withPlan { val attrs = getAttributes(keys) attrs.foreach(attr => assert(attr != null, "cloumn not found")) Filter(ExpRange(PointWrapper(attrs), Literal.create(new Point(low), ShapeType), Literal.create(new Point(high), ShapeType)), logicalPlan ) }
总结
调用Range操作,传入参数,生成一个Filter表达式,Filter表达式的参数为一个ExpRange表达式和logicalPlan,在后续的执行中调用planner.plan(ReturnAnswer(optimizedPlan)).next()将优化后的LogicalPlan转换成SparkPlan,再通过prepareForExecution(sparkPlan)将SparkPlan中的Filter转换为我们自己定义的SpatialFilterExec生成新的SparkPlan,再执行executedPlan.execute(),在executedPlan.execute()中执行SpatialFilterExec的doExecute,最后返回结果。所有代码在Github中。