aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorMatei Zaharia <matei@databricks.com>2016-06-19 21:27:04 -0700
committerReynold Xin <rxin@databricks.com>2016-06-19 21:27:04 -0700
commit4f17fddcd57adeae0d7e31bd14423283d4b625e9 (patch)
treefead6917eb4150c6a80bcfc4f4a2c5a4ba8a7f4e /sql
parent5930d7a2e95b2fe4d470cf39546e5a12306553fe (diff)
downloadspark-4f17fddcd57adeae0d7e31bd14423283d4b625e9.tar.gz
spark-4f17fddcd57adeae0d7e31bd14423283d4b625e9.tar.bz2
spark-4f17fddcd57adeae0d7e31bd14423283d4b625e9.zip
[SPARK-16031] Add debug-only socket source in Structured Streaming
## What changes were proposed in this pull request? This patch adds a text-based socket source similar to the one in Spark Streaming for debugging and tutorials. The source is clearly marked as debug-only so that users don't try to run it in production applications, because this type of source cannot provide HA without storing a lot of state in Spark. ## How was this patch tested? Unit tests and manual tests in spark-shell. Author: Matei Zaharia <matei@databricks.com> Closes #13748 from mateiz/socket-source.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala144
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala136
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala2
9 files changed, 293 insertions, 0 deletions
diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
index 9f8bb5d38f..27d32b5dca 100644
--- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
+++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
@@ -4,3 +4,4 @@ org.apache.spark.sql.execution.datasources.json.JsonFileFormat
org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
org.apache.spark.sql.execution.datasources.text.TextFileFormat
org.apache.spark.sql.execution.streaming.ConsoleSinkProvider
+org.apache.spark.sql.execution.streaming.TextSocketSourceProvider
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala
index bef56160f6..9886ad0b41 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala
@@ -128,4 +128,6 @@ class FileStreamSource(
override def getOffset: Option[Offset] = Some(fetchMaxOffset()).filterNot(_.offset == -1)
override def toString: String = s"FileStreamSource[$qualifiedBasePath]"
+
+ override def stop() {}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala
index 14450c2e2f..971147840d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala
@@ -39,4 +39,7 @@ trait Source {
* same data for a particular `start` and `end` pair.
*/
def getBatch(start: Option[Offset], end: Offset): DataFrame
+
+ /** Stop this source and free any resources it has allocated. */
+ def stop(): Unit
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index 4aefd39b36..bb42a11759 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -399,6 +399,7 @@ class StreamExecution(
microBatchThread.interrupt()
microBatchThread.join()
}
+ uniqueSources.foreach(_.stop())
logInfo(s"Query $name was stopped")
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
index 77fd043ef7..e37f0c7779 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
@@ -110,6 +110,8 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
sys.error("No data selected!")
}
}
+
+ override def stop() {}
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala
new file mode 100644
index 0000000000..d07d88dcdc
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{BufferedReader, InputStreamReader, IOException}
+import java.net.Socket
+import javax.annotation.concurrent.GuardedBy
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext}
+import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider}
+import org.apache.spark.sql.types.{StringType, StructField, StructType}
+
+object TextSocketSource {
+ val SCHEMA = StructType(StructField("value", StringType) :: Nil)
+}
+
+/**
+ * A source that reads text lines through a TCP socket, designed only for tutorials and debugging.
+ * This source will *not* work in production applications due to multiple reasons, including no
+ * support for fault recovery and keeping all of the text read in memory forever.
+ */
+class TextSocketSource(host: String, port: Int, sqlContext: SQLContext)
+ extends Source with Logging
+{
+ @GuardedBy("this")
+ private var socket: Socket = null
+
+ @GuardedBy("this")
+ private var readThread: Thread = null
+
+ @GuardedBy("this")
+ private var lines = new ArrayBuffer[String]
+
+ initialize()
+
+ private def initialize(): Unit = synchronized {
+ socket = new Socket(host, port)
+ val reader = new BufferedReader(new InputStreamReader(socket.getInputStream))
+ readThread = new Thread(s"TextSocketSource($host, $port)") {
+ setDaemon(true)
+
+ override def run(): Unit = {
+ try {
+ while (true) {
+ val line = reader.readLine()
+ if (line == null) {
+ // End of file reached
+ logWarning(s"Stream closed by $host:$port")
+ return
+ }
+ TextSocketSource.this.synchronized {
+ lines += line
+ }
+ }
+ } catch {
+ case e: IOException =>
+ }
+ }
+ }
+ readThread.start()
+ }
+
+ /** Returns the schema of the data from this source */
+ override def schema: StructType = TextSocketSource.SCHEMA
+
+ /** Returns the maximum available offset for this source. */
+ override def getOffset: Option[Offset] = synchronized {
+ if (lines.isEmpty) None else Some(LongOffset(lines.size - 1))
+ }
+
+ /** Returns the data that is between the offsets (`start`, `end`]. */
+ override def getBatch(start: Option[Offset], end: Offset): DataFrame = synchronized {
+ val startIdx = start.map(_.asInstanceOf[LongOffset].offset.toInt + 1).getOrElse(0)
+ val endIdx = end.asInstanceOf[LongOffset].offset.toInt + 1
+ val data = synchronized { lines.slice(startIdx, endIdx) }
+ import sqlContext.implicits._
+ data.toDF("value")
+ }
+
+ /** Stop this source. */
+ override def stop(): Unit = synchronized {
+ if (socket != null) {
+ try {
+ // Unfortunately, BufferedReader.readLine() cannot be interrupted, so the only way to
+ // stop the readThread is to close the socket.
+ socket.close()
+ } catch {
+ case e: IOException =>
+ }
+ socket = null
+ }
+ }
+}
+
+class TextSocketSourceProvider extends StreamSourceProvider with DataSourceRegister with Logging {
+ /** Returns the name and schema of the source that can be used to continually read data. */
+ override def sourceSchema(
+ sqlContext: SQLContext,
+ schema: Option[StructType],
+ providerName: String,
+ parameters: Map[String, String]): (String, StructType) = {
+ logWarning("The socket source should not be used for production applications! " +
+ "It does not support recovery and stores state indefinitely.")
+ if (!parameters.contains("host")) {
+ throw new AnalysisException("Set a host to read from with option(\"host\", ...).")
+ }
+ if (!parameters.contains("port")) {
+ throw new AnalysisException("Set a port to read from with option(\"port\", ...).")
+ }
+ ("textSocket", TextSocketSource.SCHEMA)
+ }
+
+ override def createSource(
+ sqlContext: SQLContext,
+ metadataPath: String,
+ schema: Option[StructType],
+ providerName: String,
+ parameters: Map[String, String]): Source = {
+ val host = parameters("host")
+ val port = parameters("port").toInt
+ new TextSocketSource(host, port, sqlContext)
+ }
+
+ /** String that represents the format that this data source provider uses. */
+ override def shortName(): String = "socket"
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala
new file mode 100644
index 0000000000..ca57763185
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{IOException, OutputStreamWriter}
+import java.net.ServerSocket
+import java.util.concurrent.LinkedBlockingQueue
+
+import org.scalatest.BeforeAndAfterEach
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.streaming.StreamTest
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.{StringType, StructField, StructType}
+
+class TextSocketStreamSuite extends StreamTest with SharedSQLContext with BeforeAndAfterEach {
+ import testImplicits._
+
+ override def afterEach() {
+ sqlContext.streams.active.foreach(_.stop())
+ if (serverThread != null) {
+ serverThread.interrupt()
+ serverThread.join()
+ serverThread = null
+ }
+ if (source != null) {
+ source.stop()
+ source = null
+ }
+ }
+
+ private var serverThread: ServerThread = null
+ private var source: Source = null
+
+ test("basic usage") {
+ serverThread = new ServerThread()
+ serverThread.start()
+
+ val provider = new TextSocketSourceProvider
+ val parameters = Map("host" -> "localhost", "port" -> serverThread.port.toString)
+ val schema = provider.sourceSchema(sqlContext, None, "", parameters)._2
+ assert(schema === StructType(StructField("value", StringType) :: Nil))
+
+ source = provider.createSource(sqlContext, "", None, "", parameters)
+
+ failAfter(streamingTimeout) {
+ serverThread.enqueue("hello")
+ while (source.getOffset.isEmpty) {
+ Thread.sleep(10)
+ }
+ val offset1 = source.getOffset.get
+ val batch1 = source.getBatch(None, offset1)
+ assert(batch1.as[String].collect().toSeq === Seq("hello"))
+
+ serverThread.enqueue("world")
+ while (source.getOffset.get === offset1) {
+ Thread.sleep(10)
+ }
+ val offset2 = source.getOffset.get
+ val batch2 = source.getBatch(Some(offset1), offset2)
+ assert(batch2.as[String].collect().toSeq === Seq("world"))
+
+ val both = source.getBatch(None, offset2)
+ assert(both.as[String].collect().sorted.toSeq === Seq("hello", "world"))
+
+ // Try stopping the source to make sure this does not block forever.
+ source.stop()
+ source = null
+ }
+ }
+
+ test("params not given") {
+ val provider = new TextSocketSourceProvider
+ intercept[AnalysisException] {
+ provider.sourceSchema(sqlContext, None, "", Map())
+ }
+ intercept[AnalysisException] {
+ provider.sourceSchema(sqlContext, None, "", Map("host" -> "localhost"))
+ }
+ intercept[AnalysisException] {
+ provider.sourceSchema(sqlContext, None, "", Map("port" -> "1234"))
+ }
+ }
+
+ test("no server up") {
+ val provider = new TextSocketSourceProvider
+ val parameters = Map("host" -> "localhost", "port" -> "0")
+ intercept[IOException] {
+ source = provider.createSource(sqlContext, "", None, "", parameters)
+ }
+ }
+
+ private class ServerThread extends Thread with Logging {
+ private val serverSocket = new ServerSocket(0)
+ private val messageQueue = new LinkedBlockingQueue[String]()
+
+ val port = serverSocket.getLocalPort
+
+ override def run(): Unit = {
+ try {
+ val clientSocket = serverSocket.accept()
+ clientSocket.setTcpNoDelay(true)
+ val out = new OutputStreamWriter(clientSocket.getOutputStream)
+ while (true) {
+ val line = messageQueue.take()
+ out.write(line + "\n")
+ out.flush()
+ }
+ } catch {
+ case e: InterruptedException =>
+ } finally {
+ serverSocket.close()
+ }
+ }
+
+ def enqueue(line: String): Unit = {
+ messageQueue.put(line)
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
index 786404a589..b8e40e71bf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
@@ -282,6 +282,8 @@ class FakeDefaultSource extends StreamSourceProvider {
val startOffset = start.map(_.asInstanceOf[LongOffset].offset).getOrElse(-1L) + 1
spark.range(startOffset, end.asInstanceOf[LongOffset].offset + 1).toDF("a")
}
+
+ override def stop() {}
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala
index 1aee1934c0..943e7b761e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala
@@ -84,6 +84,8 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider {
Seq[Int]().toDS().toDF()
}
+
+ override def stop() {}
}
}