aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorShixiong Zhu <shixiong@databricks.com>2016-06-10 00:11:46 -0700
committerTathagata Das <tathagata.das1565@gmail.com>2016-06-10 00:11:46 -0700
commit00c310133df4f3893dd90d801168c2ab9841b102 (patch)
treeba13cb409c9cab4b214181340c7eedf6276b8388
parent5a3533e779d8e43ce0980203dfd3cbe343cc7d0a (diff)
downloadspark-00c310133df4f3893dd90d801168c2ab9841b102.tar.gz
spark-00c310133df4f3893dd90d801168c2ab9841b102.tar.bz2
spark-00c310133df4f3893dd90d801168c2ab9841b102.zip
[SPARK-15593][SQL] Add DataFrameWriter.foreach to allow the user consuming data in ContinuousQuery
## What changes were proposed in this pull request? * Add DataFrameWriter.foreach to allow the user consuming data in ContinuousQuery * ForeachWriter is the interface for the user to consume partitions of data * Add a type parameter T to DataFrameWriter Usage ```Scala val ds = spark.read....stream().as[String] ds.....write .queryName(...) .option("checkpointLocation", ...) .foreach(new ForeachWriter[Int] { def open(partitionId: Long, version: Long): Boolean = { // prepare some resources for a partition // check `version` if possible and return `false` if this is a duplicated data to skip the data processing. } override def process(value: Int): Unit = { // process data } def close(errorOrNull: Throwable): Unit = { // release resources for a partition // check `errorOrNull` and handle the error if necessary. } }) ``` ## How was this patch tested? New unit tests. Author: Shixiong Zhu <shixiong@databricks.com> Closes #13342 from zsxwing/foreach.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala150
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala105
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala53
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala141
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala4
6 files changed, 413 insertions, 42 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index 1dd8818ded..32e2fdc3f9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, Project}
import org.apache.spark.sql.execution.datasources.{BucketSpec, CreateTableUsingAsSelect, DataSource, HadoopFsRelation}
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
-import org.apache.spark.sql.execution.streaming.{MemoryPlan, MemorySink, StreamExecution}
+import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.{ContinuousQuery, OutputMode, ProcessingTime, Trigger}
import org.apache.spark.util.Utils
@@ -40,7 +40,9 @@ import org.apache.spark.util.Utils
*
* @since 1.4.0
*/
-final class DataFrameWriter private[sql](df: DataFrame) {
+final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
+
+ private val df = ds.toDF()
/**
* Specifies the behavior when data or table already exists. Options include:
@@ -51,7 +53,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
*
* @since 1.4.0
*/
- def mode(saveMode: SaveMode): DataFrameWriter = {
+ def mode(saveMode: SaveMode): DataFrameWriter[T] = {
// mode() is used for non-continuous queries
// outputMode() is used for continuous queries
assertNotStreaming("mode() can only be called on non-continuous queries")
@@ -68,7 +70,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
*
* @since 1.4.0
*/
- def mode(saveMode: String): DataFrameWriter = {
+ def mode(saveMode: String): DataFrameWriter[T] = {
// mode() is used for non-continuous queries
// outputMode() is used for continuous queries
assertNotStreaming("mode() can only be called on non-continuous queries")
@@ -93,7 +95,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 2.0.0
*/
@Experimental
- def outputMode(outputMode: OutputMode): DataFrameWriter = {
+ def outputMode(outputMode: OutputMode): DataFrameWriter[T] = {
assertStreaming("outputMode() can only be called on continuous queries")
this.outputMode = outputMode
this
@@ -109,7 +111,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 2.0.0
*/
@Experimental
- def outputMode(outputMode: String): DataFrameWriter = {
+ def outputMode(outputMode: String): DataFrameWriter[T] = {
assertStreaming("outputMode() can only be called on continuous queries")
this.outputMode = outputMode.toLowerCase match {
case "append" =>
@@ -147,7 +149,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 2.0.0
*/
@Experimental
- def trigger(trigger: Trigger): DataFrameWriter = {
+ def trigger(trigger: Trigger): DataFrameWriter[T] = {
assertStreaming("trigger() can only be called on continuous queries")
this.trigger = trigger
this
@@ -158,7 +160,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
*
* @since 1.4.0
*/
- def format(source: String): DataFrameWriter = {
+ def format(source: String): DataFrameWriter[T] = {
this.source = source
this
}
@@ -168,7 +170,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
*
* @since 1.4.0
*/
- def option(key: String, value: String): DataFrameWriter = {
+ def option(key: String, value: String): DataFrameWriter[T] = {
this.extraOptions += (key -> value)
this
}
@@ -178,28 +180,28 @@ final class DataFrameWriter private[sql](df: DataFrame) {
*
* @since 2.0.0
*/
- def option(key: String, value: Boolean): DataFrameWriter = option(key, value.toString)
+ def option(key: String, value: Boolean): DataFrameWriter[T] = option(key, value.toString)
/**
* Adds an output option for the underlying data source.
*
* @since 2.0.0
*/
- def option(key: String, value: Long): DataFrameWriter = option(key, value.toString)
+ def option(key: String, value: Long): DataFrameWriter[T] = option(key, value.toString)
/**
* Adds an output option for the underlying data source.
*
* @since 2.0.0
*/
- def option(key: String, value: Double): DataFrameWriter = option(key, value.toString)
+ def option(key: String, value: Double): DataFrameWriter[T] = option(key, value.toString)
/**
* (Scala-specific) Adds output options for the underlying data source.
*
* @since 1.4.0
*/
- def options(options: scala.collection.Map[String, String]): DataFrameWriter = {
+ def options(options: scala.collection.Map[String, String]): DataFrameWriter[T] = {
this.extraOptions ++= options
this
}
@@ -209,7 +211,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
*
* @since 1.4.0
*/
- def options(options: java.util.Map[String, String]): DataFrameWriter = {
+ def options(options: java.util.Map[String, String]): DataFrameWriter[T] = {
this.options(options.asScala)
this
}
@@ -232,7 +234,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 1.4.0
*/
@scala.annotation.varargs
- def partitionBy(colNames: String*): DataFrameWriter = {
+ def partitionBy(colNames: String*): DataFrameWriter[T] = {
this.partitioningColumns = Option(colNames)
this
}
@@ -246,7 +248,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 2.0
*/
@scala.annotation.varargs
- def bucketBy(numBuckets: Int, colName: String, colNames: String*): DataFrameWriter = {
+ def bucketBy(numBuckets: Int, colName: String, colNames: String*): DataFrameWriter[T] = {
this.numBuckets = Option(numBuckets)
this.bucketColumnNames = Option(colName +: colNames)
this
@@ -260,7 +262,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 2.0
*/
@scala.annotation.varargs
- def sortBy(colName: String, colNames: String*): DataFrameWriter = {
+ def sortBy(colName: String, colNames: String*): DataFrameWriter[T] = {
this.sortColumnNames = Option(colName +: colNames)
this
}
@@ -301,7 +303,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 2.0.0
*/
@Experimental
- def queryName(queryName: String): DataFrameWriter = {
+ def queryName(queryName: String): DataFrameWriter[T] = {
assertStreaming("queryName() can only be called on continuous queries")
this.extraOptions += ("queryName" -> queryName)
this
@@ -337,16 +339,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
val queryName =
extraOptions.getOrElse(
"queryName", throw new AnalysisException("queryName must be specified for memory sink"))
- val checkpointLocation = extraOptions.get("checkpointLocation").map { userSpecified =>
- new Path(userSpecified).toUri.toString
- }.orElse {
- val checkpointConfig: Option[String] =
- df.sparkSession.conf.get(SQLConf.CHECKPOINT_LOCATION)
-
- checkpointConfig.map { location =>
- new Path(location, queryName).toUri.toString
- }
- }.getOrElse {
+ val checkpointLocation = getCheckpointLocation(queryName, failIfNotSet = false).getOrElse {
Utils.createTempDir(namePrefix = "memory.stream").getCanonicalPath
}
@@ -378,21 +371,10 @@ final class DataFrameWriter private[sql](df: DataFrame) {
className = source,
options = extraOptions.toMap,
partitionColumns = normalizedParCols.getOrElse(Nil))
-
val queryName = extraOptions.getOrElse("queryName", StreamExecution.nextName)
- val checkpointLocation = extraOptions.get("checkpointLocation")
- .orElse {
- df.sparkSession.sessionState.conf.checkpointLocation.map { l =>
- new Path(l, queryName).toUri.toString
- }
- }.getOrElse {
- throw new AnalysisException("checkpointLocation must be specified either " +
- "through option() or SQLConf")
- }
-
df.sparkSession.sessionState.continuousQueryManager.startQuery(
queryName,
- checkpointLocation,
+ getCheckpointLocation(queryName, failIfNotSet = true).get,
df,
dataSource.createSink(outputMode),
outputMode,
@@ -401,6 +383,94 @@ final class DataFrameWriter private[sql](df: DataFrame) {
}
/**
+ * :: Experimental ::
+ * Starts the execution of the streaming query, which will continually send results to the given
+ * [[ForeachWriter]] as as new data arrives. The [[ForeachWriter]] can be used to send the data
+ * generated by the [[DataFrame]]/[[Dataset]] to an external system. The returned The returned
+ * [[ContinuousQuery]] object can be used to interact with the stream.
+ *
+ * Scala example:
+ * {{{
+ * datasetOfString.write.foreach(new ForeachWriter[String] {
+ *
+ * def open(partitionId: Long, version: Long): Boolean = {
+ * // open connection
+ * }
+ *
+ * def process(record: String) = {
+ * // write string to connection
+ * }
+ *
+ * def close(errorOrNull: Throwable): Unit = {
+ * // close the connection
+ * }
+ * })
+ * }}}
+ *
+ * Java example:
+ * {{{
+ * datasetOfString.write().foreach(new ForeachWriter<String>() {
+ *
+ * @Override
+ * public boolean open(long partitionId, long version) {
+ * // open connection
+ * }
+ *
+ * @Override
+ * public void process(String value) {
+ * // write string to connection
+ * }
+ *
+ * @Override
+ * public void close(Throwable errorOrNull) {
+ * // close the connection
+ * }
+ * });
+ * }}}
+ *
+ * @since 2.0.0
+ */
+ @Experimental
+ def foreach(writer: ForeachWriter[T]): ContinuousQuery = {
+ assertNotBucketed("foreach")
+ assertStreaming(
+ "foreach() can only be called on streaming Datasets/DataFrames.")
+
+ val queryName = extraOptions.getOrElse("queryName", StreamExecution.nextName)
+ val sink = new ForeachSink[T](ds.sparkSession.sparkContext.clean(writer))(ds.exprEnc)
+ df.sparkSession.sessionState.continuousQueryManager.startQuery(
+ queryName,
+ getCheckpointLocation(queryName, failIfNotSet = false).getOrElse {
+ Utils.createTempDir(namePrefix = "foreach.stream").getCanonicalPath
+ },
+ df,
+ sink,
+ outputMode,
+ trigger)
+ }
+
+ /**
+ * Returns the checkpointLocation for a query. If `failIfNotSet` is `true` but the checkpoint
+ * location is not set, [[AnalysisException]] will be thrown. If `failIfNotSet` is `false`, `None`
+ * will be returned if the checkpoint location is not set.
+ */
+ private def getCheckpointLocation(queryName: String, failIfNotSet: Boolean): Option[String] = {
+ val checkpointLocation = extraOptions.get("checkpointLocation").map { userSpecified =>
+ new Path(userSpecified).toUri.toString
+ }.orElse {
+ df.sparkSession.conf.get(SQLConf.CHECKPOINT_LOCATION).map { location =>
+ new Path(location, queryName).toUri.toString
+ }
+ }
+ if (failIfNotSet && checkpointLocation.isEmpty) {
+ throw new AnalysisException("checkpointLocation must be specified either " +
+ """through option("checkpointLocation", ...) or """ +
+ s"""SparkSession.conf.set("${SQLConf.CHECKPOINT_LOCATION.key}", ...)""")
+ }
+ checkpointLocation
+ }
+
+ /**
* Inserts the content of the [[DataFrame]] to the specified table. It requires that
* the schema of the [[DataFrame]] is the same as the schema of the table.
*
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 162524a9ef..16bbf30a94 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -2400,7 +2400,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
- def write: DataFrameWriter = new DataFrameWriter(toDF())
+ def write: DataFrameWriter[T] = new DataFrameWriter[T](this)
/**
* Returns the content of the Dataset as a Dataset of JSON strings.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala
new file mode 100644
index 0000000000..09f07426a6
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.sql.streaming.ContinuousQuery
+
+/**
+ * :: Experimental ::
+ * A class to consume data generated by a [[ContinuousQuery]]. Typically this is used to send the
+ * generated data to external systems. Each partition will use a new deserialized instance, so you
+ * usually should do all the initialization (e.g. opening a connection or initiating a transaction)
+ * in the `open` method.
+ *
+ * Scala example:
+ * {{{
+ * datasetOfString.write.foreach(new ForeachWriter[String] {
+ *
+ * def open(partitionId: Long, version: Long): Boolean = {
+ * // open connection
+ * }
+ *
+ * def process(record: String) = {
+ * // write string to connection
+ * }
+ *
+ * def close(errorOrNull: Throwable): Unit = {
+ * // close the connection
+ * }
+ * })
+ * }}}
+ *
+ * Java example:
+ * {{{
+ * datasetOfString.write().foreach(new ForeachWriter<String>() {
+ *
+ * @Override
+ * public boolean open(long partitionId, long version) {
+ * // open connection
+ * }
+ *
+ * @Override
+ * public void process(String value) {
+ * // write string to connection
+ * }
+ *
+ * @Override
+ * public void close(Throwable errorOrNull) {
+ * // close the connection
+ * }
+ * });
+ * }}}
+ * @since 2.0.0
+ */
+@Experimental
+abstract class ForeachWriter[T] extends Serializable {
+
+ /**
+ * Called when starting to process one partition of new data in the executor. The `version` is
+ * for data deduplication when there are failures. When recovering from a failure, some data may
+ * be generated multiple times but they will always have the same version.
+ *
+ * If this method finds using the `partitionId` and `version` that this partition has already been
+ * processed, it can return `false` to skip the further data processing. However, `close` still
+ * will be called for cleaning up resources.
+ *
+ * @param partitionId the partition id.
+ * @param version a unique id for data deduplication.
+ * @return `true` if the corresponding partition and version id should be processed. `false`
+ * indicates the partition should be skipped.
+ */
+ def open(partitionId: Long, version: Long): Boolean
+
+ /**
+ * Called to process the data in the executor side. This method will be called only when `open`
+ * returns `true`.
+ */
+ def process(value: T): Unit
+
+ /**
+ * Called when stopping to process one partition of new data in the executor side. This is
+ * guaranteed to be called either `open` returns `true` or `false`. However,
+ * `close` won't be called in the following cases:
+ * - JVM crashes without throwing a `Throwable`
+ * - `open` throws a `Throwable`.
+ *
+ * @param errorOrNull the error thrown during processing data or null if there was no error.
+ */
+ def close(errorOrNull: Throwable): Unit
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala
new file mode 100644
index 0000000000..14b9b1cb09
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.TaskContext
+import org.apache.spark.sql.{DataFrame, Encoder, ForeachWriter}
+
+/**
+ * A [[Sink]] that forwards all data into [[ForeachWriter]] according to the contract defined by
+ * [[ForeachWriter]].
+ *
+ * @param writer The [[ForeachWriter]] to process all data.
+ * @tparam T The expected type of the sink.
+ */
+class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with Serializable {
+
+ override def addBatch(batchId: Long, data: DataFrame): Unit = {
+ data.as[T].foreachPartition { iter =>
+ if (writer.open(TaskContext.getPartitionId(), batchId)) {
+ var isFailed = false
+ try {
+ while (iter.hasNext) {
+ writer.process(iter.next())
+ }
+ } catch {
+ case e: Throwable =>
+ isFailed = true
+ writer.close(e)
+ }
+ if (!isFailed) {
+ writer.close(null)
+ }
+ } else {
+ writer.close(null)
+ }
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala
new file mode 100644
index 0000000000..e1fb3b9478
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent.ConcurrentLinkedQueue
+
+import scala.collection.mutable
+
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark.sql.ForeachWriter
+import org.apache.spark.sql.streaming.StreamTest
+import org.apache.spark.sql.test.SharedSQLContext
+
+class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAfter {
+
+ import testImplicits._
+
+ after {
+ sqlContext.streams.active.foreach(_.stop())
+ }
+
+ test("foreach") {
+ withTempDir { checkpointDir =>
+ val input = MemoryStream[Int]
+ val query = input.toDS().repartition(2).write
+ .option("checkpointLocation", checkpointDir.getCanonicalPath)
+ .foreach(new TestForeachWriter())
+ input.addData(1, 2, 3, 4)
+ query.processAllAvailable()
+
+ val expectedEventsForPartition0 = Seq(
+ ForeachSinkSuite.Open(partition = 0, version = 0),
+ ForeachSinkSuite.Process(value = 1),
+ ForeachSinkSuite.Process(value = 3),
+ ForeachSinkSuite.Close(None)
+ )
+ val expectedEventsForPartition1 = Seq(
+ ForeachSinkSuite.Open(partition = 1, version = 0),
+ ForeachSinkSuite.Process(value = 2),
+ ForeachSinkSuite.Process(value = 4),
+ ForeachSinkSuite.Close(None)
+ )
+
+ val allEvents = ForeachSinkSuite.allEvents()
+ assert(allEvents.size === 2)
+ assert {
+ allEvents === Seq(expectedEventsForPartition0, expectedEventsForPartition1) ||
+ allEvents === Seq(expectedEventsForPartition1, expectedEventsForPartition0)
+ }
+ query.stop()
+ }
+ }
+
+ test("foreach with error") {
+ withTempDir { checkpointDir =>
+ val input = MemoryStream[Int]
+ val query = input.toDS().repartition(1).write
+ .option("checkpointLocation", checkpointDir.getCanonicalPath)
+ .foreach(new TestForeachWriter() {
+ override def process(value: Int): Unit = {
+ super.process(value)
+ throw new RuntimeException("error")
+ }
+ })
+ input.addData(1, 2, 3, 4)
+ query.processAllAvailable()
+
+ val allEvents = ForeachSinkSuite.allEvents()
+ assert(allEvents.size === 1)
+ assert(allEvents(0)(0) === ForeachSinkSuite.Open(partition = 0, version = 0))
+ assert(allEvents(0)(1) === ForeachSinkSuite.Process(value = 1))
+ val errorEvent = allEvents(0)(2).asInstanceOf[ForeachSinkSuite.Close]
+ assert(errorEvent.error.get.isInstanceOf[RuntimeException])
+ assert(errorEvent.error.get.getMessage === "error")
+ query.stop()
+ }
+ }
+}
+
+/** A global object to collect events in the executor */
+object ForeachSinkSuite {
+
+ trait Event
+
+ case class Open(partition: Long, version: Long) extends Event
+
+ case class Process[T](value: T) extends Event
+
+ case class Close(error: Option[Throwable]) extends Event
+
+ private val _allEvents = new ConcurrentLinkedQueue[Seq[Event]]()
+
+ def addEvents(events: Seq[Event]): Unit = {
+ _allEvents.add(events)
+ }
+
+ def allEvents(): Seq[Seq[Event]] = {
+ _allEvents.toArray(new Array[Seq[Event]](_allEvents.size()))
+ }
+
+ def clear(): Unit = {
+ _allEvents.clear()
+ }
+}
+
+/** A [[ForeachWriter]] that writes collected events to ForeachSinkSuite */
+class TestForeachWriter extends ForeachWriter[Int] {
+ ForeachSinkSuite.clear()
+
+ private val events = mutable.ArrayBuffer[ForeachSinkSuite.Event]()
+
+ override def open(partitionId: Long, version: Long): Boolean = {
+ events += ForeachSinkSuite.Open(partition = partitionId, version = version)
+ true
+ }
+
+ override def process(value: Int): Unit = {
+ events += ForeachSinkSuite.Process(value)
+ }
+
+ override def close(errorOrNull: Throwable): Unit = {
+ events += ForeachSinkSuite.Close(error = Option(errorOrNull))
+ ForeachSinkSuite.addEvents(events)
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
index bab0092c37..fc01ff3f5a 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
@@ -238,7 +238,9 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
shuffleLeft: Boolean,
shuffleRight: Boolean): Unit = {
withTable("bucketed_table1", "bucketed_table2") {
- def withBucket(writer: DataFrameWriter, bucketSpec: Option[BucketSpec]): DataFrameWriter = {
+ def withBucket(
+ writer: DataFrameWriter[Row],
+ bucketSpec: Option[BucketSpec]): DataFrameWriter[Row] = {
bucketSpec.map { spec =>
writer.bucketBy(
spec.numBuckets,