aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala31
-rw-r--r--core/src/test/scala/org/apache/spark/util/UtilsSuite.scala37
2 files changed, 43 insertions, 25 deletions
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 7e204fa218..1a9dbcae8c 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -2344,29 +2344,24 @@ private[spark] class RedirectThread(
* the toString method.
*/
private[spark] class CircularBuffer(sizeInBytes: Int = 10240) extends java.io.OutputStream {
- var pos: Int = 0
- var buffer = new Array[Int](sizeInBytes)
+ private var pos: Int = 0
+ private var isBufferFull = false
+ private val buffer = new Array[Byte](sizeInBytes)
- def write(i: Int): Unit = {
- buffer(pos) = i
+ def write(input: Int): Unit = {
+ buffer(pos) = input.toByte
pos = (pos + 1) % buffer.length
+ isBufferFull = isBufferFull || (pos == 0)
}
override def toString: String = {
- val (end, start) = buffer.splitAt(pos)
- val input = new java.io.InputStream {
- val iterator = (start ++ end).iterator
-
- def read(): Int = if (iterator.hasNext) iterator.next() else -1
- }
- val reader = new BufferedReader(new InputStreamReader(input, StandardCharsets.UTF_8))
- val stringBuilder = new StringBuilder
- var line = reader.readLine()
- while (line != null) {
- stringBuilder.append(line)
- stringBuilder.append("\n")
- line = reader.readLine()
+ if (!isBufferFull) {
+ return new String(buffer, 0, pos, StandardCharsets.UTF_8)
}
- stringBuilder.toString()
+
+ val nonCircularBuffer = new Array[Byte](sizeInBytes)
+ System.arraycopy(buffer, pos, nonCircularBuffer, 0, buffer.length - pos)
+ System.arraycopy(buffer, 0, nonCircularBuffer, buffer.length - pos, pos)
+ new String(nonCircularBuffer, StandardCharsets.UTF_8)
}
}
diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
index 4aa4854c36..6698749866 100644
--- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.util
-import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, FileOutputStream}
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, FileOutputStream, PrintStream}
import java.lang.{Double => JDouble, Float => JFloat}
import java.net.{BindException, ServerSocket, URI}
import java.nio.{ByteBuffer, ByteOrder}
@@ -681,14 +681,37 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging {
assert(!Utils.isInDirectory(nullFile, childFile3))
}
- test("circular buffer") {
+ test("circular buffer: if nothing was written to the buffer, display nothing") {
+ val buffer = new CircularBuffer(4)
+ assert(buffer.toString === "")
+ }
+
+ test("circular buffer: if the buffer isn't full, print only the contents written") {
+ val buffer = new CircularBuffer(10)
+ val stream = new PrintStream(buffer, true, "UTF-8")
+ stream.print("test")
+ assert(buffer.toString === "test")
+ }
+
+ test("circular buffer: data written == size of the buffer") {
+ val buffer = new CircularBuffer(4)
+ val stream = new PrintStream(buffer, true, "UTF-8")
+
+ // fill the buffer to its exact size so that it just hits overflow
+ stream.print("test")
+ assert(buffer.toString === "test")
+
+ // add more data to the buffer
+ stream.print("12")
+ assert(buffer.toString === "st12")
+ }
+
+ test("circular buffer: multiple overflow") {
val buffer = new CircularBuffer(25)
- val stream = new java.io.PrintStream(buffer, true, "UTF-8")
+ val stream = new PrintStream(buffer, true, "UTF-8")
- // scalastyle:off println
- stream.println("test circular test circular test circular test circular test circular")
- // scalastyle:on println
- assert(buffer.toString === "t circular test circular\n")
+ stream.print("test circular test circular test circular test circular test circular")
+ assert(buffer.toString === "st circular test circular")
}
test("nanSafeCompareDoubles") {