From 90d5754212425d55f992c939a2bc7d9ac6ef92b8 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 23 Sep 2016 09:44:30 +0100 Subject: [SPARK-16861][PYSPARK][CORE] Refactor PySpark accumulator API on top of Accumulator V2 ## What changes were proposed in this pull request? Move the internals of the PySpark accumulator API from the old deprecated API on top of the new accumulator API. ## How was this patch tested? The existing PySpark accumulator tests (both unit tests and doc tests at the start of accumulator.py). Author: Holden Karau Closes #14467 from holdenk/SPARK-16861-refactor-pyspark-accumulator-api. --- .../org/apache/spark/api/python/PythonRDD.scala | 42 ++++++++++++---------- 1 file changed, 23 insertions(+), 19 deletions(-) (limited to 'core/src/main/scala/org/apache') diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index d841091a31..0ca91b9bf8 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -20,7 +20,7 @@ package org.apache.spark.api.python import java.io._ import java.net._ import java.nio.charset.StandardCharsets -import java.util.{ArrayList => JArrayList, Collections, List => JList, Map => JMap} +import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} import scala.collection.JavaConverters._ import scala.collection.mutable @@ -38,7 +38,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.input.PortableDataStream import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.util._ private[spark] class PythonRDD( @@ -75,7 +75,7 @@ private[spark] case class PythonFunction( pythonExec: String, pythonVer: String, broadcastVars: JList[Broadcast[PythonBroadcast]], - accumulator: Accumulator[JList[Array[Byte]]]) + accumulator: PythonAccumulatorV2) /** * A wrapper for chained Python functions (from bottom to top). @@ -200,7 +200,7 @@ private[spark] class PythonRunner( val updateLen = stream.readInt() val update = new Array[Byte](updateLen) stream.readFully(update) - accumulator += Collections.singletonList(update) + accumulator.add(update) } // Check whether the worker is ready to be re-used. if (stream.readInt() == SpecialLengths.END_OF_STREAM) { @@ -461,7 +461,7 @@ private[spark] object PythonRDD extends Logging { JavaRDD[Array[Byte]] = { val file = new DataInputStream(new FileInputStream(filename)) try { - val objs = new collection.mutable.ArrayBuffer[Array[Byte]] + val objs = new mutable.ArrayBuffer[Array[Byte]] try { while (true) { val length = file.readInt() @@ -866,11 +866,13 @@ class BytesToString extends org.apache.spark.api.java.function.Function[Array[By } /** - * Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it + * Internal class that acts as an `AccumulatorV2` for Python accumulators. Inside, it * collects a list of pickled strings that we pass to Python through a socket. */ -private class PythonAccumulatorParam(@transient private val serverHost: String, serverPort: Int) - extends AccumulatorParam[JList[Array[Byte]]] { +private[spark] class PythonAccumulatorV2( + @transient private val serverHost: String, + private val serverPort: Int) + extends CollectionAccumulator[Array[Byte]] { Utils.checkHost(serverHost, "Expected hostname") @@ -880,30 +882,33 @@ private class PythonAccumulatorParam(@transient private val serverHost: String, * We try to reuse a single Socket to transfer accumulator updates, as they are all added * by the DAGScheduler's single-threaded RpcEndpoint anyway. */ - @transient var socket: Socket = _ + @transient private var socket: Socket = _ - def openSocket(): Socket = synchronized { + private def openSocket(): Socket = synchronized { if (socket == null || socket.isClosed) { socket = new Socket(serverHost, serverPort) } socket } - override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList + // Need to override so the types match with PythonFunction + override def copyAndReset(): PythonAccumulatorV2 = new PythonAccumulatorV2(serverHost, serverPort) - override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]]) - : JList[Array[Byte]] = synchronized { + override def merge(other: AccumulatorV2[Array[Byte], JList[Array[Byte]]]): Unit = synchronized { + val otherPythonAccumulator = other.asInstanceOf[PythonAccumulatorV2] + // This conditional isn't strictly speaking needed - merging only currently happens on the + // driver program - but that isn't gauranteed so incase this changes. if (serverHost == null) { - // This happens on the worker node, where we just want to remember all the updates - val1.addAll(val2) - val1 + // We are on the worker + super.merge(otherPythonAccumulator) } else { // This happens on the master, where we pass the updates to Python through a socket val socket = openSocket() val in = socket.getInputStream val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize)) - out.writeInt(val2.size) - for (array <- val2.asScala) { + val values = other.value + out.writeInt(values.size) + for (array <- values.asScala) { out.writeInt(array.length) out.write(array) } @@ -913,7 +918,6 @@ private class PythonAccumulatorParam(@transient private val serverHost: String, if (byteRead == -1) { throw new SparkException("EOF reached before Python server acknowledged") } - null } } } -- cgit v1.2.3