aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala
diff options
context:
space:
mode:
authorHolden Karau <holden@us.ibm.com>2016-09-23 09:44:30 +0100
committerSean Owen <sowen@cloudera.com>2016-09-23 09:44:30 +0100
commit90d5754212425d55f992c939a2bc7d9ac6ef92b8 (patch)
tree59d36048cc576bce47d1003e35186204532bf894 /core/src/main/scala
parent5c5396cb4725ba5ceee26ed885e8b941d219757b (diff)
downloadspark-90d5754212425d55f992c939a2bc7d9ac6ef92b8.tar.gz
spark-90d5754212425d55f992c939a2bc7d9ac6ef92b8.tar.bz2
spark-90d5754212425d55f992c939a2bc7d9ac6ef92b8.zip
[SPARK-16861][PYSPARK][CORE] Refactor PySpark accumulator API on top of Accumulator V2
## What changes were proposed in this pull request? Move the internals of the PySpark accumulator API from the old deprecated API on top of the new accumulator API. ## How was this patch tested? The existing PySpark accumulator tests (both unit tests and doc tests at the start of accumulator.py). Author: Holden Karau <holden@us.ibm.com> Closes #14467 from holdenk/SPARK-16861-refactor-pyspark-accumulator-api.
Diffstat (limited to 'core/src/main/scala')
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala42
1 files changed, 23 insertions, 19 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index d841091a31..0ca91b9bf8 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -20,7 +20,7 @@ package org.apache.spark.api.python
import java.io._
import java.net._
import java.nio.charset.StandardCharsets
-import java.util.{ArrayList => JArrayList, Collections, List => JList, Map => JMap}
+import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}
import scala.collection.JavaConverters._
import scala.collection.mutable
@@ -38,7 +38,7 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.input.PortableDataStream
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
-import org.apache.spark.util.{SerializableConfiguration, Utils}
+import org.apache.spark.util._
private[spark] class PythonRDD(
@@ -75,7 +75,7 @@ private[spark] case class PythonFunction(
pythonExec: String,
pythonVer: String,
broadcastVars: JList[Broadcast[PythonBroadcast]],
- accumulator: Accumulator[JList[Array[Byte]]])
+ accumulator: PythonAccumulatorV2)
/**
* A wrapper for chained Python functions (from bottom to top).
@@ -200,7 +200,7 @@ private[spark] class PythonRunner(
val updateLen = stream.readInt()
val update = new Array[Byte](updateLen)
stream.readFully(update)
- accumulator += Collections.singletonList(update)
+ accumulator.add(update)
}
// Check whether the worker is ready to be re-used.
if (stream.readInt() == SpecialLengths.END_OF_STREAM) {
@@ -461,7 +461,7 @@ private[spark] object PythonRDD extends Logging {
JavaRDD[Array[Byte]] = {
val file = new DataInputStream(new FileInputStream(filename))
try {
- val objs = new collection.mutable.ArrayBuffer[Array[Byte]]
+ val objs = new mutable.ArrayBuffer[Array[Byte]]
try {
while (true) {
val length = file.readInt()
@@ -866,11 +866,13 @@ class BytesToString extends org.apache.spark.api.java.function.Function[Array[By
}
/**
- * Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it
+ * Internal class that acts as an `AccumulatorV2` for Python accumulators. Inside, it
* collects a list of pickled strings that we pass to Python through a socket.
*/
-private class PythonAccumulatorParam(@transient private val serverHost: String, serverPort: Int)
- extends AccumulatorParam[JList[Array[Byte]]] {
+private[spark] class PythonAccumulatorV2(
+ @transient private val serverHost: String,
+ private val serverPort: Int)
+ extends CollectionAccumulator[Array[Byte]] {
Utils.checkHost(serverHost, "Expected hostname")
@@ -880,30 +882,33 @@ private class PythonAccumulatorParam(@transient private val serverHost: String,
* We try to reuse a single Socket to transfer accumulator updates, as they are all added
* by the DAGScheduler's single-threaded RpcEndpoint anyway.
*/
- @transient var socket: Socket = _
+ @transient private var socket: Socket = _
- def openSocket(): Socket = synchronized {
+ private def openSocket(): Socket = synchronized {
if (socket == null || socket.isClosed) {
socket = new Socket(serverHost, serverPort)
}
socket
}
- override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList
+ // Need to override so the types match with PythonFunction
+ override def copyAndReset(): PythonAccumulatorV2 = new PythonAccumulatorV2(serverHost, serverPort)
- override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]])
- : JList[Array[Byte]] = synchronized {
+ override def merge(other: AccumulatorV2[Array[Byte], JList[Array[Byte]]]): Unit = synchronized {
+ val otherPythonAccumulator = other.asInstanceOf[PythonAccumulatorV2]
+ // This conditional isn't strictly speaking needed - merging only currently happens on the
+ // driver program - but that isn't gauranteed so incase this changes.
if (serverHost == null) {
- // This happens on the worker node, where we just want to remember all the updates
- val1.addAll(val2)
- val1
+ // We are on the worker
+ super.merge(otherPythonAccumulator)
} else {
// This happens on the master, where we pass the updates to Python through a socket
val socket = openSocket()
val in = socket.getInputStream
val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize))
- out.writeInt(val2.size)
- for (array <- val2.asScala) {
+ val values = other.value
+ out.writeInt(values.size)
+ for (array <- values.asScala) {
out.writeInt(array.length)
out.write(array)
}
@@ -913,7 +918,6 @@ private class PythonAccumulatorParam(@transient private val serverHost: String,
if (byteRead == -1) {
throw new SparkException("EOF reached before Python server acknowledged")
}
- null
}
}
}