aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala36
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala280
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala127
3 files changed, 436 insertions, 7 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
index d9bf4d3ccf..f9d20ad090 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
@@ -17,18 +17,21 @@
package org.apache.spark.sql.execution.python
+import java.io.File
+
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import net.razorvine.pickle.{Pickler, Unpickler}
-import org.apache.spark.TaskContext
+import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonRunner}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.types.{DataType, StructField, StructType}
+import org.apache.spark.util.Utils
/**
@@ -37,9 +40,25 @@ import org.apache.spark.sql.types.{DataType, StructField, StructType}
* Python evaluation works by sending the necessary (projected) input data via a socket to an
* external Python process, and combine the result from the Python process with the original row.
*
- * For each row we send to Python, we also put it in a queue. For each output row from Python,
+ * For each row we send to Python, we also put it in a queue first. For each output row from Python,
* we drain the queue to find the original input row. Note that if the Python process is way too
- * slow, this could lead to the queue growing unbounded and eventually run out of memory.
+ * slow, this could lead to the queue growing unbounded and spill into disk when run out of memory.
+ *
+ * Here is a diagram to show how this works:
+ *
+ * Downstream (for parent)
+ * / \
+ * / socket (output of UDF)
+ * / \
+ * RowQueue Python
+ * \ /
+ * \ socket (input of UDF)
+ * \ /
+ * upstream (from child)
+ *
+ * The rows sent to and received from Python are packed into batches (100 rows) and serialized,
+ * there should be always some rows buffered in the socket or Python process, so the pulling from
+ * RowQueue ALWAYS happened after pushing into it.
*/
case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
extends SparkPlan {
@@ -70,7 +89,11 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi
// The queue used to buffer input rows so we can drain it to
// combine input with output from Python.
- val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]()
+ val queue = HybridRowQueue(TaskContext.get().taskMemoryManager(),
+ new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length)
+ TaskContext.get().addTaskCompletionListener({ ctx =>
+ queue.close()
+ })
val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip
@@ -98,7 +121,7 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi
// For each row, add it to the queue.
val inputIterator = iter.grouped(100).map { inputRows =>
val toBePickled = inputRows.map { inputRow =>
- queue.add(inputRow)
+ queue.add(inputRow.asInstanceOf[UnsafeRow])
val row = projection(inputRow)
if (needConversion) {
EvaluatePython.toJava(row, schema)
@@ -132,7 +155,6 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi
StructType(udfs.map(u => StructField("", u.dataType, u.nullable)))
}
val resultProj = UnsafeProjection.create(output, output)
-
outputIterator.flatMap { pickedResult =>
val unpickledBatch = unpickle.loads(pickedResult)
unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
@@ -144,7 +166,7 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi
} else {
EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow]
}
- resultProj(joined(queue.poll(), row))
+ resultProj(joined(queue.remove(), row))
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala
new file mode 100644
index 0000000000..422a3f862d
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala
@@ -0,0 +1,280 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution.python
+
+import java.io._
+
+import com.google.common.io.Closeables
+
+import org.apache.spark.SparkException
+import org.apache.spark.memory.{MemoryConsumer, TaskMemoryManager}
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.unsafe.Platform
+import org.apache.spark.unsafe.memory.MemoryBlock
+
+/**
+ * A RowQueue is an FIFO queue for UnsafeRow.
+ *
+ * This RowQueue is ONLY designed and used for Python UDF, which has only one writer and only one
+ * reader, the reader ALWAYS ran behind the writer. See the doc of class [[BatchEvalPythonExec]]
+ * on how it works.
+ */
+private[python] trait RowQueue {
+
+ /**
+ * Add a row to the end of it, returns true iff the row has been added to the queue.
+ */
+ def add(row: UnsafeRow): Boolean
+
+ /**
+ * Retrieve and remove the first row, returns null if it's empty.
+ *
+ * It can only be called after add is called, otherwise it will fail (NPE).
+ */
+ def remove(): UnsafeRow
+
+ /**
+ * Cleanup all the resources.
+ */
+ def close(): Unit
+}
+
+/**
+ * A RowQueue that is based on in-memory page. UnsafeRows are appended into it until it's full.
+ * Another thread could read from it at the same time (behind the writer).
+ *
+ * The format of UnsafeRow in page:
+ * [4 bytes to hold length of record (N)] [N bytes to hold record] [...]
+ *
+ * -1 length means end of page.
+ */
+private[python] abstract class InMemoryRowQueue(val page: MemoryBlock, numFields: Int)
+ extends RowQueue {
+ private val base: AnyRef = page.getBaseObject
+ private val endOfPage: Long = page.getBaseOffset + page.size
+ // the first location where a new row would be written
+ private var writeOffset = page.getBaseOffset
+ // points to the start of the next row to read
+ private var readOffset = page.getBaseOffset
+ private val resultRow = new UnsafeRow(numFields)
+
+ def add(row: UnsafeRow): Boolean = synchronized {
+ val size = row.getSizeInBytes
+ if (writeOffset + 4 + size > endOfPage) {
+ // if there is not enough space in this page to hold the new record
+ if (writeOffset + 4 <= endOfPage) {
+ // if there's extra space at the end of the page, store a special "end-of-page" length (-1)
+ Platform.putInt(base, writeOffset, -1)
+ }
+ false
+ } else {
+ Platform.putInt(base, writeOffset, size)
+ Platform.copyMemory(row.getBaseObject, row.getBaseOffset, base, writeOffset + 4, size)
+ writeOffset += 4 + size
+ true
+ }
+ }
+
+ def remove(): UnsafeRow = synchronized {
+ assert(readOffset <= writeOffset, "reader should not go beyond writer")
+ if (readOffset + 4 > endOfPage || Platform.getInt(base, readOffset) < 0) {
+ null
+ } else {
+ val size = Platform.getInt(base, readOffset)
+ resultRow.pointTo(base, readOffset + 4, size)
+ readOffset += 4 + size
+ resultRow
+ }
+ }
+}
+
+/**
+ * A RowQueue that is backed by a file on disk. This queue will stop accepting new rows once any
+ * reader has begun reading from the queue.
+ */
+private[python] case class DiskRowQueue(file: File, fields: Int) extends RowQueue {
+ private var out = new DataOutputStream(
+ new BufferedOutputStream(new FileOutputStream(file.toString)))
+ private var unreadBytes = 0L
+
+ private var in: DataInputStream = _
+ private val resultRow = new UnsafeRow(fields)
+
+ def add(row: UnsafeRow): Boolean = synchronized {
+ if (out == null) {
+ // Another thread is reading, stop writing this one
+ return false
+ }
+ out.writeInt(row.getSizeInBytes)
+ out.write(row.getBytes)
+ unreadBytes += 4 + row.getSizeInBytes
+ true
+ }
+
+ def remove(): UnsafeRow = synchronized {
+ if (out != null) {
+ out.close()
+ out = null
+ in = new DataInputStream(new BufferedInputStream(new FileInputStream(file.toString)))
+ }
+
+ if (unreadBytes > 0) {
+ val size = in.readInt()
+ val bytes = new Array[Byte](size)
+ in.readFully(bytes)
+ unreadBytes -= 4 + size
+ resultRow.pointTo(bytes, size)
+ resultRow
+ } else {
+ null
+ }
+ }
+
+ def close(): Unit = synchronized {
+ Closeables.close(out, true)
+ out = null
+ Closeables.close(in, true)
+ in = null
+ if (file.exists()) {
+ file.delete()
+ }
+ }
+}
+
+/**
+ * A RowQueue that has a list of RowQueues, which could be in memory or disk.
+ *
+ * HybridRowQueue could be safely appended in one thread, and pulled in another thread in the same
+ * time.
+ */
+private[python] case class HybridRowQueue(
+ memManager: TaskMemoryManager,
+ tempDir: File,
+ numFields: Int)
+ extends MemoryConsumer(memManager) with RowQueue {
+
+ // Each buffer should have at least one row
+ private var queues = new java.util.LinkedList[RowQueue]()
+
+ private var writing: RowQueue = _
+ private var reading: RowQueue = _
+
+ // exposed for testing
+ private[python] def numQueues(): Int = queues.size()
+
+ def spill(size: Long, trigger: MemoryConsumer): Long = {
+ if (trigger == this) {
+ // When it's triggered by itself, it should write upcoming rows into disk instead of copying
+ // the rows already in the queue.
+ return 0L
+ }
+ var released = 0L
+ synchronized {
+ // poll out all the buffers and add them back in the same order to make sure that the rows
+ // are in correct order.
+ val newQueues = new java.util.LinkedList[RowQueue]()
+ while (!queues.isEmpty) {
+ val queue = queues.remove()
+ val newQueue = if (!queues.isEmpty && queue.isInstanceOf[InMemoryRowQueue]) {
+ val diskQueue = createDiskQueue()
+ var row = queue.remove()
+ while (row != null) {
+ diskQueue.add(row)
+ row = queue.remove()
+ }
+ released += queue.asInstanceOf[InMemoryRowQueue].page.size()
+ queue.close()
+ diskQueue
+ } else {
+ queue
+ }
+ newQueues.add(newQueue)
+ }
+ queues = newQueues
+ }
+ released
+ }
+
+ private def createDiskQueue(): RowQueue = {
+ DiskRowQueue(File.createTempFile("buffer", "", tempDir), numFields)
+ }
+
+ private def createNewQueue(required: Long): RowQueue = {
+ val page = try {
+ allocatePage(required)
+ } catch {
+ case _: OutOfMemoryError =>
+ null
+ }
+ val buffer = if (page != null) {
+ new InMemoryRowQueue(page, numFields) {
+ override def close(): Unit = {
+ freePage(page)
+ }
+ }
+ } else {
+ createDiskQueue()
+ }
+
+ synchronized {
+ queues.add(buffer)
+ }
+ buffer
+ }
+
+ def add(row: UnsafeRow): Boolean = {
+ if (writing == null || !writing.add(row)) {
+ writing = createNewQueue(4 + row.getSizeInBytes)
+ if (!writing.add(row)) {
+ throw new SparkException(s"failed to push a row into $writing")
+ }
+ }
+ true
+ }
+
+ def remove(): UnsafeRow = {
+ var row: UnsafeRow = null
+ if (reading != null) {
+ row = reading.remove()
+ }
+ if (row == null) {
+ if (reading != null) {
+ reading.close()
+ }
+ synchronized {
+ reading = queues.remove()
+ }
+ assert(reading != null, s"queue should not be empty")
+ row = reading.remove()
+ assert(row != null, s"$reading should have at least one row")
+ }
+ row
+ }
+
+ def close(): Unit = {
+ if (reading != null) {
+ reading.close()
+ reading = null
+ }
+ synchronized {
+ while (!queues.isEmpty) {
+ queues.remove().close()
+ }
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala
new file mode 100644
index 0000000000..ffda33cf90
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import java.io.File
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.memory.{MemoryManager, TaskMemoryManager, TestMemoryManager}
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.unsafe.memory.MemoryBlock
+import org.apache.spark.util.Utils
+
+class RowQueueSuite extends SparkFunSuite {
+
+ test("in-memory queue") {
+ val page = MemoryBlock.fromLongArray(new Array[Long](1<<10))
+ val queue = new InMemoryRowQueue(page, 1) {
+ override def close() {}
+ }
+ val row = new UnsafeRow(1)
+ row.pointTo(new Array[Byte](16), 16)
+ val n = page.size() / (4 + row.getSizeInBytes)
+ var i = 0
+ while (i < n) {
+ row.setLong(0, i)
+ assert(queue.add(row), "fail to add")
+ i += 1
+ }
+ assert(!queue.add(row), "should not add more")
+ i = 0
+ while (i < n) {
+ val row = queue.remove()
+ assert(row != null, "fail to poll")
+ assert(row.getLong(0) == i, "does not match")
+ i += 1
+ }
+ assert(queue.remove() == null, "should be empty")
+ queue.close()
+ }
+
+ test("disk queue") {
+ val dir = Utils.createTempDir().getCanonicalFile
+ dir.mkdirs()
+ val queue = DiskRowQueue(new File(dir, "buffer"), 1)
+ val row = new UnsafeRow(1)
+ row.pointTo(new Array[Byte](16), 16)
+ val n = 1000
+ var i = 0
+ while (i < n) {
+ row.setLong(0, i)
+ assert(queue.add(row), "fail to add")
+ i += 1
+ }
+ val first = queue.remove()
+ assert(first != null, "first should not be null")
+ assert(first.getLong(0) == 0, "first should be 0")
+ assert(!queue.add(row), "should not add more")
+ i = 1
+ while (i < n) {
+ val row = queue.remove()
+ assert(row != null, "fail to poll")
+ assert(row.getLong(0) == i, "does not match")
+ i += 1
+ }
+ assert(queue.remove() == null, "should be empty")
+ queue.close()
+ }
+
+ test("hybrid queue") {
+ val mem = new TestMemoryManager(new SparkConf())
+ mem.limit(4<<10)
+ val taskM = new TaskMemoryManager(mem, 0)
+ val queue = HybridRowQueue(taskM, Utils.createTempDir().getCanonicalFile, 1)
+ val row = new UnsafeRow(1)
+ row.pointTo(new Array[Byte](16), 16)
+ val n = (4<<10) / 16 * 3
+ var i = 0
+ while (i < n) {
+ row.setLong(0, i)
+ assert(queue.add(row), "fail to add")
+ i += 1
+ }
+ assert(queue.numQueues() > 1, "should have more than one queue")
+ queue.spill(1<<20, null)
+ i = 0
+ while (i < n) {
+ val row = queue.remove()
+ assert(row != null, "fail to poll")
+ assert(row.getLong(0) == i, "does not match")
+ i += 1
+ }
+
+ // fill again and spill
+ i = 0
+ while (i < n) {
+ row.setLong(0, i)
+ assert(queue.add(row), "fail to add")
+ i += 1
+ }
+ assert(queue.numQueues() > 1, "should have more than one queue")
+ queue.spill(1<<20, null)
+ assert(queue.numQueues() > 1, "should have more than one queue")
+ i = 0
+ while (i < n) {
+ val row = queue.remove()
+ assert(row != null, "fail to poll")
+ assert(row.getLong(0) == i, "does not match")
+ i += 1
+ }
+ queue.close()
+ }
+}