From 0f7175def985a7f1e37198680f893e749612ab76 Mon Sep 17 00:00:00 2001 From: Liwei Lin Date: Thu, 7 Jul 2016 10:40:42 -0700 Subject: [SPARK-16350][SQL] Fix support for incremental planning in wirteStream.foreach() ## What changes were proposed in this pull request? There are cases where `complete` output mode does not output updated aggregated value; for details please refer to [SPARK-16350](https://issues.apache.org/jira/browse/SPARK-16350). The cause is that, as we do `data.as[T].foreachPartition { iter => ... }` in `ForeachSink.addBatch()`, `foreachPartition()` does not support incremental planning for now. This patches makes `foreachPartition()` support incremental planning in `ForeachSink`, by making a special version of `Dataset` with its `rdd()` method supporting incremental planning. ## How was this patch tested? Added a unit test which failed before the change Author: Liwei Lin Closes #14030 from lw-lin/fix-foreach-complete. --- .../sql/execution/streaming/ForeachSink.scala | 40 +++++++++- .../execution/streaming/IncrementalExecution.scala | 4 +- .../sql/execution/streaming/ForeachSinkSuite.scala | 86 +++++++++++++++++++--- 3 files changed, 117 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala index 14b9b1cb09..082664aa23 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.TaskContext -import org.apache.spark.sql.{DataFrame, Encoder, ForeachWriter} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Dataset, Encoder, ForeachWriter} +import org.apache.spark.sql.catalyst.plans.logical.CatalystSerde /** * A [[Sink]] that forwards all data into [[ForeachWriter]] according to the contract defined by @@ -30,7 +32,41 @@ import org.apache.spark.sql.{DataFrame, Encoder, ForeachWriter} class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with Serializable { override def addBatch(batchId: Long, data: DataFrame): Unit = { - data.as[T].foreachPartition { iter => + // TODO: Refine this method when SPARK-16264 is resolved; see comments below. + + // This logic should've been as simple as: + // ``` + // data.as[T].foreachPartition { iter => ... } + // ``` + // + // Unfortunately, doing that would just break the incremental planing. The reason is, + // `Dataset.foreachPartition()` would further call `Dataset.rdd()`, but `Dataset.rdd()` just + // does not support `IncrementalExecution`. + // + // So as a provisional fix, below we've made a special version of `Dataset` with its `rdd()` + // method supporting incremental planning. But in the long run, we should generally make newly + // created Datasets use `IncrementalExecution` where necessary (which is SPARK-16264 tries to + // resolve). + + val datasetWithIncrementalExecution = + new Dataset(data.sparkSession, data.logicalPlan, implicitly[Encoder[T]]) { + override lazy val rdd: RDD[T] = { + val objectType = exprEnc.deserializer.dataType + val deserialized = CatalystSerde.deserialize[T](logicalPlan) + + // was originally: sparkSession.sessionState.executePlan(deserialized) ... + val incrementalExecution = new IncrementalExecution( + this.sparkSession, + deserialized, + data.queryExecution.asInstanceOf[IncrementalExecution].outputMode, + data.queryExecution.asInstanceOf[IncrementalExecution].checkpointLocation, + data.queryExecution.asInstanceOf[IncrementalExecution].currentBatchId) + incrementalExecution.toRdd.mapPartitions { rows => + rows.map(_.get(0, objectType)) + }.asInstanceOf[RDD[T]] + } + } + datasetWithIncrementalExecution.foreachPartition { iter => if (writer.open(TaskContext.getPartitionId(), batchId)) { var isFailed = false try { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 0ce00552bf..7367c68d0a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -30,8 +30,8 @@ import org.apache.spark.sql.streaming.OutputMode class IncrementalExecution private[sql]( sparkSession: SparkSession, logicalPlan: LogicalPlan, - outputMode: OutputMode, - checkpointLocation: String, + val outputMode: OutputMode, + val checkpointLocation: String, val currentBatchId: Long) extends QueryExecution(sparkSession, logicalPlan) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala index 6ff597c16b..7928b8e877 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala @@ -24,7 +24,7 @@ import scala.collection.mutable import org.scalatest.BeforeAndAfter import org.apache.spark.sql.ForeachWriter -import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.streaming.{OutputMode, StreamTest} import org.apache.spark.sql.test.SharedSQLContext class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAfter { @@ -35,35 +35,103 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf sqlContext.streams.active.foreach(_.stop()) } - test("foreach") { + test("foreach() with `append` output mode") { withTempDir { checkpointDir => val input = MemoryStream[Int] val query = input.toDS().repartition(2).writeStream .option("checkpointLocation", checkpointDir.getCanonicalPath) + .outputMode(OutputMode.Append) .foreach(new TestForeachWriter()) .start() + + // -- batch 0 --------------------------------------- input.addData(1, 2, 3, 4) query.processAllAvailable() - val expectedEventsForPartition0 = Seq( + var expectedEventsForPartition0 = Seq( ForeachSinkSuite.Open(partition = 0, version = 0), ForeachSinkSuite.Process(value = 1), ForeachSinkSuite.Process(value = 3), ForeachSinkSuite.Close(None) ) - val expectedEventsForPartition1 = Seq( + var expectedEventsForPartition1 = Seq( ForeachSinkSuite.Open(partition = 1, version = 0), ForeachSinkSuite.Process(value = 2), ForeachSinkSuite.Process(value = 4), ForeachSinkSuite.Close(None) ) - val allEvents = ForeachSinkSuite.allEvents() + var allEvents = ForeachSinkSuite.allEvents() + assert(allEvents.size === 2) + assert(allEvents.toSet === Set(expectedEventsForPartition0, expectedEventsForPartition1)) + + ForeachSinkSuite.clear() + + // -- batch 1 --------------------------------------- + input.addData(5, 6, 7, 8) + query.processAllAvailable() + + expectedEventsForPartition0 = Seq( + ForeachSinkSuite.Open(partition = 0, version = 1), + ForeachSinkSuite.Process(value = 5), + ForeachSinkSuite.Process(value = 7), + ForeachSinkSuite.Close(None) + ) + expectedEventsForPartition1 = Seq( + ForeachSinkSuite.Open(partition = 1, version = 1), + ForeachSinkSuite.Process(value = 6), + ForeachSinkSuite.Process(value = 8), + ForeachSinkSuite.Close(None) + ) + + allEvents = ForeachSinkSuite.allEvents() assert(allEvents.size === 2) - assert { - allEvents === Seq(expectedEventsForPartition0, expectedEventsForPartition1) || - allEvents === Seq(expectedEventsForPartition1, expectedEventsForPartition0) - } + assert(allEvents.toSet === Set(expectedEventsForPartition0, expectedEventsForPartition1)) + + query.stop() + } + } + + test("foreach() with `complete` output mode") { + withTempDir { checkpointDir => + val input = MemoryStream[Int] + + val query = input.toDS() + .groupBy().count().as[Long].map(_.toInt) + .writeStream + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .outputMode(OutputMode.Complete) + .foreach(new TestForeachWriter()) + .start() + + // -- batch 0 --------------------------------------- + input.addData(1, 2, 3, 4) + query.processAllAvailable() + + var allEvents = ForeachSinkSuite.allEvents() + assert(allEvents.size === 1) + var expectedEvents = Seq( + ForeachSinkSuite.Open(partition = 0, version = 0), + ForeachSinkSuite.Process(value = 4), + ForeachSinkSuite.Close(None) + ) + assert(allEvents === Seq(expectedEvents)) + + ForeachSinkSuite.clear() + + // -- batch 1 --------------------------------------- + input.addData(5, 6, 7, 8) + query.processAllAvailable() + + allEvents = ForeachSinkSuite.allEvents() + assert(allEvents.size === 1) + expectedEvents = Seq( + ForeachSinkSuite.Open(partition = 0, version = 1), + ForeachSinkSuite.Process(value = 8), + ForeachSinkSuite.Close(None) + ) + assert(allEvents === Seq(expectedEvents)) + query.stop() } } -- cgit v1.2.3