aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/org/apache/spark/util/Utils.scala
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main/scala/org/apache/spark/util/Utils.scala')
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala781
1 files changed, 781 insertions, 0 deletions
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
new file mode 100644
index 0000000000..bb47fc0a2c
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -0,0 +1,781 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.io._
+import java.net.{InetAddress, URL, URI, NetworkInterface, Inet4Address, ServerSocket}
+import java.util.{Locale, Random, UUID}
+import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadFactory, ThreadPoolExecutor}
+import java.util.regex.Pattern
+
+import scala.collection.Map
+import scala.collection.mutable.{ArrayBuffer, HashMap}
+import scala.collection.JavaConversions._
+import scala.io.Source
+
+import com.google.common.io.Files
+import com.google.common.util.concurrent.ThreadFactoryBuilder
+
+import org.apache.hadoop.fs.{Path, FileSystem, FileUtil}
+
+import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance}
+import org.apache.spark.deploy.SparkHadoopUtil
+import java.nio.ByteBuffer
+import org.apache.spark.{SparkEnv, SparkException, Logging}
+
+
+/**
+ * Various utility methods used by Spark.
+ */
+private[spark] object Utils extends Logging {
+
+ /** Serialize an object using Java serialization */
+ def serialize[T](o: T): Array[Byte] = {
+ val bos = new ByteArrayOutputStream()
+ val oos = new ObjectOutputStream(bos)
+ oos.writeObject(o)
+ oos.close()
+ return bos.toByteArray
+ }
+
+ /** Deserialize an object using Java serialization */
+ def deserialize[T](bytes: Array[Byte]): T = {
+ val bis = new ByteArrayInputStream(bytes)
+ val ois = new ObjectInputStream(bis)
+ return ois.readObject.asInstanceOf[T]
+ }
+
+ /** Deserialize an object using Java serialization and the given ClassLoader */
+ def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = {
+ val bis = new ByteArrayInputStream(bytes)
+ val ois = new ObjectInputStream(bis) {
+ override def resolveClass(desc: ObjectStreamClass) =
+ Class.forName(desc.getName, false, loader)
+ }
+ return ois.readObject.asInstanceOf[T]
+ }
+
+ /** Serialize via nested stream using specific serializer */
+ def serializeViaNestedStream(os: OutputStream, ser: SerializerInstance)(f: SerializationStream => Unit) = {
+ val osWrapper = ser.serializeStream(new OutputStream {
+ def write(b: Int) = os.write(b)
+
+ override def write(b: Array[Byte], off: Int, len: Int) = os.write(b, off, len)
+ })
+ try {
+ f(osWrapper)
+ } finally {
+ osWrapper.close()
+ }
+ }
+
+ /** Deserialize via nested stream using specific serializer */
+ def deserializeViaNestedStream(is: InputStream, ser: SerializerInstance)(f: DeserializationStream => Unit) = {
+ val isWrapper = ser.deserializeStream(new InputStream {
+ def read(): Int = is.read()
+
+ override def read(b: Array[Byte], off: Int, len: Int): Int = is.read(b, off, len)
+ })
+ try {
+ f(isWrapper)
+ } finally {
+ isWrapper.close()
+ }
+ }
+
+ /**
+ * Primitive often used when writing {@link java.nio.ByteBuffer} to {@link java.io.DataOutput}.
+ */
+ def writeByteBuffer(bb: ByteBuffer, out: ObjectOutput) = {
+ if (bb.hasArray) {
+ out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining())
+ } else {
+ val bbval = new Array[Byte](bb.remaining())
+ bb.get(bbval)
+ out.write(bbval)
+ }
+ }
+
+ def isAlpha(c: Char): Boolean = {
+ (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z')
+ }
+
+ /** Split a string into words at non-alphabetic characters */
+ def splitWords(s: String): Seq[String] = {
+ val buf = new ArrayBuffer[String]
+ var i = 0
+ while (i < s.length) {
+ var j = i
+ while (j < s.length && isAlpha(s.charAt(j))) {
+ j += 1
+ }
+ if (j > i) {
+ buf += s.substring(i, j)
+ }
+ i = j
+ while (i < s.length && !isAlpha(s.charAt(i))) {
+ i += 1
+ }
+ }
+ return buf
+ }
+
+ private val shutdownDeletePaths = new collection.mutable.HashSet[String]()
+
+ // Register the path to be deleted via shutdown hook
+ def registerShutdownDeleteDir(file: File) {
+ val absolutePath = file.getAbsolutePath()
+ shutdownDeletePaths.synchronized {
+ shutdownDeletePaths += absolutePath
+ }
+ }
+
+ // Is the path already registered to be deleted via a shutdown hook ?
+ def hasShutdownDeleteDir(file: File): Boolean = {
+ val absolutePath = file.getAbsolutePath()
+ shutdownDeletePaths.synchronized {
+ shutdownDeletePaths.contains(absolutePath)
+ }
+ }
+
+ // Note: if file is child of some registered path, while not equal to it, then return true;
+ // else false. This is to ensure that two shutdown hooks do not try to delete each others
+ // paths - resulting in IOException and incomplete cleanup.
+ def hasRootAsShutdownDeleteDir(file: File): Boolean = {
+ val absolutePath = file.getAbsolutePath()
+ val retval = shutdownDeletePaths.synchronized {
+ shutdownDeletePaths.find { path =>
+ !absolutePath.equals(path) && absolutePath.startsWith(path)
+ }.isDefined
+ }
+ if (retval) {
+ logInfo("path = " + file + ", already present as root for deletion.")
+ }
+ retval
+ }
+
+ /** Create a temporary directory inside the given parent directory */
+ def createTempDir(root: String = System.getProperty("java.io.tmpdir")): File = {
+ var attempts = 0
+ val maxAttempts = 10
+ var dir: File = null
+ while (dir == null) {
+ attempts += 1
+ if (attempts > maxAttempts) {
+ throw new IOException("Failed to create a temp directory (under " + root + ") after " +
+ maxAttempts + " attempts!")
+ }
+ try {
+ dir = new File(root, "spark-" + UUID.randomUUID.toString)
+ if (dir.exists() || !dir.mkdirs()) {
+ dir = null
+ }
+ } catch { case e: IOException => ; }
+ }
+
+ registerShutdownDeleteDir(dir)
+
+ // Add a shutdown hook to delete the temp dir when the JVM exits
+ Runtime.getRuntime.addShutdownHook(new Thread("delete Spark temp dir " + dir) {
+ override def run() {
+ // Attempt to delete if some patch which is parent of this is not already registered.
+ if (! hasRootAsShutdownDeleteDir(dir)) Utils.deleteRecursively(dir)
+ }
+ })
+ dir
+ }
+
+ /** Copy all data from an InputStream to an OutputStream */
+ def copyStream(in: InputStream,
+ out: OutputStream,
+ closeStreams: Boolean = false)
+ {
+ val buf = new Array[Byte](8192)
+ var n = 0
+ while (n != -1) {
+ n = in.read(buf)
+ if (n != -1) {
+ out.write(buf, 0, n)
+ }
+ }
+ if (closeStreams) {
+ in.close()
+ out.close()
+ }
+ }
+
+ /**
+ * Download a file requested by the executor. Supports fetching the file in a variety of ways,
+ * including HTTP, HDFS and files on a standard filesystem, based on the URL parameter.
+ *
+ * Throws SparkException if the target file already exists and has different contents than
+ * the requested file.
+ */
+ def fetchFile(url: String, targetDir: File) {
+ val filename = url.split("/").last
+ val tempDir = getLocalDir
+ val tempFile = File.createTempFile("fetchFileTemp", null, new File(tempDir))
+ val targetFile = new File(targetDir, filename)
+ val uri = new URI(url)
+ uri.getScheme match {
+ case "http" | "https" | "ftp" =>
+ logInfo("Fetching " + url + " to " + tempFile)
+ val in = new URL(url).openStream()
+ val out = new FileOutputStream(tempFile)
+ Utils.copyStream(in, out, true)
+ if (targetFile.exists && !Files.equal(tempFile, targetFile)) {
+ tempFile.delete()
+ throw new SparkException(
+ "File " + targetFile + " exists and does not match contents of" + " " + url)
+ } else {
+ Files.move(tempFile, targetFile)
+ }
+ case "file" | null =>
+ // In the case of a local file, copy the local file to the target directory.
+ // Note the difference between uri vs url.
+ val sourceFile = if (uri.isAbsolute) new File(uri) else new File(url)
+ if (targetFile.exists) {
+ // If the target file already exists, warn the user if
+ if (!Files.equal(sourceFile, targetFile)) {
+ throw new SparkException(
+ "File " + targetFile + " exists and does not match contents of" + " " + url)
+ } else {
+ // Do nothing if the file contents are the same, i.e. this file has been copied
+ // previously.
+ logInfo(sourceFile.getAbsolutePath + " has been previously copied to "
+ + targetFile.getAbsolutePath)
+ }
+ } else {
+ // The file does not exist in the target directory. Copy it there.
+ logInfo("Copying " + sourceFile.getAbsolutePath + " to " + targetFile.getAbsolutePath)
+ Files.copy(sourceFile, targetFile)
+ }
+ case _ =>
+ // Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others
+ val env = SparkEnv.get
+ val uri = new URI(url)
+ val conf = env.hadoop.newConfiguration()
+ val fs = FileSystem.get(uri, conf)
+ val in = fs.open(new Path(uri))
+ val out = new FileOutputStream(tempFile)
+ Utils.copyStream(in, out, true)
+ if (targetFile.exists && !Files.equal(tempFile, targetFile)) {
+ tempFile.delete()
+ throw new SparkException("File " + targetFile + " exists and does not match contents of" +
+ " " + url)
+ } else {
+ Files.move(tempFile, targetFile)
+ }
+ }
+ // Decompress the file if it's a .tar or .tar.gz
+ if (filename.endsWith(".tar.gz") || filename.endsWith(".tgz")) {
+ logInfo("Untarring " + filename)
+ Utils.execute(Seq("tar", "-xzf", filename), targetDir)
+ } else if (filename.endsWith(".tar")) {
+ logInfo("Untarring " + filename)
+ Utils.execute(Seq("tar", "-xf", filename), targetDir)
+ }
+ // Make the file executable - That's necessary for scripts
+ FileUtil.chmod(targetFile.getAbsolutePath, "a+x")
+ }
+
+ /**
+ * Get a temporary directory using Spark's spark.local.dir property, if set. This will always
+ * return a single directory, even though the spark.local.dir property might be a list of
+ * multiple paths.
+ */
+ def getLocalDir: String = {
+ System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")).split(',')(0)
+ }
+
+ /**
+ * Shuffle the elements of a collection into a random order, returning the
+ * result in a new collection. Unlike scala.util.Random.shuffle, this method
+ * uses a local random number generator, avoiding inter-thread contention.
+ */
+ def randomize[T: ClassManifest](seq: TraversableOnce[T]): Seq[T] = {
+ randomizeInPlace(seq.toArray)
+ }
+
+ /**
+ * Shuffle the elements of an array into a random order, modifying the
+ * original array. Returns the original array.
+ */
+ def randomizeInPlace[T](arr: Array[T], rand: Random = new Random): Array[T] = {
+ for (i <- (arr.length - 1) to 1 by -1) {
+ val j = rand.nextInt(i)
+ val tmp = arr(j)
+ arr(j) = arr(i)
+ arr(i) = tmp
+ }
+ arr
+ }
+
+ /**
+ * Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4).
+ * Note, this is typically not used from within core spark.
+ */
+ lazy val localIpAddress: String = findLocalIpAddress()
+ lazy val localIpAddressHostname: String = getAddressHostName(localIpAddress)
+
+ private def findLocalIpAddress(): String = {
+ val defaultIpOverride = System.getenv("SPARK_LOCAL_IP")
+ if (defaultIpOverride != null) {
+ defaultIpOverride
+ } else {
+ val address = InetAddress.getLocalHost
+ if (address.isLoopbackAddress) {
+ // Address resolves to something like 127.0.1.1, which happens on Debian; try to find
+ // a better address using the local network interfaces
+ for (ni <- NetworkInterface.getNetworkInterfaces) {
+ for (addr <- ni.getInetAddresses if !addr.isLinkLocalAddress &&
+ !addr.isLoopbackAddress && addr.isInstanceOf[Inet4Address]) {
+ // We've found an address that looks reasonable!
+ logWarning("Your hostname, " + InetAddress.getLocalHost.getHostName + " resolves to" +
+ " a loopback address: " + address.getHostAddress + "; using " + addr.getHostAddress +
+ " instead (on interface " + ni.getName + ")")
+ logWarning("Set SPARK_LOCAL_IP if you need to bind to another address")
+ return addr.getHostAddress
+ }
+ }
+ logWarning("Your hostname, " + InetAddress.getLocalHost.getHostName + " resolves to" +
+ " a loopback address: " + address.getHostAddress + ", but we couldn't find any" +
+ " external IP address!")
+ logWarning("Set SPARK_LOCAL_IP if you need to bind to another address")
+ }
+ address.getHostAddress
+ }
+ }
+
+ private var customHostname: Option[String] = None
+
+ /**
+ * Allow setting a custom host name because when we run on Mesos we need to use the same
+ * hostname it reports to the master.
+ */
+ def setCustomHostname(hostname: String) {
+ // DEBUG code
+ Utils.checkHost(hostname)
+ customHostname = Some(hostname)
+ }
+
+ /**
+ * Get the local machine's hostname.
+ */
+ def localHostName(): String = {
+ customHostname.getOrElse(localIpAddressHostname)
+ }
+
+ def getAddressHostName(address: String): String = {
+ InetAddress.getByName(address).getHostName
+ }
+
+ def localHostPort(): String = {
+ val retval = System.getProperty("spark.hostPort", null)
+ if (retval == null) {
+ logErrorWithStack("spark.hostPort not set but invoking localHostPort")
+ return localHostName()
+ }
+
+ retval
+ }
+
+ def checkHost(host: String, message: String = "") {
+ assert(host.indexOf(':') == -1, message)
+ }
+
+ def checkHostPort(hostPort: String, message: String = "") {
+ assert(hostPort.indexOf(':') != -1, message)
+ }
+
+ // Used by DEBUG code : remove when all testing done
+ def logErrorWithStack(msg: String) {
+ try { throw new Exception } catch { case ex: Exception => { logError(msg, ex) } }
+ }
+
+ // Typically, this will be of order of number of nodes in cluster
+ // If not, we should change it to LRUCache or something.
+ private val hostPortParseResults = new ConcurrentHashMap[String, (String, Int)]()
+
+ def parseHostPort(hostPort: String): (String, Int) = {
+ {
+ // Check cache first.
+ var cached = hostPortParseResults.get(hostPort)
+ if (cached != null) return cached
+ }
+
+ val indx: Int = hostPort.lastIndexOf(':')
+ // This is potentially broken - when dealing with ipv6 addresses for example, sigh ...
+ // but then hadoop does not support ipv6 right now.
+ // For now, we assume that if port exists, then it is valid - not check if it is an int > 0
+ if (-1 == indx) {
+ val retval = (hostPort, 0)
+ hostPortParseResults.put(hostPort, retval)
+ return retval
+ }
+
+ val retval = (hostPort.substring(0, indx).trim(), hostPort.substring(indx + 1).trim().toInt)
+ hostPortParseResults.putIfAbsent(hostPort, retval)
+ hostPortParseResults.get(hostPort)
+ }
+
+ private[spark] val daemonThreadFactory: ThreadFactory =
+ new ThreadFactoryBuilder().setDaemon(true).build()
+
+ /**
+ * Wrapper over newCachedThreadPool.
+ */
+ def newDaemonCachedThreadPool(): ThreadPoolExecutor =
+ Executors.newCachedThreadPool(daemonThreadFactory).asInstanceOf[ThreadPoolExecutor]
+
+ /**
+ * Return the string to tell how long has passed in seconds. The passing parameter should be in
+ * millisecond.
+ */
+ def getUsedTimeMs(startTimeMs: Long): String = {
+ return " " + (System.currentTimeMillis - startTimeMs) + " ms"
+ }
+
+ /**
+ * Wrapper over newFixedThreadPool.
+ */
+ def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor =
+ Executors.newFixedThreadPool(nThreads, daemonThreadFactory).asInstanceOf[ThreadPoolExecutor]
+
+ /**
+ * Delete a file or directory and its contents recursively.
+ */
+ def deleteRecursively(file: File) {
+ if (file.isDirectory) {
+ for (child <- file.listFiles()) {
+ deleteRecursively(child)
+ }
+ }
+ if (!file.delete()) {
+ throw new IOException("Failed to delete: " + file)
+ }
+ }
+
+ /**
+ * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of megabytes.
+ * This is used to figure out how much memory to claim from Mesos based on the SPARK_MEM
+ * environment variable.
+ */
+ def memoryStringToMb(str: String): Int = {
+ val lower = str.toLowerCase
+ if (lower.endsWith("k")) {
+ (lower.substring(0, lower.length-1).toLong / 1024).toInt
+ } else if (lower.endsWith("m")) {
+ lower.substring(0, lower.length-1).toInt
+ } else if (lower.endsWith("g")) {
+ lower.substring(0, lower.length-1).toInt * 1024
+ } else if (lower.endsWith("t")) {
+ lower.substring(0, lower.length-1).toInt * 1024 * 1024
+ } else {// no suffix, so it's just a number in bytes
+ (lower.toLong / 1024 / 1024).toInt
+ }
+ }
+
+ /**
+ * Convert a quantity in bytes to a human-readable string such as "4.0 MB".
+ */
+ def bytesToString(size: Long): String = {
+ val TB = 1L << 40
+ val GB = 1L << 30
+ val MB = 1L << 20
+ val KB = 1L << 10
+
+ val (value, unit) = {
+ if (size >= 2*TB) {
+ (size.asInstanceOf[Double] / TB, "TB")
+ } else if (size >= 2*GB) {
+ (size.asInstanceOf[Double] / GB, "GB")
+ } else if (size >= 2*MB) {
+ (size.asInstanceOf[Double] / MB, "MB")
+ } else if (size >= 2*KB) {
+ (size.asInstanceOf[Double] / KB, "KB")
+ } else {
+ (size.asInstanceOf[Double], "B")
+ }
+ }
+ "%.1f %s".formatLocal(Locale.US, value, unit)
+ }
+
+ /**
+ * Returns a human-readable string representing a duration such as "35ms"
+ */
+ def msDurationToString(ms: Long): String = {
+ val second = 1000
+ val minute = 60 * second
+ val hour = 60 * minute
+
+ ms match {
+ case t if t < second =>
+ "%d ms".format(t)
+ case t if t < minute =>
+ "%.1f s".format(t.toFloat / second)
+ case t if t < hour =>
+ "%.1f m".format(t.toFloat / minute)
+ case t =>
+ "%.2f h".format(t.toFloat / hour)
+ }
+ }
+
+ /**
+ * Convert a quantity in megabytes to a human-readable string such as "4.0 MB".
+ */
+ def megabytesToString(megabytes: Long): String = {
+ bytesToString(megabytes * 1024L * 1024L)
+ }
+
+ /**
+ * Execute a command in the given working directory, throwing an exception if it completes
+ * with an exit code other than 0.
+ */
+ def execute(command: Seq[String], workingDir: File) {
+ val process = new ProcessBuilder(command: _*)
+ .directory(workingDir)
+ .redirectErrorStream(true)
+ .start()
+ new Thread("read stdout for " + command(0)) {
+ override def run() {
+ for (line <- Source.fromInputStream(process.getInputStream).getLines) {
+ System.err.println(line)
+ }
+ }
+ }.start()
+ val exitCode = process.waitFor()
+ if (exitCode != 0) {
+ throw new SparkException("Process " + command + " exited with code " + exitCode)
+ }
+ }
+
+ /**
+ * Execute a command in the current working directory, throwing an exception if it completes
+ * with an exit code other than 0.
+ */
+ def execute(command: Seq[String]) {
+ execute(command, new File("."))
+ }
+
+ /**
+ * Execute a command and get its output, throwing an exception if it yields a code other than 0.
+ */
+ def executeAndGetOutput(command: Seq[String], workingDir: File = new File("."),
+ extraEnvironment: Map[String, String] = Map.empty): String = {
+ val builder = new ProcessBuilder(command: _*)
+ .directory(workingDir)
+ val environment = builder.environment()
+ for ((key, value) <- extraEnvironment) {
+ environment.put(key, value)
+ }
+ val process = builder.start()
+ new Thread("read stderr for " + command(0)) {
+ override def run() {
+ for (line <- Source.fromInputStream(process.getErrorStream).getLines) {
+ System.err.println(line)
+ }
+ }
+ }.start()
+ val output = new StringBuffer
+ val stdoutThread = new Thread("read stdout for " + command(0)) {
+ override def run() {
+ for (line <- Source.fromInputStream(process.getInputStream).getLines) {
+ output.append(line)
+ }
+ }
+ }
+ stdoutThread.start()
+ val exitCode = process.waitFor()
+ stdoutThread.join() // Wait for it to finish reading output
+ if (exitCode != 0) {
+ throw new SparkException("Process " + command + " exited with code " + exitCode)
+ }
+ output.toString
+ }
+
+ /**
+ * A regular expression to match classes of the "core" Spark API that we want to skip when
+ * finding the call site of a method.
+ */
+ private val SPARK_CLASS_REGEX = """^spark(\.api\.java)?(\.rdd)?\.[A-Z]""".r
+
+ private[spark] class CallSiteInfo(val lastSparkMethod: String, val firstUserFile: String,
+ val firstUserLine: Int, val firstUserClass: String)
+
+ /**
+ * When called inside a class in the spark package, returns the name of the user code class
+ * (outside the spark package) that called into Spark, as well as which Spark method they called.
+ * This is used, for example, to tell users where in their code each RDD got created.
+ */
+ def getCallSiteInfo: CallSiteInfo = {
+ val trace = Thread.currentThread.getStackTrace().filter( el =>
+ (!el.getMethodName.contains("getStackTrace")))
+
+ // Keep crawling up the stack trace until we find the first function not inside of the spark
+ // package. We track the last (shallowest) contiguous Spark method. This might be an RDD
+ // transformation, a SparkContext function (such as parallelize), or anything else that leads
+ // to instantiation of an RDD. We also track the first (deepest) user method, file, and line.
+ var lastSparkMethod = "<unknown>"
+ var firstUserFile = "<unknown>"
+ var firstUserLine = 0
+ var finished = false
+ var firstUserClass = "<unknown>"
+
+ for (el <- trace) {
+ if (!finished) {
+ if (SPARK_CLASS_REGEX.findFirstIn(el.getClassName) != None) {
+ lastSparkMethod = if (el.getMethodName == "<init>") {
+ // Spark method is a constructor; get its class name
+ el.getClassName.substring(el.getClassName.lastIndexOf('.') + 1)
+ } else {
+ el.getMethodName
+ }
+ }
+ else {
+ firstUserLine = el.getLineNumber
+ firstUserFile = el.getFileName
+ firstUserClass = el.getClassName
+ finished = true
+ }
+ }
+ }
+ new CallSiteInfo(lastSparkMethod, firstUserFile, firstUserLine, firstUserClass)
+ }
+
+ def formatSparkCallSite = {
+ val callSiteInfo = getCallSiteInfo
+ "%s at %s:%s".format(callSiteInfo.lastSparkMethod, callSiteInfo.firstUserFile,
+ callSiteInfo.firstUserLine)
+ }
+
+ /** Return a string containing part of a file from byte 'start' to 'end'. */
+ def offsetBytes(path: String, start: Long, end: Long): String = {
+ val file = new File(path)
+ val length = file.length()
+ val effectiveEnd = math.min(length, end)
+ val effectiveStart = math.max(0, start)
+ val buff = new Array[Byte]((effectiveEnd-effectiveStart).toInt)
+ val stream = new FileInputStream(file)
+
+ stream.skip(effectiveStart)
+ stream.read(buff)
+ stream.close()
+ Source.fromBytes(buff).mkString
+ }
+
+ /**
+ * Clone an object using a Spark serializer.
+ */
+ def clone[T](value: T, serializer: SerializerInstance): T = {
+ serializer.deserialize[T](serializer.serialize(value))
+ }
+
+ /**
+ * Detect whether this thread might be executing a shutdown hook. Will always return true if
+ * the current thread is a running a shutdown hook but may spuriously return true otherwise (e.g.
+ * if System.exit was just called by a concurrent thread).
+ *
+ * Currently, this detects whether the JVM is shutting down by Runtime#addShutdownHook throwing
+ * an IllegalStateException.
+ */
+ def inShutdown(): Boolean = {
+ try {
+ val hook = new Thread {
+ override def run() {}
+ }
+ Runtime.getRuntime.addShutdownHook(hook)
+ Runtime.getRuntime.removeShutdownHook(hook)
+ } catch {
+ case ise: IllegalStateException => return true
+ }
+ return false
+ }
+
+ def isSpace(c: Char): Boolean = {
+ " \t\r\n".indexOf(c) != -1
+ }
+
+ /**
+ * Split a string of potentially quoted arguments from the command line the way that a shell
+ * would do it to determine arguments to a command. For example, if the string is 'a "b c" d',
+ * then it would be parsed as three arguments: 'a', 'b c' and 'd'.
+ */
+ def splitCommandString(s: String): Seq[String] = {
+ val buf = new ArrayBuffer[String]
+ var inWord = false
+ var inSingleQuote = false
+ var inDoubleQuote = false
+ var curWord = new StringBuilder
+ def endWord() {
+ buf += curWord.toString
+ curWord.clear()
+ }
+ var i = 0
+ while (i < s.length) {
+ var nextChar = s.charAt(i)
+ if (inDoubleQuote) {
+ if (nextChar == '"') {
+ inDoubleQuote = false
+ } else if (nextChar == '\\') {
+ if (i < s.length - 1) {
+ // Append the next character directly, because only " and \ may be escaped in
+ // double quotes after the shell's own expansion
+ curWord.append(s.charAt(i + 1))
+ i += 1
+ }
+ } else {
+ curWord.append(nextChar)
+ }
+ } else if (inSingleQuote) {
+ if (nextChar == '\'') {
+ inSingleQuote = false
+ } else {
+ curWord.append(nextChar)
+ }
+ // Backslashes are not treated specially in single quotes
+ } else if (nextChar == '"') {
+ inWord = true
+ inDoubleQuote = true
+ } else if (nextChar == '\'') {
+ inWord = true
+ inSingleQuote = true
+ } else if (!isSpace(nextChar)) {
+ curWord.append(nextChar)
+ inWord = true
+ } else if (inWord && isSpace(nextChar)) {
+ endWord()
+ inWord = false
+ }
+ i += 1
+ }
+ if (inWord || inDoubleQuote || inSingleQuote) {
+ endWord()
+ }
+ return buf
+ }
+
+ /* Calculates 'x' modulo 'mod', takes to consideration sign of x,
+ * i.e. if 'x' is negative, than 'x' % 'mod' is negative too
+ * so function return (x % mod) + mod in that case.
+ */
+ def nonNegativeMod(x: Int, mod: Int): Int = {
+ val rawMod = x % mod
+ rawMod + (if (rawMod < 0) mod else 0)
+ }
+}