aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorReynold Xin <rxin@apache.org>2013-11-25 00:57:51 -0800
committerReynold Xin <rxin@apache.org>2013-11-25 00:57:51 -0800
commit088995f917548ac549397f24916ea72b0c3fc9d0 (patch)
tree71ea8fb0e790254228129053007effd5bf6721a6 /core
parent6af03edcf15e517f69598c4e974cca69b63904fa (diff)
parent6bcac986b20477fcb8cc011ecff19f482e033794 (diff)
downloadspark-088995f917548ac549397f24916ea72b0c3fc9d0.tar.gz
spark-088995f917548ac549397f24916ea72b0c3fc9d0.tar.bz2
spark-088995f917548ac549397f24916ea72b0c3fc9d0.zip
Merge pull request #77 from amplab/upgrade
Sync with Spark master
Diffstat (limited to 'core')
-rw-r--r--core/pom.xml4
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala23
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala28
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala31
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala82
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala117
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala664
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala36
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala15
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala52
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockInfo.scala18
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala23
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala33
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala93
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala24
-rw-r--r--core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala94
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala46
-rw-r--r--core/src/test/scala/org/apache/spark/LocalSparkContext.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala45
-rw-r--r--core/src/test/scala/org/apache/spark/PartitioningSuite.scala10
-rw-r--r--core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala19
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala86
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala6
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala76
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/PrimitiveVectorSuite.scala117
43 files changed, 1229 insertions, 617 deletions
diff --git a/core/pom.xml b/core/pom.xml
index 8621d257e5..6af229c71d 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -159,6 +159,10 @@
<artifactId>metrics-ganglia</artifactId>
</dependency>
<dependency>
+ <groupId>com.codahale.metrics</groupId>
+ <artifactId>metrics-graphite</artifactId>
+ </dependency>
+ <dependency>
<groupId>org.apache.derby</groupId>
<artifactId>derby</artifactId>
<scope>test</scope>
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 880b49e8ef..42b2985b50 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -24,7 +24,6 @@ import java.util.concurrent.atomic.AtomicInteger
import scala.collection.Map
import scala.collection.generic.Growable
-import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
@@ -145,6 +144,14 @@ class SparkContext(
executorEnvs ++= environment
}
+ // Set SPARK_USER for user who is running SparkContext.
+ val sparkUser = Option {
+ Option(System.getProperty("user.name")).getOrElse(System.getenv("SPARK_USER"))
+ }.getOrElse {
+ SparkContext.SPARK_UNKNOWN_USER
+ }
+ executorEnvs("SPARK_USER") = sparkUser
+
// Create and start the scheduler
private[spark] var taskScheduler: TaskScheduler = {
// Regular expression used for local[N] master format
@@ -238,7 +245,6 @@ class SparkContext(
taskScheduler.start()
@volatile private[spark] var dagScheduler = new DAGScheduler(taskScheduler)
- dagScheduler.start()
ui.start()
@@ -272,6 +278,12 @@ class SparkContext(
override protected def childValue(parent: Properties): Properties = new Properties(parent)
}
+ private[spark] def getLocalProperties(): Properties = localProperties.get()
+
+ private[spark] def setLocalProperties(props: Properties) {
+ localProperties.set(props)
+ }
+
def initLocalProperties() {
localProperties.set(new Properties())
}
@@ -293,7 +305,7 @@ class SparkContext(
/** Set a human readable description of the current job. */
@deprecated("use setJobGroup", "0.8.1")
def setJobDescription(value: String) {
- setJobGroup("", value)
+ setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, value)
}
/**
@@ -796,11 +808,10 @@ class SparkContext(
val cleanedFunc = clean(func)
logInfo("Starting job: " + callSite)
val start = System.nanoTime
- val result = dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal,
+ dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal,
resultHandler, localProperties.get)
logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
rdd.doCheckpoint()
- result
}
/**
@@ -982,6 +993,8 @@ object SparkContext {
private[spark] val SPARK_JOB_GROUP_ID = "spark.jobGroup.id"
+ private[spark] val SPARK_UNKNOWN_USER = "<unknown>"
+
implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] {
def addInPlace(t1: Double, t2: Double): Double = t1 + t2
def zero(initialValue: Double) = 0.0
diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala
index 668032a3a2..0aa8852649 100644
--- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala
@@ -1,19 +1,19 @@
/*
*
- * * Licensed to the Apache Software Foundation (ASF) under one or more
- * * contributor license agreements. See the NOTICE file distributed with
- * * this work for additional information regarding copyright ownership.
- * * The ASF licenses this file to You under the Apache License, Version 2.0
- * * (the "License"); you may not use this file except in compliance with
- * * the License. You may obtain a copy of the License at
- * *
- * * http://www.apache.org/licenses/LICENSE-2.0
- * *
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS,
- * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * * See the License for the specific language governing permissions and
- * * limitations under the License.
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
*
*/
diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
index 308a2bfa22..a724900943 100644
--- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
@@ -17,12 +17,12 @@
package org.apache.spark.deploy
-import akka.actor.{ActorRef, Props, Actor, ActorSystem, Terminated}
+import akka.actor.ActorSystem
import org.apache.spark.deploy.worker.Worker
import org.apache.spark.deploy.master.Master
-import org.apache.spark.util.{Utils, AkkaUtils}
-import org.apache.spark.{Logging}
+import org.apache.spark.util.Utils
+import org.apache.spark.Logging
import scala.collection.mutable.ArrayBuffer
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
index 6bc846aa92..fc1537f796 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
@@ -17,16 +17,39 @@
package org.apache.spark.deploy
+import java.security.PrivilegedExceptionAction
+
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.JobConf
+import org.apache.hadoop.security.UserGroupInformation
-import org.apache.spark.SparkException
+import org.apache.spark.{SparkContext, SparkException}
/**
* Contains util methods to interact with Hadoop from Spark.
*/
private[spark]
class SparkHadoopUtil {
+ val conf = newConfiguration()
+ UserGroupInformation.setConfiguration(conf)
+
+ def runAsUser(user: String)(func: () => Unit) {
+ // if we are already running as the user intended there is no reason to do the doAs. It
+ // will actually break secure HDFS access as it doesn't fill in the credentials. Also if
+ // the user is UNKNOWN then we shouldn't be creating a remote unknown user
+ // (this is actually the path spark on yarn takes) since SPARK_USER is initialized only
+ // in SparkContext.
+ val currentUser = Option(System.getProperty("user.name")).
+ getOrElse(SparkContext.SPARK_UNKNOWN_USER)
+ if (user != SparkContext.SPARK_UNKNOWN_USER && currentUser != user) {
+ val ugi = UserGroupInformation.createRemoteUser(user)
+ ugi.doAs(new PrivilegedExceptionAction[Unit] {
+ def run: Unit = func()
+ })
+ } else {
+ func()
+ }
+ }
/**
* Return an appropriate (subclass) of Configuration. Creating config can initializes some Hadoop
@@ -42,9 +65,9 @@ class SparkHadoopUtil {
def isYarnMode(): Boolean = { false }
}
-
+
object SparkHadoopUtil {
- private val hadoop = {
+ private val hadoop = {
val yarnMode = java.lang.Boolean.valueOf(System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE")))
if (yarnMode) {
try {
@@ -56,7 +79,7 @@ object SparkHadoopUtil {
new SparkHadoopUtil
}
}
-
+
def get: SparkHadoopUtil = {
hadoop
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
index 8fabc95665..fff9cb60c7 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
@@ -104,7 +104,7 @@ private[spark] class ExecutorRunner(
// SPARK-698: do not call the run.cmd script, as process.destroy()
// fails to kill a process tree on Windows
Seq(runner) ++ buildJavaOpts() ++ Seq(command.mainClass) ++
- command.arguments.map(substituteVariables)
+ (command.arguments ++ Seq(appId)).map(substituteVariables)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index 80ff4c59cb..8332631838 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -22,7 +22,7 @@ import java.nio.ByteBuffer
import akka.actor.{ActorRef, Actor, Props, Terminated}
import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientShutdown, RemoteClientDisconnected}
-import org.apache.spark.{Logging, SparkEnv}
+import org.apache.spark.Logging
import org.apache.spark.TaskState.TaskState
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
import org.apache.spark.util.{Utils, AkkaUtils}
@@ -111,7 +111,7 @@ private[spark] object CoarseGrainedExecutorBackend {
def main(args: Array[String]) {
if (args.length < 4) {
- //the reason we allow the last frameworkId argument is to make it easy to kill rogue executors
+ //the reason we allow the last appid argument is to make it easy to kill rogue executors
System.err.println(
"Usage: CoarseGrainedExecutorBackend <driverUrl> <executorId> <hostname> <cores> " +
"[<appid>]")
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index b773346df3..5c9bb9db1c 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -25,8 +25,9 @@ import java.util.concurrent._
import scala.collection.JavaConversions._
import scala.collection.mutable.HashMap
-import org.apache.spark.scheduler._
import org.apache.spark._
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.scheduler._
import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
import org.apache.spark.util.Utils
@@ -129,6 +130,8 @@ private[spark] class Executor(
// Maintains the list of running tasks.
private val runningTasks = new ConcurrentHashMap[Long, TaskRunner]
+ val sparkUser = Option(System.getenv("SPARK_USER")).getOrElse(SparkContext.SPARK_UNKNOWN_USER)
+
def launchTask(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) {
val tr = new TaskRunner(context, taskId, serializedTask)
runningTasks.put(taskId, tr)
@@ -176,7 +179,7 @@ private[spark] class Executor(
}
}
- override def run() {
+ override def run(): Unit = SparkHadoopUtil.get.runAsUser(sparkUser) { () =>
val startTime = System.currentTimeMillis()
SparkEnv.set(env)
Thread.currentThread.setContextClassLoader(replClassLoader)
diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala
index 34ed9c8f73..97176e4f5b 100644
--- a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala
+++ b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala
@@ -20,8 +20,6 @@ package org.apache.spark.executor
import com.codahale.metrics.{Gauge, MetricRegistry}
import org.apache.hadoop.fs.FileSystem
-import org.apache.hadoop.hdfs.DistributedFileSystem
-import org.apache.hadoop.fs.LocalFileSystem
import scala.collection.JavaConversions._
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala
new file mode 100644
index 0000000000..cdcfec8ca7
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.metrics.sink
+
+import java.util.Properties
+import java.util.concurrent.TimeUnit
+import java.net.InetSocketAddress
+
+import com.codahale.metrics.MetricRegistry
+import com.codahale.metrics.graphite.{GraphiteReporter, Graphite}
+
+import org.apache.spark.metrics.MetricsSystem
+
+class GraphiteSink(val property: Properties, val registry: MetricRegistry) extends Sink {
+ val GRAPHITE_DEFAULT_PERIOD = 10
+ val GRAPHITE_DEFAULT_UNIT = "SECONDS"
+ val GRAPHITE_DEFAULT_PREFIX = ""
+
+ val GRAPHITE_KEY_HOST = "host"
+ val GRAPHITE_KEY_PORT = "port"
+ val GRAPHITE_KEY_PERIOD = "period"
+ val GRAPHITE_KEY_UNIT = "unit"
+ val GRAPHITE_KEY_PREFIX = "prefix"
+
+ def propertyToOption(prop: String) = Option(property.getProperty(prop))
+
+ if (!propertyToOption(GRAPHITE_KEY_HOST).isDefined) {
+ throw new Exception("Graphite sink requires 'host' property.")
+ }
+
+ if (!propertyToOption(GRAPHITE_KEY_PORT).isDefined) {
+ throw new Exception("Graphite sink requires 'port' property.")
+ }
+
+ val host = propertyToOption(GRAPHITE_KEY_HOST).get
+ val port = propertyToOption(GRAPHITE_KEY_PORT).get.toInt
+
+ val pollPeriod = propertyToOption(GRAPHITE_KEY_PERIOD) match {
+ case Some(s) => s.toInt
+ case None => GRAPHITE_DEFAULT_PERIOD
+ }
+
+ val pollUnit = propertyToOption(GRAPHITE_KEY_UNIT) match {
+ case Some(s) => TimeUnit.valueOf(s.toUpperCase())
+ case None => TimeUnit.valueOf(GRAPHITE_DEFAULT_UNIT)
+ }
+
+ val prefix = propertyToOption(GRAPHITE_KEY_PREFIX).getOrElse(GRAPHITE_DEFAULT_PREFIX)
+
+ MetricsSystem.checkMinimalPollingPeriod(pollUnit, pollPeriod)
+
+ val graphite: Graphite = new Graphite(new InetSocketAddress(host, port))
+
+ val reporter: GraphiteReporter = GraphiteReporter.forRegistry(registry)
+ .convertDurationsTo(TimeUnit.MILLISECONDS)
+ .convertRatesTo(TimeUnit.SECONDS)
+ .prefixedWith(prefix)
+ .build(graphite)
+
+ override def start() {
+ reporter.start(pollPeriod, pollUnit)
+ }
+
+ override def stop() {
+ reporter.stop()
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala
index 481ff8c3e0..b1e1576dad 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala
@@ -76,7 +76,7 @@ private[spark] object ShuffleCopier extends Logging {
extends FileClientHandler with Logging {
override def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) {
- logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)");
+ logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)")
resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen))
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala
index 9b0c882481..0de22f0e06 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala
@@ -70,7 +70,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
override def compute(split: Partition, context: TaskContext) = {
val currSplit = split.asInstanceOf[CartesianPartition]
for (x <- rdd1.iterator(currSplit.s1, context);
- y <- rdd2.iterator(currSplit.s2, context)) yield (x, y)
+ y <- rdd2.iterator(currSplit.s2, context)) yield (x, y)
}
override def getDependencies: Seq[Dependency[_]] = List(
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index 32901a508f..53f77a38f5 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -52,7 +52,7 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSp
* sources in HBase, or S3).
*
* @param sc The SparkContext to associate the RDD with.
- * @param broadCastedConf A general Hadoop Configuration, or a subclass of it. If the enclosed
+ * @param broadcastedConf A general Hadoop Configuration, or a subclass of it. If the enclosed
* variabe references an instance of JobConf, then that JobConf will be used for the Hadoop job.
* Otherwise, a new JobConf will be created on each slave using the enclosed Configuration.
* @param initLocalJobConfFuncOpt Optional closure used to initialize any JobConf that HadoopRDD
@@ -132,6 +132,8 @@ class HadoopRDD[K, V](
override def getPartitions: Array[Partition] = {
val jobConf = getJobConf()
+ // add the credentials here as this can be called before SparkContext initialized
+ SparkHadoopUtil.get.addCredentials(jobConf)
val inputFormat = getInputFormat(jobConf)
if (inputFormat.isInstanceOf[Configurable]) {
inputFormat.asInstanceOf[Configurable].setConf(jobConf)
diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala
index 165cd412fc..574dd4233f 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala
@@ -33,11 +33,13 @@ class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boo
extends NarrowDependency[T](rdd) {
@transient
- val partitions: Array[Partition] = rdd.partitions.zipWithIndex
- .filter(s => partitionFilterFunc(s._2))
+ val partitions: Array[Partition] = rdd.partitions
+ .filter(s => partitionFilterFunc(s.index)).zipWithIndex
.map { case(split, idx) => new PartitionPruningRDDPartition(idx, split) : Partition }
- override def getParents(partitionId: Int) = List(partitions(partitionId).index)
+ override def getParents(partitionId: Int) = {
+ List(partitions(partitionId).asInstanceOf[PartitionPruningRDDPartition].parentSplit.index)
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 4cef0825dd..42bb3884c8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -19,9 +19,10 @@ package org.apache.spark.scheduler
import java.io.NotSerializableException
import java.util.Properties
-import java.util.concurrent.{LinkedBlockingQueue, TimeUnit}
import java.util.concurrent.atomic.AtomicInteger
+import akka.actor._
+import akka.util.duration._
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
import org.apache.spark._
@@ -65,12 +66,12 @@ class DAGScheduler(
// Called by TaskScheduler to report task's starting.
def taskStarted(task: Task[_], taskInfo: TaskInfo) {
- eventQueue.put(BeginEvent(task, taskInfo))
+ eventProcessActor ! BeginEvent(task, taskInfo)
}
// Called to report that a task has completed and results are being fetched remotely.
def taskGettingResult(task: Task[_], taskInfo: TaskInfo) {
- eventQueue.put(GettingResultEvent(task, taskInfo))
+ eventProcessActor ! GettingResultEvent(task, taskInfo)
}
// Called by TaskScheduler to report task completions or failures.
@@ -81,23 +82,23 @@ class DAGScheduler(
accumUpdates: Map[Long, Any],
taskInfo: TaskInfo,
taskMetrics: TaskMetrics) {
- eventQueue.put(CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics))
+ eventProcessActor ! CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics)
}
// Called by TaskScheduler when an executor fails.
def executorLost(execId: String) {
- eventQueue.put(ExecutorLost(execId))
+ eventProcessActor ! ExecutorLost(execId)
}
// Called by TaskScheduler when a host is added
def executorGained(execId: String, host: String) {
- eventQueue.put(ExecutorGained(execId, host))
+ eventProcessActor ! ExecutorGained(execId, host)
}
// Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or
// cancellation of the job itself.
def taskSetFailed(taskSet: TaskSet, reason: String) {
- eventQueue.put(TaskSetFailed(taskSet, reason))
+ eventProcessActor ! TaskSetFailed(taskSet, reason)
}
// The time, in millis, to wait for fetch failure events to stop coming in after one is detected;
@@ -109,7 +110,30 @@ class DAGScheduler(
// resubmit failed stages
val POLL_TIMEOUT = 10L
- private val eventQueue = new LinkedBlockingQueue[DAGSchedulerEvent]
+ private val eventProcessActor: ActorRef = env.actorSystem.actorOf(Props(new Actor {
+ override def preStart() {
+ context.system.scheduler.schedule(RESUBMIT_TIMEOUT milliseconds, RESUBMIT_TIMEOUT milliseconds) {
+ if (failed.size > 0) {
+ resubmitFailedStages()
+ }
+ }
+ }
+
+ /**
+ * The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure
+ * events and responds by launching tasks. This runs in a dedicated thread and receives events
+ * via the eventQueue.
+ */
+ def receive = {
+ case event: DAGSchedulerEvent =>
+ logDebug("Got event of type " + event.getClass.getName)
+
+ if (!processEvent(event))
+ submitWaitingStages()
+ else
+ context.stop(self)
+ }
+ }))
private[scheduler] val nextJobId = new AtomicInteger(0)
@@ -150,16 +174,6 @@ class DAGScheduler(
val metadataCleaner = new MetadataCleaner(MetadataCleanerType.DAG_SCHEDULER, this.cleanup)
- // Start a thread to run the DAGScheduler event loop
- def start() {
- new Thread("DAGScheduler") {
- setDaemon(true)
- override def run() {
- DAGScheduler.this.run()
- }
- }.start()
- }
-
def addSparkListener(listener: SparkListener) {
listenerBus.addListener(listener)
}
@@ -301,8 +315,7 @@ class DAGScheduler(
assert(partitions.size > 0)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler)
- eventQueue.put(JobSubmitted(jobId, rdd, func2, partitions.toArray, allowLocal, callSite,
- waiter, properties))
+ eventProcessActor ! JobSubmitted(jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties)
waiter
}
@@ -337,8 +350,7 @@ class DAGScheduler(
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val partitions = (0 until rdd.partitions.size).toArray
val jobId = nextJobId.getAndIncrement()
- eventQueue.put(JobSubmitted(jobId, rdd, func2, partitions, allowLocal = false, callSite,
- listener, properties))
+ eventProcessActor ! JobSubmitted(jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, properties)
listener.awaitResult() // Will throw an exception if the job fails
}
@@ -347,19 +359,19 @@ class DAGScheduler(
*/
def cancelJob(jobId: Int) {
logInfo("Asked to cancel job " + jobId)
- eventQueue.put(JobCancelled(jobId))
+ eventProcessActor ! JobCancelled(jobId)
}
def cancelJobGroup(groupId: String) {
logInfo("Asked to cancel job group " + groupId)
- eventQueue.put(JobGroupCancelled(groupId))
+ eventProcessActor ! JobGroupCancelled(groupId)
}
/**
* Cancel all jobs that are running or waiting in the queue.
*/
def cancelAllJobs() {
- eventQueue.put(AllJobsCancelled)
+ eventProcessActor ! AllJobsCancelled
}
/**
@@ -417,15 +429,14 @@ class DAGScheduler(
case ExecutorLost(execId) =>
handleExecutorLost(execId)
- case begin: BeginEvent =>
- listenerBus.post(SparkListenerTaskStart(begin.task, begin.taskInfo))
+ case BeginEvent(task, taskInfo) =>
+ listenerBus.post(SparkListenerTaskStart(task, taskInfo))
- case gettingResult: GettingResultEvent =>
- listenerBus.post(SparkListenerTaskGettingResult(gettingResult.task, gettingResult.taskInfo))
+ case GettingResultEvent(task, taskInfo) =>
+ listenerBus.post(SparkListenerTaskGettingResult(task, taskInfo))
- case completion: CompletionEvent =>
- listenerBus.post(SparkListenerTaskEnd(
- completion.task, completion.reason, completion.taskInfo, completion.taskMetrics))
+ case completion @ CompletionEvent(task, reason, _, _, taskInfo, taskMetrics) =>
+ listenerBus.post(SparkListenerTaskEnd(task, reason, taskInfo, taskMetrics))
handleTaskCompletion(completion)
case TaskSetFailed(taskSet, reason) =>
@@ -475,42 +486,6 @@ class DAGScheduler(
}
}
-
- /**
- * The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure
- * events and responds by launching tasks. This runs in a dedicated thread and receives events
- * via the eventQueue.
- */
- private def run() {
- SparkEnv.set(env)
-
- while (true) {
- val event = eventQueue.poll(POLL_TIMEOUT, TimeUnit.MILLISECONDS)
- if (event != null) {
- logDebug("Got event of type " + event.getClass.getName)
- }
- this.synchronized { // needed in case other threads makes calls into methods of this class
- if (event != null) {
- if (processEvent(event)) {
- return
- }
- }
-
- val time = System.currentTimeMillis() // TODO: use a pluggable clock for testability
- // Periodically resubmit failed stages if some map output fetches have failed and we have
- // waited at least RESUBMIT_TIMEOUT. We wait for this short time because when a node fails,
- // tasks on many other nodes are bound to get a fetch failure, and they won't all get it at
- // the same time, so we want to make sure we've identified all the reduce tasks that depend
- // on the failed node.
- if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) {
- resubmitFailedStages()
- } else {
- submitWaitingStages()
- }
- }
- }
- }
-
/**
* Run a job on an RDD locally, assuming it has only a single partition and no dependencies.
* We run the operation in a separate thread just in case it takes a bunch of time, so that we
@@ -879,7 +854,7 @@ class DAGScheduler(
// If the RDD has narrow dependencies, pick the first partition of the first narrow dep
// that has any placement preferences. Ideally we would choose based on transfer sizes,
// but this will do for now.
- rdd.dependencies.foreach(_ match {
+ rdd.dependencies.foreach {
case n: NarrowDependency[_] =>
for (inPart <- n.getParents(partition)) {
val locs = getPreferredLocs(n.rdd, inPart)
@@ -887,7 +862,7 @@ class DAGScheduler(
return locs
}
case _ =>
- })
+ }
Nil
}
@@ -910,7 +885,7 @@ class DAGScheduler(
}
def stop() {
- eventQueue.put(StopDAGScheduler)
+ eventProcessActor ! StopDAGScheduler
metadataCleaner.cancel()
taskSched.stop()
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
index 12b0d74fb5..60927831a1 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
@@ -1,280 +1,384 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.scheduler
-
-import java.io.PrintWriter
-import java.io.File
-import java.io.FileNotFoundException
-import java.text.SimpleDateFormat
-import java.util.{Date, Properties}
-import java.util.concurrent.LinkedBlockingQueue
-
-import scala.collection.mutable.{HashMap, ListBuffer}
-
-import org.apache.spark._
-import org.apache.spark.rdd.RDD
-import org.apache.spark.executor.TaskMetrics
-
-/**
- * A logger class to record runtime information for jobs in Spark. This class outputs one log file
- * per Spark job with information such as RDD graph, tasks start/stop, shuffle information.
- *
- * @param logDirName The base directory for the log files.
- */
-class JobLogger(val logDirName: String) extends SparkListener with Logging {
-
- private val logDir = Option(System.getenv("SPARK_LOG_DIR")).getOrElse("/tmp/spark")
-
- private val jobIDToPrintWriter = new HashMap[Int, PrintWriter]
- private val stageIDToJobID = new HashMap[Int, Int]
- private val jobIDToStages = new HashMap[Int, ListBuffer[Stage]]
- private val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
- private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents]
-
- createLogDir()
- def this() = this(String.valueOf(System.currentTimeMillis()))
-
- // The following 5 functions are used only in testing.
- private[scheduler] def getLogDir = logDir
- private[scheduler] def getJobIDtoPrintWriter = jobIDToPrintWriter
- private[scheduler] def getStageIDToJobID = stageIDToJobID
- private[scheduler] def getJobIDToStages = jobIDToStages
- private[scheduler] def getEventQueue = eventQueue
-
- // Create a folder for log files, the folder's name is the creation time of the jobLogger
- protected def createLogDir() {
- val dir = new File(logDir + "/" + logDirName + "/")
- if (!dir.exists() && !dir.mkdirs()) {
- logError("Error creating log directory: " + logDir + "/" + logDirName + "/")
- }
- }
-
- // Create a log file for one job, the file name is the jobID
- protected def createLogWriter(jobID: Int) {
- try {
- val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobID)
- jobIDToPrintWriter += (jobID -> fileWriter)
- } catch {
- case e: FileNotFoundException => e.printStackTrace()
- }
- }
-
- // Close log file, and clean the stage relationship in stageIDToJobID
- protected def closeLogWriter(jobID: Int) =
- jobIDToPrintWriter.get(jobID).foreach { fileWriter =>
- fileWriter.close()
- jobIDToStages.get(jobID).foreach(_.foreach{ stage =>
- stageIDToJobID -= stage.id
- })
- jobIDToPrintWriter -= jobID
- jobIDToStages -= jobID
- }
-
- // Write log information to log file, withTime parameter controls whether to recored
- // time stamp for the information
- protected def jobLogInfo(jobID: Int, info: String, withTime: Boolean = true) {
- var writeInfo = info
- if (withTime) {
- val date = new Date(System.currentTimeMillis())
- writeInfo = DATE_FORMAT.format(date) + ": " +info
- }
- jobIDToPrintWriter.get(jobID).foreach(_.println(writeInfo))
- }
-
- protected def stageLogInfo(stageID: Int, info: String, withTime: Boolean = true) =
- stageIDToJobID.get(stageID).foreach(jobID => jobLogInfo(jobID, info, withTime))
-
- protected def buildJobDep(jobID: Int, stage: Stage) {
- if (stage.jobId == jobID) {
- jobIDToStages.get(jobID) match {
- case Some(stageList) => stageList += stage
- case None => val stageList = new ListBuffer[Stage]
- stageList += stage
- jobIDToStages += (jobID -> stageList)
- }
- stageIDToJobID += (stage.id -> jobID)
- stage.parents.foreach(buildJobDep(jobID, _))
- }
- }
-
- protected def recordStageDep(jobID: Int) {
- def getRddsInStage(rdd: RDD[_]): ListBuffer[RDD[_]] = {
- var rddList = new ListBuffer[RDD[_]]
- rddList += rdd
- rdd.dependencies.foreach {
- case shufDep: ShuffleDependency[_, _] =>
- case dep: Dependency[_] => rddList ++= getRddsInStage(dep.rdd)
- }
- rddList
- }
- jobIDToStages.get(jobID).foreach {_.foreach { stage =>
- var depRddDesc: String = ""
- getRddsInStage(stage.rdd).foreach { rdd =>
- depRddDesc += rdd.id + ","
- }
- var depStageDesc: String = ""
- stage.parents.foreach { stage =>
- depStageDesc += "(" + stage.id + "," + stage.shuffleDep.get.shuffleId + ")"
- }
- jobLogInfo(jobID, "STAGE_ID=" + stage.id + " RDD_DEP=(" +
- depRddDesc.substring(0, depRddDesc.length - 1) + ")" +
- " STAGE_DEP=" + depStageDesc, false)
- }
- }
- }
-
- // Generate indents and convert to String
- protected def indentString(indent: Int) = {
- val sb = new StringBuilder()
- for (i <- 1 to indent) {
- sb.append(" ")
- }
- sb.toString()
- }
-
- protected def getRddName(rdd: RDD[_]) = {
- var rddName = rdd.getClass.getName
- if (rdd.name != null) {
- rddName = rdd.name
- }
- rddName
- }
-
- protected def recordRddInStageGraph(jobID: Int, rdd: RDD[_], indent: Int) {
- val rddInfo = "RDD_ID=" + rdd.id + "(" + getRddName(rdd) + "," + rdd.generator + ")"
- jobLogInfo(jobID, indentString(indent) + rddInfo, false)
- rdd.dependencies.foreach {
- case shufDep: ShuffleDependency[_, _] =>
- val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId
- jobLogInfo(jobID, indentString(indent + 1) + depInfo, false)
- case dep: Dependency[_] => recordRddInStageGraph(jobID, dep.rdd, indent + 1)
- }
- }
-
- protected def recordStageDepGraph(jobID: Int, stage: Stage, indent: Int = 0) {
- val stageInfo = if (stage.isShuffleMap) {
- "STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" + stage.shuffleDep.get.shuffleId
- } else {
- "STAGE_ID=" + stage.id + " RESULT_STAGE"
- }
- if (stage.jobId == jobID) {
- jobLogInfo(jobID, indentString(indent) + stageInfo, false)
- recordRddInStageGraph(jobID, stage.rdd, indent)
- stage.parents.foreach(recordStageDepGraph(jobID, _, indent + 2))
- } else {
- jobLogInfo(jobID, indentString(indent) + stageInfo + " JOB_ID=" + stage.jobId, false)
- }
- }
-
- // Record task metrics into job log files
- protected def recordTaskMetrics(stageID: Int, status: String,
- taskInfo: TaskInfo, taskMetrics: TaskMetrics) {
- val info = " TID=" + taskInfo.taskId + " STAGE_ID=" + stageID +
- " START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime +
- " EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname
- val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime
- val readMetrics = taskMetrics.shuffleReadMetrics match {
- case Some(metrics) =>
- " SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime +
- " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
- " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched +
- " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched +
- " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime +
- " REMOTE_FETCH_TIME=" + metrics.remoteFetchTime +
- " REMOTE_BYTES_READ=" + metrics.remoteBytesRead
- case None => ""
- }
- val writeMetrics = taskMetrics.shuffleWriteMetrics match {
- case Some(metrics) => " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten
- case None => ""
- }
- stageLogInfo(stageID, status + info + executorRunTime + readMetrics + writeMetrics)
- }
-
- override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
- stageLogInfo(stageSubmitted.stage.stageId,"STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format(
- stageSubmitted.stage.stageId, stageSubmitted.stage.numTasks))
- }
-
- override def onStageCompleted(stageCompleted: StageCompleted) {
- stageLogInfo(stageCompleted.stage.stageId, "STAGE_ID=%d STATUS=COMPLETED".format(
- stageCompleted.stage.stageId))
- }
-
- override def onTaskStart(taskStart: SparkListenerTaskStart) { }
-
- override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
- val task = taskEnd.task
- val taskInfo = taskEnd.taskInfo
- var taskStatus = ""
- task match {
- case resultTask: ResultTask[_, _] => taskStatus = "TASK_TYPE=RESULT_TASK"
- case shuffleMapTask: ShuffleMapTask => taskStatus = "TASK_TYPE=SHUFFLE_MAP_TASK"
- }
- taskEnd.reason match {
- case Success => taskStatus += " STATUS=SUCCESS"
- recordTaskMetrics(task.stageId, taskStatus, taskInfo, taskEnd.taskMetrics)
- case Resubmitted =>
- taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId +
- " STAGE_ID=" + task.stageId
- stageLogInfo(task.stageId, taskStatus)
- case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
- taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" +
- task.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" +
- mapId + " REDUCE_ID=" + reduceId
- stageLogInfo(task.stageId, taskStatus)
- case OtherFailure(message) =>
- taskStatus += " STATUS=FAILURE TID=" + taskInfo.taskId +
- " STAGE_ID=" + task.stageId + " INFO=" + message
- stageLogInfo(task.stageId, taskStatus)
- case _ =>
- }
- }
-
- override def onJobEnd(jobEnd: SparkListenerJobEnd) {
- val job = jobEnd.job
- var info = "JOB_ID=" + job.jobId
- jobEnd.jobResult match {
- case JobSucceeded => info += " STATUS=SUCCESS"
- case JobFailed(exception, _) =>
- info += " STATUS=FAILED REASON="
- exception.getMessage.split("\\s+").foreach(info += _ + "_")
- case _ =>
- }
- jobLogInfo(job.jobId, info.substring(0, info.length - 1).toUpperCase)
- closeLogWriter(job.jobId)
- }
-
- protected def recordJobProperties(jobID: Int, properties: Properties) {
- if(properties != null) {
- val description = properties.getProperty(SparkContext.SPARK_JOB_DESCRIPTION, "")
- jobLogInfo(jobID, description, false)
- }
- }
-
- override def onJobStart(jobStart: SparkListenerJobStart) {
- val job = jobStart.job
- val properties = jobStart.properties
- createLogWriter(job.jobId)
- recordJobProperties(job.jobId, properties)
- buildJobDep(job.jobId, job.finalStage)
- recordStageDep(job.jobId)
- recordStageDepGraph(job.jobId, job.finalStage)
- jobLogInfo(job.jobId, "JOB_ID=" + job.jobId + " STATUS=STARTED")
- }
-}
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.io.{IOException, File, FileNotFoundException, PrintWriter}
+import java.text.SimpleDateFormat
+import java.util.{Date, Properties}
+import java.util.concurrent.LinkedBlockingQueue
+
+import scala.collection.mutable.{HashMap, HashSet, ListBuffer}
+
+import org.apache.spark._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.storage.StorageLevel
+
+/**
+ * A logger class to record runtime information for jobs in Spark. This class outputs one log file
+ * for each Spark job, containing RDD graph, tasks start/stop, shuffle information.
+ * JobLogger is a subclass of SparkListener, use addSparkListener to add JobLogger to a SparkContext
+ * after the SparkContext is created.
+ * Note that each JobLogger only works for one SparkContext
+ * @param logDirName The base directory for the log files.
+ */
+
+class JobLogger(val user: String, val logDirName: String)
+ extends SparkListener with Logging {
+
+ def this() = this(System.getProperty("user.name", "<unknown>"),
+ String.valueOf(System.currentTimeMillis()))
+
+ private val logDir =
+ if (System.getenv("SPARK_LOG_DIR") != null)
+ System.getenv("SPARK_LOG_DIR")
+ else
+ "/tmp/spark-%s".format(user)
+
+ private val jobIDToPrintWriter = new HashMap[Int, PrintWriter]
+ private val stageIDToJobID = new HashMap[Int, Int]
+ private val jobIDToStages = new HashMap[Int, ListBuffer[Stage]]
+ private val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
+ private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents]
+
+ createLogDir()
+
+ // The following 5 functions are used only in testing.
+ private[scheduler] def getLogDir = logDir
+ private[scheduler] def getJobIDtoPrintWriter = jobIDToPrintWriter
+ private[scheduler] def getStageIDToJobID = stageIDToJobID
+ private[scheduler] def getJobIDToStages = jobIDToStages
+ private[scheduler] def getEventQueue = eventQueue
+
+ /** Create a folder for log files, the folder's name is the creation time of jobLogger */
+ protected def createLogDir() {
+ val dir = new File(logDir + "/" + logDirName + "/")
+ if (dir.exists()) {
+ return
+ }
+ if (dir.mkdirs() == false) {
+ // JobLogger should throw a exception rather than continue to construct this object.
+ throw new IOException("create log directory error:" + logDir + "/" + logDirName + "/")
+ }
+ }
+
+ /**
+ * Create a log file for one job
+ * @param jobID ID of the job
+ * @exception FileNotFoundException Fail to create log file
+ */
+ protected def createLogWriter(jobID: Int) {
+ try {
+ val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobID)
+ jobIDToPrintWriter += (jobID -> fileWriter)
+ } catch {
+ case e: FileNotFoundException => e.printStackTrace()
+ }
+ }
+
+ /**
+ * Close log file, and clean the stage relationship in stageIDToJobID
+ * @param jobID ID of the job
+ */
+ protected def closeLogWriter(jobID: Int) {
+ jobIDToPrintWriter.get(jobID).foreach { fileWriter =>
+ fileWriter.close()
+ jobIDToStages.get(jobID).foreach(_.foreach{ stage =>
+ stageIDToJobID -= stage.id
+ })
+ jobIDToPrintWriter -= jobID
+ jobIDToStages -= jobID
+ }
+ }
+
+ /**
+ * Write info into log file
+ * @param jobID ID of the job
+ * @param info Info to be recorded
+ * @param withTime Controls whether to record time stamp before the info, default is true
+ */
+ protected def jobLogInfo(jobID: Int, info: String, withTime: Boolean = true) {
+ var writeInfo = info
+ if (withTime) {
+ val date = new Date(System.currentTimeMillis())
+ writeInfo = DATE_FORMAT.format(date) + ": " +info
+ }
+ jobIDToPrintWriter.get(jobID).foreach(_.println(writeInfo))
+ }
+
+ /**
+ * Write info into log file
+ * @param stageID ID of the stage
+ * @param info Info to be recorded
+ * @param withTime Controls whether to record time stamp before the info, default is true
+ */
+ protected def stageLogInfo(stageID: Int, info: String, withTime: Boolean = true) {
+ stageIDToJobID.get(stageID).foreach(jobID => jobLogInfo(jobID, info, withTime))
+ }
+
+ /**
+ * Build stage dependency for a job
+ * @param jobID ID of the job
+ * @param stage Root stage of the job
+ */
+ protected def buildJobDep(jobID: Int, stage: Stage) {
+ if (stage.jobId == jobID) {
+ jobIDToStages.get(jobID) match {
+ case Some(stageList) => stageList += stage
+ case None => val stageList = new ListBuffer[Stage]
+ stageList += stage
+ jobIDToStages += (jobID -> stageList)
+ }
+ stageIDToJobID += (stage.id -> jobID)
+ stage.parents.foreach(buildJobDep(jobID, _))
+ }
+ }
+
+ /**
+ * Record stage dependency and RDD dependency for a stage
+ * @param jobID Job ID of the stage
+ */
+ protected def recordStageDep(jobID: Int) {
+ def getRddsInStage(rdd: RDD[_]): ListBuffer[RDD[_]] = {
+ var rddList = new ListBuffer[RDD[_]]
+ rddList += rdd
+ rdd.dependencies.foreach {
+ case shufDep: ShuffleDependency[_, _] =>
+ case dep: Dependency[_] => rddList ++= getRddsInStage(dep.rdd)
+ }
+ rddList
+ }
+ jobIDToStages.get(jobID).foreach {_.foreach { stage =>
+ var depRddDesc: String = ""
+ getRddsInStage(stage.rdd).foreach { rdd =>
+ depRddDesc += rdd.id + ","
+ }
+ var depStageDesc: String = ""
+ stage.parents.foreach { stage =>
+ depStageDesc += "(" + stage.id + "," + stage.shuffleDep.get.shuffleId + ")"
+ }
+ jobLogInfo(jobID, "STAGE_ID=" + stage.id + " RDD_DEP=(" +
+ depRddDesc.substring(0, depRddDesc.length - 1) + ")" +
+ " STAGE_DEP=" + depStageDesc, false)
+ }
+ }
+ }
+
+ /**
+ * Generate indents and convert to String
+ * @param indent Number of indents
+ * @return string of indents
+ */
+ protected def indentString(indent: Int): String = {
+ val sb = new StringBuilder()
+ for (i <- 1 to indent) {
+ sb.append(" ")
+ }
+ sb.toString()
+ }
+
+ /**
+ * Get RDD's name
+ * @param rdd Input RDD
+ * @return String of RDD's name
+ */
+ protected def getRddName(rdd: RDD[_]): String = {
+ var rddName = rdd.getClass.getSimpleName
+ if (rdd.name != null) {
+ rddName = rdd.name
+ }
+ rddName
+ }
+
+ /**
+ * Record RDD dependency graph in a stage
+ * @param jobID Job ID of the stage
+ * @param rdd Root RDD of the stage
+ * @param indent Indent number before info
+ */
+ protected def recordRddInStageGraph(jobID: Int, rdd: RDD[_], indent: Int) {
+ val rddInfo =
+ if (rdd.getStorageLevel != StorageLevel.NONE) {
+ "RDD_ID=" + rdd.id + " " + getRddName(rdd) + " CACHED" + " " +
+ rdd.origin + " " + rdd.generator
+ } else {
+ "RDD_ID=" + rdd.id + " " + getRddName(rdd) + " NONE" + " " +
+ rdd.origin + " " + rdd.generator
+ }
+ jobLogInfo(jobID, indentString(indent) + rddInfo, false)
+ rdd.dependencies.foreach {
+ case shufDep: ShuffleDependency[_, _] =>
+ val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId
+ jobLogInfo(jobID, indentString(indent + 1) + depInfo, false)
+ case dep: Dependency[_] => recordRddInStageGraph(jobID, dep.rdd, indent + 1)
+ }
+ }
+
+ /**
+ * Record stage dependency graph of a job
+ * @param jobID Job ID of the stage
+ * @param stage Root stage of the job
+ * @param indent Indent number before info, default is 0
+ */
+ protected def recordStageDepGraph(jobID: Int, stage: Stage, idSet: HashSet[Int], indent: Int = 0) {
+ val stageInfo = if (stage.isShuffleMap) {
+ "STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" + stage.shuffleDep.get.shuffleId
+ } else {
+ "STAGE_ID=" + stage.id + " RESULT_STAGE"
+ }
+ if (stage.jobId == jobID) {
+ jobLogInfo(jobID, indentString(indent) + stageInfo, false)
+ if (!idSet.contains(stage.id)) {
+ idSet += stage.id
+ recordRddInStageGraph(jobID, stage.rdd, indent)
+ stage.parents.foreach(recordStageDepGraph(jobID, _, idSet, indent + 2))
+ }
+ } else {
+ jobLogInfo(jobID, indentString(indent) + stageInfo + " JOB_ID=" + stage.jobId, false)
+ }
+ }
+
+ /**
+ * Record task metrics into job log files, including execution info and shuffle metrics
+ * @param stageID Stage ID of the task
+ * @param status Status info of the task
+ * @param taskInfo Task description info
+ * @param taskMetrics Task running metrics
+ */
+ protected def recordTaskMetrics(stageID: Int, status: String,
+ taskInfo: TaskInfo, taskMetrics: TaskMetrics) {
+ val info = " TID=" + taskInfo.taskId + " STAGE_ID=" + stageID +
+ " START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime +
+ " EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname
+ val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime
+ val readMetrics = taskMetrics.shuffleReadMetrics match {
+ case Some(metrics) =>
+ " SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime +
+ " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
+ " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched +
+ " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched +
+ " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime +
+ " REMOTE_FETCH_TIME=" + metrics.remoteFetchTime +
+ " REMOTE_BYTES_READ=" + metrics.remoteBytesRead
+ case None => ""
+ }
+ val writeMetrics = taskMetrics.shuffleWriteMetrics match {
+ case Some(metrics) => " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten
+ case None => ""
+ }
+ stageLogInfo(stageID, status + info + executorRunTime + readMetrics + writeMetrics)
+ }
+
+ /**
+ * When stage is submitted, record stage submit info
+ * @param stageSubmitted Stage submitted event
+ */
+ override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
+ stageLogInfo(stageSubmitted.stage.stageId,"STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format(
+ stageSubmitted.stage.stageId, stageSubmitted.stage.numTasks))
+ }
+
+ /**
+ * When stage is completed, record stage completion status
+ * @param stageCompleted Stage completed event
+ */
+ override def onStageCompleted(stageCompleted: StageCompleted) {
+ stageLogInfo(stageCompleted.stage.stageId, "STAGE_ID=%d STATUS=COMPLETED".format(
+ stageCompleted.stage.stageId))
+ }
+
+ override def onTaskStart(taskStart: SparkListenerTaskStart) { }
+
+ /**
+ * When task ends, record task completion status and metrics
+ * @param taskEnd Task end event
+ */
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
+ val task = taskEnd.task
+ val taskInfo = taskEnd.taskInfo
+ var taskStatus = ""
+ task match {
+ case resultTask: ResultTask[_, _] => taskStatus = "TASK_TYPE=RESULT_TASK"
+ case shuffleMapTask: ShuffleMapTask => taskStatus = "TASK_TYPE=SHUFFLE_MAP_TASK"
+ }
+ taskEnd.reason match {
+ case Success => taskStatus += " STATUS=SUCCESS"
+ recordTaskMetrics(task.stageId, taskStatus, taskInfo, taskEnd.taskMetrics)
+ case Resubmitted =>
+ taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId +
+ " STAGE_ID=" + task.stageId
+ stageLogInfo(task.stageId, taskStatus)
+ case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
+ taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" +
+ task.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" +
+ mapId + " REDUCE_ID=" + reduceId
+ stageLogInfo(task.stageId, taskStatus)
+ case OtherFailure(message) =>
+ taskStatus += " STATUS=FAILURE TID=" + taskInfo.taskId +
+ " STAGE_ID=" + task.stageId + " INFO=" + message
+ stageLogInfo(task.stageId, taskStatus)
+ case _ =>
+ }
+ }
+
+ /**
+ * When job ends, recording job completion status and close log file
+ * @param jobEnd Job end event
+ */
+ override def onJobEnd(jobEnd: SparkListenerJobEnd) {
+ val job = jobEnd.job
+ var info = "JOB_ID=" + job.jobId
+ jobEnd.jobResult match {
+ case JobSucceeded => info += " STATUS=SUCCESS"
+ case JobFailed(exception, _) =>
+ info += " STATUS=FAILED REASON="
+ exception.getMessage.split("\\s+").foreach(info += _ + "_")
+ case _ =>
+ }
+ jobLogInfo(job.jobId, info.substring(0, info.length - 1).toUpperCase)
+ closeLogWriter(job.jobId)
+ }
+
+ /**
+ * Record job properties into job log file
+ * @param jobID ID of the job
+ * @param properties Properties of the job
+ */
+ protected def recordJobProperties(jobID: Int, properties: Properties) {
+ if(properties != null) {
+ val description = properties.getProperty(SparkContext.SPARK_JOB_DESCRIPTION, "")
+ jobLogInfo(jobID, description, false)
+ }
+ }
+
+ /**
+ * When job starts, record job property and stage graph
+ * @param jobStart Job start event
+ */
+ override def onJobStart(jobStart: SparkListenerJobStart) {
+ val job = jobStart.job
+ val properties = jobStart.properties
+ createLogWriter(job.jobId)
+ recordJobProperties(job.jobId, properties)
+ buildJobDep(job.jobId, job.finalStage)
+ recordStageDep(job.jobId)
+ recordStageDepGraph(job.jobId, job.finalStage, new HashSet[Int])
+ jobLogInfo(job.jobId, "JOB_ID=" + job.jobId + " STATUS=STARTED")
+ }
+}
+
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index 1dc71a0428..0f2deb4bcb 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -167,6 +167,7 @@ private[spark] class ShuffleMapTask(
var totalTime = 0L
val compressedSizes: Array[Byte] = shuffle.writers.map { writer: BlockObjectWriter =>
writer.commit()
+ writer.close()
val size = writer.fileSegment().length
totalBytes += size
totalTime += writer.timeWriting()
@@ -184,14 +185,16 @@ private[spark] class ShuffleMapTask(
} catch { case e: Exception =>
// If there is an exception from running the task, revert the partial writes
// and throw the exception upstream to Spark.
- if (shuffle != null) {
- shuffle.writers.foreach(_.revertPartialWrites())
+ if (shuffle != null && shuffle.writers != null) {
+ for (writer <- shuffle.writers) {
+ writer.revertPartialWrites()
+ writer.close()
+ }
}
throw e
} finally {
// Release the writers back to the shuffle block manager.
if (shuffle != null && shuffle.writers != null) {
- shuffle.writers.foreach(_.close())
shuffle.releaseWriters(success)
}
// Execute the callbacks on task completion.
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
index 85033958ef..c1e65a3c48 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
@@ -25,6 +25,8 @@ import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
+import akka.util.duration._
+
import org.apache.spark._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.scheduler._
@@ -119,21 +121,12 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
backend.start()
if (System.getProperty("spark.speculation", "false").toBoolean) {
- new Thread("ClusterScheduler speculation check") {
- setDaemon(true)
-
- override def run() {
- logInfo("Starting speculative execution thread")
- while (true) {
- try {
- Thread.sleep(SPECULATION_INTERVAL)
- } catch {
- case e: InterruptedException => {}
- }
- checkSpeculatableTasks()
- }
- }
- }.start()
+ logInfo("Starting speculative execution thread")
+
+ sc.env.actorSystem.scheduler.schedule(SPECULATION_INTERVAL milliseconds,
+ SPECULATION_INTERVAL milliseconds) {
+ checkSpeculatableTasks()
+ }
}
}
@@ -256,7 +249,6 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
var failedExecutor: Option[String] = None
- var taskFailed = false
synchronized {
try {
if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) {
@@ -276,9 +268,6 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
taskIdToExecutorId.remove(tid)
}
- if (state == TaskState.FAILED) {
- taskFailed = true
- }
activeTaskSets.get(taskSetId).foreach { taskSet =>
if (state == TaskState.FINISHED) {
taskSet.removeRunningTask(tid)
@@ -300,10 +289,6 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
dagScheduler.executorLost(failedExecutor.get)
backend.reviveOffers()
}
- if (taskFailed) {
- // Also revive offers if a task had failed for some reason other than host lost
- backend.reviveOffers()
- }
}
def handleTaskGettingResult(taskSetManager: ClusterTaskSetManager, tid: Long) {
@@ -323,8 +308,9 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
taskState: TaskState,
reason: Option[TaskEndReason]) = synchronized {
taskSetManager.handleFailedTask(tid, taskState, reason)
- if (taskState == TaskState.FINISHED) {
- // The task finished successfully but the result was lost, so we should revive offers.
+ if (taskState != TaskState.KILLED) {
+ // Need to revive offers again now that the task set manager state has been updated to
+ // reflect failed tasks that need to be re-run.
backend.reviveOffers()
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
index ee47aaffca..4c5eca8537 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
@@ -17,6 +17,7 @@
package org.apache.spark.scheduler.cluster
+import java.io.NotSerializableException
import java.util.Arrays
import scala.collection.mutable.ArrayBuffer
@@ -484,6 +485,14 @@ private[spark] class ClusterTaskSetManager(
case ef: ExceptionFailure =>
sched.dagScheduler.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null))
+ if (ef.className == classOf[NotSerializableException].getName()) {
+ // If the task result wasn't serializable, there's no point in trying to re-execute it.
+ logError("Task %s:%s had a not serializable result: %s; not retrying".format(
+ taskSet.id, index, ef.description))
+ abort("Task %s:%s had a not serializable result: %s".format(
+ taskSet.id, index, ef.description))
+ return
+ }
val key = ef.description
val now = clock.getTime()
val (printFull, dupCount) = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 70f3f88401..d0ba5bf55d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -87,8 +87,14 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac
case StatusUpdate(executorId, taskId, state, data) =>
scheduler.statusUpdate(taskId, state, data.value)
if (TaskState.isFinished(state)) {
- freeCores(executorId) += 1
- makeOffers(executorId)
+ if (executorActor.contains(executorId)) {
+ freeCores(executorId) += 1
+ makeOffers(executorId)
+ } else {
+ // Ignoring the update since we don't know about the executor.
+ val msg = "Ignored task status update (%d state %s) from unknown executor %s with ID %s"
+ logWarning(msg.format(taskId, state, sender, executorId))
+ }
}
case ReviveOffers =>
@@ -175,7 +181,9 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac
Props(new DriverActor(properties)), name = CoarseGrainedSchedulerBackend.ACTOR_NAME)
}
- private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
+ private val timeout = {
+ Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
+ }
def stopExecutors() {
try {
@@ -191,6 +199,7 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac
}
override def stop() {
+ stopExecutors()
try {
if (driverActor != null) {
val future = driverActor.ask(StopDriver)(timeout)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
index d78bdbaa7a..e000531a26 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
@@ -45,11 +45,13 @@ private[spark] class SimrSchedulerBackend(
logInfo("Writing to HDFS file: " + driverFilePath)
logInfo("Writing Akka address: " + driverUrl)
+ logInfo("Writing Spark UI Address: " + sc.ui.appUIAddress)
// Create temporary file to prevent race condition where executors get empty driverUrl file
val temp = fs.create(tmpPath, true)
temp.writeUTF(driverUrl)
temp.writeInt(maxCores)
+ temp.writeUTF(sc.ui.appUIAddress)
temp.close()
// "Atomic" rename
@@ -60,7 +62,6 @@ private[spark] class SimrSchedulerBackend(
val conf = new Configuration()
val fs = FileSystem.get(conf)
fs.delete(new Path(driverFilePath), false)
- super.stopExecutors()
super.stop()
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
index 300fe693f1..cd521e0f2b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
@@ -181,6 +181,7 @@ private[spark] class CoarseMesosSchedulerBackend(
!slaveIdsWithExecutors.contains(slaveId)) {
// Launch an executor on the slave
val cpusToUse = math.min(cpus, maxCores - totalCoresAcquired)
+ totalCoresAcquired += cpusToUse
val taskId = newMesosTaskId()
taskIdToSlaveId(taskId) = slaveId
slaveIdsWithExecutors += slaveId
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index 55b25f145a..e748c2275d 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -27,13 +27,17 @@ import com.twitter.chill.{EmptyScalaKryoInstantiator, AllScalaRegistrar}
import org.apache.spark.{SerializableWritable, Logging}
import org.apache.spark.broadcast.HttpBroadcast
-import org.apache.spark.storage.{GetBlock,GotBlock, PutBlock, StorageLevel, TestBlockId}
+import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.storage._
/**
- * A Spark serializer that uses the [[http://code.google.com/p/kryo/wiki/V1Documentation Kryo 1.x library]].
+ * A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]].
*/
class KryoSerializer extends org.apache.spark.serializer.Serializer with Logging {
- private val bufferSize = System.getProperty("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024
+
+ private val bufferSize = {
+ System.getProperty("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024
+ }
def newKryoOutput() = new KryoOutput(bufferSize)
@@ -42,21 +46,11 @@ class KryoSerializer extends org.apache.spark.serializer.Serializer with Logging
val kryo = instantiator.newKryo()
val classLoader = Thread.currentThread.getContextClassLoader
- val blockId = TestBlockId("1")
- // Register some commonly used classes
- val toRegister: Seq[AnyRef] = Seq(
- ByteBuffer.allocate(1),
- StorageLevel.MEMORY_ONLY,
- PutBlock(blockId, ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY),
- GotBlock(blockId, ByteBuffer.allocate(1)),
- GetBlock(blockId),
- 1 to 10,
- 1 until 10,
- 1L to 10L,
- 1L until 10L
- )
-
- for (obj <- toRegister) kryo.register(obj.getClass)
+ // Allow disabling Kryo reference tracking if user knows their object graphs don't have loops.
+ // Do this before we invoke the user registrator so the user registrator can override this.
+ kryo.setReferences(System.getProperty("spark.kryo.referenceTracking", "true").toBoolean)
+
+ for (cls <- KryoSerializer.toRegister) kryo.register(cls)
// Allow sending SerializableWritable
kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer())
@@ -78,10 +72,6 @@ class KryoSerializer extends org.apache.spark.serializer.Serializer with Logging
new AllScalaRegistrar().apply(kryo)
kryo.setClassLoader(classLoader)
-
- // Allow disabling Kryo reference tracking if user knows their object graphs don't have loops
- kryo.setReferences(System.getProperty("spark.kryo.referenceTracking", "true").toBoolean)
-
kryo
}
@@ -165,3 +155,21 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ
trait KryoRegistrator {
def registerClasses(kryo: Kryo)
}
+
+private[serializer] object KryoSerializer {
+ // Commonly used classes.
+ private val toRegister: Seq[Class[_]] = Seq(
+ ByteBuffer.allocate(1).getClass,
+ classOf[StorageLevel],
+ classOf[PutBlock],
+ classOf[GotBlock],
+ classOf[GetBlock],
+ classOf[MapStatus],
+ classOf[BlockManagerId],
+ classOf[Array[Byte]],
+ (1 to 10).getClass,
+ (1 until 10).getClass,
+ (1L to 10L).getClass,
+ (1L until 10L).getClass
+ )
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala
index dbe0bda615..c8f397609a 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala
@@ -19,9 +19,7 @@ package org.apache.spark.storage
import java.util.concurrent.ConcurrentHashMap
-private[storage] trait BlockInfo {
- def level: StorageLevel
- def tellMaster: Boolean
+private[storage] class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) {
// To save space, 'pending' and 'failed' are encoded as special sizes:
@volatile var size: Long = BlockInfo.BLOCK_PENDING
private def pending: Boolean = size == BlockInfo.BLOCK_PENDING
@@ -81,17 +79,3 @@ private object BlockInfo {
private val BLOCK_PENDING: Long = -1L
private val BLOCK_FAILED: Long = -2L
}
-
-// All shuffle blocks have the same `level` and `tellMaster` properties,
-// so we can save space by not storing them in each instance:
-private[storage] class ShuffleBlockInfo extends BlockInfo {
- // These need to be defined using 'def' instead of 'val' in order for
- // the compiler to eliminate the fields:
- def level: StorageLevel = StorageLevel.DISK_ONLY
- def tellMaster: Boolean = false
-}
-
-private[storage] class BlockInfoImpl(val level: StorageLevel, val tellMaster: Boolean)
- extends BlockInfo {
- // Intentionally left blank
-} \ No newline at end of file
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index fbedfbc446..702aca8323 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -465,13 +465,7 @@ private[spark] class BlockManager(
def getDiskWriter(blockId: BlockId, file: File, serializer: Serializer, bufferSize: Int)
: BlockObjectWriter = {
val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _)
- val writer = new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream)
- writer.registerCloseEventHandler(() => {
- val myInfo = new ShuffleBlockInfo()
- blockInfo.put(blockId, myInfo)
- myInfo.markReady(writer.fileSegment().length)
- })
- writer
+ new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream)
}
/**
@@ -501,7 +495,7 @@ private[spark] class BlockManager(
// to be dropped right after it got put into memory. Note, however, that other threads will
// not be able to get() this block until we call markReady on its BlockInfo.
val myInfo = {
- val tinfo = new BlockInfoImpl(level, tellMaster)
+ val tinfo = new BlockInfo(level, tellMaster)
// Do atomically !
val oldBlockOpt = blockInfo.putIfAbsent(blockId, tinfo)
@@ -897,9 +891,9 @@ private[spark] object BlockManager extends Logging {
blockManagerMaster: BlockManagerMaster = null)
: Map[BlockId, Seq[BlockManagerId]] =
{
- // env == null and blockManagerMaster != null is used in tests
+ // blockManagerMaster != null is used in tests
assert (env != null || blockManagerMaster != null)
- val blockLocations: Seq[Seq[BlockManagerId]] = if (env != null) {
+ val blockLocations: Seq[Seq[BlockManagerId]] = if (blockManagerMaster == null) {
env.blockManager.getLocationBlockIds(blockIds)
} else {
blockManagerMaster.getLocations(blockIds)
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
index e49c191c70..b4451fc7b8 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
@@ -34,20 +34,12 @@ import org.apache.spark.serializer.{SerializationStream, Serializer}
*/
abstract class BlockObjectWriter(val blockId: BlockId) {
- var closeEventHandler: () => Unit = _
-
def open(): BlockObjectWriter
- def close() {
- closeEventHandler()
- }
+ def close()
def isOpen: Boolean
- def registerCloseEventHandler(handler: () => Unit) {
- closeEventHandler = handler
- }
-
/**
* Flush the partial writes and commit them as a single atomic block. Return the
* number of bytes written for this commit.
@@ -101,6 +93,8 @@ class DiskBlockObjectWriter(
def write(i: Int): Unit = callWithTiming(out.write(i))
override def write(b: Array[Byte]) = callWithTiming(out.write(b))
override def write(b: Array[Byte], off: Int, len: Int) = callWithTiming(out.write(b, off, len))
+ override def close() = out.close()
+ override def flush() = out.flush()
}
private val syncWrites = System.getProperty("spark.shuffle.sync", "false").toBoolean
@@ -146,8 +140,6 @@ class DiskBlockObjectWriter(
ts = null
objOut = null
}
- // Invoke the close callback handler.
- super.close()
}
override def isOpen: Boolean = objOut != null
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala
index 42e9be6e19..e596690bc3 100644
--- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala
@@ -76,7 +76,7 @@ private[spark] class ExecutorsUI(val sc: SparkContext) {
</tr>
}
- val execInfo = for (b <- 0 until storageStatusList.size) yield getExecInfo(b)
+ val execInfo = for (statusId <- 0 until storageStatusList.size) yield getExecInfo(statusId)
val execTable = UIUtils.listingTable(execHead, execRow, execInfo)
val content =
@@ -99,16 +99,17 @@ private[spark] class ExecutorsUI(val sc: SparkContext) {
UIUtils.headerSparkPage(content, sc, "Executors (" + execInfo.size + ")", Executors)
}
- def getExecInfo(a: Int): Seq[String] = {
- val execId = sc.getExecutorStorageStatus(a).blockManagerId.executorId
- val hostPort = sc.getExecutorStorageStatus(a).blockManagerId.hostPort
- val rddBlocks = sc.getExecutorStorageStatus(a).blocks.size.toString
- val memUsed = sc.getExecutorStorageStatus(a).memUsed().toString
- val maxMem = sc.getExecutorStorageStatus(a).maxMem.toString
- val diskUsed = sc.getExecutorStorageStatus(a).diskUsed().toString
- val activeTasks = listener.executorToTasksActive.get(a.toString).map(l => l.size).getOrElse(0)
- val failedTasks = listener.executorToTasksFailed.getOrElse(a.toString, 0)
- val completedTasks = listener.executorToTasksComplete.getOrElse(a.toString, 0)
+ def getExecInfo(statusId: Int): Seq[String] = {
+ val status = sc.getExecutorStorageStatus(statusId)
+ val execId = status.blockManagerId.executorId
+ val hostPort = status.blockManagerId.hostPort
+ val rddBlocks = status.blocks.size.toString
+ val memUsed = status.memUsed().toString
+ val maxMem = status.maxMem.toString
+ val diskUsed = status.diskUsed().toString
+ val activeTasks = listener.executorToTasksActive.getOrElse(execId, HashSet.empty[Long]).size
+ val failedTasks = listener.executorToTasksFailed.getOrElse(execId, 0)
+ val completedTasks = listener.executorToTasksComplete.getOrElse(execId, 0)
val totalTasks = activeTasks + failedTasks + completedTasks
Seq(
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index 35b5d5fd59..fbd822867f 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -133,7 +133,7 @@ private[spark] class StagePage(parent: JobProgressUI) {
summary ++
<h4>Summary Metrics for {numCompleted} Completed Tasks</h4> ++
<div>{summaryTable.getOrElse("No tasks have reported metrics yet.")}</div> ++
- <h4>Tasks</h4> ++ taskTable;
+ <h4>Tasks</h4> ++ taskTable
headerSparkPage(content, parent.sc, "Details for Stage %d".format(stageId), Stages)
}
@@ -152,6 +152,22 @@ private[spark] class StagePage(parent: JobProgressUI) {
else metrics.map(m => parent.formatDuration(m.executorRunTime)).getOrElse("")
val gcTime = metrics.map(m => m.jvmGCTime).getOrElse(0L)
+ var shuffleReadSortable: String = ""
+ var shuffleReadReadable: String = ""
+ if (shuffleRead) {
+ shuffleReadSortable = metrics.flatMap{m => m.shuffleReadMetrics}.map{s => s.remoteBytesRead}.toString()
+ shuffleReadReadable = metrics.flatMap{m => m.shuffleReadMetrics}.map{s =>
+ Utils.bytesToString(s.remoteBytesRead)}.getOrElse("")
+ }
+
+ var shuffleWriteSortable: String = ""
+ var shuffleWriteReadable: String = ""
+ if (shuffleWrite) {
+ shuffleWriteSortable = metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => s.shuffleBytesWritten}.toString()
+ shuffleWriteReadable = metrics.flatMap{m => m.shuffleWriteMetrics}.map{s =>
+ Utils.bytesToString(s.shuffleBytesWritten)}.getOrElse("")
+ }
+
<tr>
<td>{info.index}</td>
<td>{info.taskId}</td>
@@ -166,14 +182,17 @@ private[spark] class StagePage(parent: JobProgressUI) {
{if (gcTime > 0) parent.formatDuration(gcTime) else ""}
</td>
{if (shuffleRead) {
- <td>{metrics.flatMap{m => m.shuffleReadMetrics}.map{s =>
- Utils.bytesToString(s.remoteBytesRead)}.getOrElse("")}</td>
+ <td sorttable_customkey={shuffleReadSortable}>
+ {shuffleReadReadable}
+ </td>
}}
{if (shuffleWrite) {
- <td>{metrics.flatMap{m => m.shuffleWriteMetrics}.map{s =>
- parent.formatDuration(s.shuffleWriteTime / (1000 * 1000))}.getOrElse("")}</td>
- <td>{metrics.flatMap{m => m.shuffleWriteMetrics}.map{s =>
- Utils.bytesToString(s.shuffleBytesWritten)}.getOrElse("")}</td>
+ <td>{metrics.flatMap{m => m.shuffleWriteMetrics}.map{s =>
+ parent.formatDuration(s.shuffleWriteTime / (1000 * 1000))}.getOrElse("")}
+ </td>
+ <td sorttable_customkey={shuffleWriteSortable}>
+ {shuffleWriteReadable}
+ </td>
}}
<td>{exception.map(e =>
<span>
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
index d7d0441c38..9ad6de3c6d 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
@@ -79,11 +79,14 @@ private[spark] class StageTable(val stages: Seq[StageInfo], val parent: JobProgr
case None => "Unknown"
}
- val shuffleRead = listener.stageIdToShuffleRead.getOrElse(s.stageId, 0L) match {
+ val shuffleReadSortable = listener.stageIdToShuffleRead.getOrElse(s.stageId, 0L)
+ val shuffleRead = shuffleReadSortable match {
case 0 => ""
case b => Utils.bytesToString(b)
}
- val shuffleWrite = listener.stageIdToShuffleWrite.getOrElse(s.stageId, 0L) match {
+
+ val shuffleWriteSortable = listener.stageIdToShuffleWrite.getOrElse(s.stageId, 0L)
+ val shuffleWrite = shuffleWriteSortable match {
case 0 => ""
case b => Utils.bytesToString(b)
}
@@ -119,8 +122,8 @@ private[spark] class StageTable(val stages: Seq[StageInfo], val parent: JobProgr
<td class="progress-cell">
{makeProgressBar(startedTasks, completedTasks, failedTasks, totalTasks)}
</td>
- <td>{shuffleRead}</td>
- <td>{shuffleWrite}</td>
+ <td sorttable_customekey={shuffleReadSortable.toString}>{shuffleRead}</td>
+ <td sorttable_customekey={shuffleWriteSortable.toString}>{shuffleWrite}</td>
</tr>
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala
index f60deafc6f..8bb4ee3bfa 100644
--- a/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala
@@ -35,6 +35,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
private var capacity = nextPowerOf2(initialCapacity)
private var mask = capacity - 1
private var curSize = 0
+ private var growThreshold = LOAD_FACTOR * capacity
// Holds keys and values in the same array for memory locality; specifically, the order of
// elements is key0, value0, key1, value1, key2, value2, etc.
@@ -56,7 +57,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
var i = 1
while (true) {
val curKey = data(2 * pos)
- if (k.eq(curKey) || k == curKey) {
+ if (k.eq(curKey) || k.equals(curKey)) {
return data(2 * pos + 1).asInstanceOf[V]
} else if (curKey.eq(null)) {
return null.asInstanceOf[V]
@@ -80,9 +81,23 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
haveNullValue = true
return
}
- val isNewEntry = putInto(data, k, value.asInstanceOf[AnyRef])
- if (isNewEntry) {
- incrementSize()
+ var pos = rehash(key.hashCode) & mask
+ var i = 1
+ while (true) {
+ val curKey = data(2 * pos)
+ if (curKey.eq(null)) {
+ data(2 * pos) = k
+ data(2 * pos + 1) = value.asInstanceOf[AnyRef]
+ incrementSize() // Since we added a new key
+ return
+ } else if (k.eq(curKey) || k.equals(curKey)) {
+ data(2 * pos + 1) = value.asInstanceOf[AnyRef]
+ return
+ } else {
+ val delta = i
+ pos = (pos + delta) & mask
+ i += 1
+ }
}
}
@@ -104,7 +119,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
var i = 1
while (true) {
val curKey = data(2 * pos)
- if (k.eq(curKey) || k == curKey) {
+ if (k.eq(curKey) || k.equals(curKey)) {
val newValue = updateFunc(true, data(2 * pos + 1).asInstanceOf[V])
data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
return newValue
@@ -161,45 +176,17 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
/** Increase table size by 1, rehashing if necessary */
private def incrementSize() {
curSize += 1
- if (curSize > LOAD_FACTOR * capacity) {
+ if (curSize > growThreshold) {
growTable()
}
}
/**
- * Re-hash a value to deal better with hash functions that don't differ
- * in the lower bits, similar to java.util.HashMap
+ * Re-hash a value to deal better with hash functions that don't differ in the lower bits.
+ * We use the Murmur Hash 3 finalization step that's also used in fastutil.
*/
private def rehash(h: Int): Int = {
- val r = h ^ (h >>> 20) ^ (h >>> 12)
- r ^ (r >>> 7) ^ (r >>> 4)
- }
-
- /**
- * Put an entry into a table represented by data, returning true if
- * this increases the size of the table or false otherwise. Assumes
- * that "data" has at least one empty slot.
- */
- private def putInto(data: Array[AnyRef], key: AnyRef, value: AnyRef): Boolean = {
- val mask = (data.length / 2) - 1
- var pos = rehash(key.hashCode) & mask
- var i = 1
- while (true) {
- val curKey = data(2 * pos)
- if (curKey.eq(null)) {
- data(2 * pos) = key
- data(2 * pos + 1) = value.asInstanceOf[AnyRef]
- return true
- } else if (curKey.eq(key) || curKey == key) {
- data(2 * pos + 1) = value.asInstanceOf[AnyRef]
- return false
- } else {
- val delta = i
- pos = (pos + delta) & mask
- i += 1
- }
- }
- return false // Never reached but needed to keep compiler happy
+ it.unimi.dsi.fastutil.HashCommon.murmurHash3(h)
}
/** Double the table's size and re-hash everything */
@@ -211,16 +198,36 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
throw new Exception("Can't make capacity bigger than 2^29 elements")
}
val newData = new Array[AnyRef](2 * newCapacity)
- var pos = 0
- while (pos < capacity) {
- if (!data(2 * pos).eq(null)) {
- putInto(newData, data(2 * pos), data(2 * pos + 1))
+ val newMask = newCapacity - 1
+ // Insert all our old values into the new array. Note that because our old keys are
+ // unique, there's no need to check for equality here when we insert.
+ var oldPos = 0
+ while (oldPos < capacity) {
+ if (!data(2 * oldPos).eq(null)) {
+ val key = data(2 * oldPos)
+ val value = data(2 * oldPos + 1)
+ var newPos = rehash(key.hashCode) & newMask
+ var i = 1
+ var keepGoing = true
+ while (keepGoing) {
+ val curKey = newData(2 * newPos)
+ if (curKey.eq(null)) {
+ newData(2 * newPos) = key
+ newData(2 * newPos + 1) = value
+ keepGoing = false
+ } else {
+ val delta = i
+ newPos = (newPos + delta) & newMask
+ i += 1
+ }
+ }
}
- pos += 1
+ oldPos += 1
}
data = newData
capacity = newCapacity
- mask = newCapacity - 1
+ mask = newMask
+ growThreshold = LOAD_FACTOR * newCapacity
}
private def nextPowerOf2(n: Int): Int = {
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index fe932d8ede..a79e64e810 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -823,4 +823,28 @@ private[spark] object Utils extends Logging {
return System.getProperties().clone()
.asInstanceOf[java.util.Properties].toMap[String, String]
}
+
+ /**
+ * Method executed for repeating a task for side effects.
+ * Unlike a for comprehension, it permits JVM JIT optimization
+ */
+ def times(numIters: Int)(f: => Unit): Unit = {
+ var i = 0
+ while (i < numIters) {
+ f
+ i += 1
+ }
+ }
+
+ /**
+ * Timing method based on iterations that permit JVM JIT optimization.
+ * @param numIters number of iterations
+ * @param f function to be executed
+ */
+ def timeIt(numIters: Int)(f: => Unit): Long = {
+ val start = System.currentTimeMillis
+ times(numIters)(f)
+ System.currentTimeMillis - start
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala
new file mode 100644
index 0000000000..e9907e6c85
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.util.{Random => JavaRandom}
+import org.apache.spark.util.Utils.timeIt
+
+/**
+ * This class implements a XORShift random number generator algorithm
+ * Source:
+ * Marsaglia, G. (2003). Xorshift RNGs. Journal of Statistical Software, Vol. 8, Issue 14.
+ * @see <a href="http://www.jstatsoft.org/v08/i14/paper">Paper</a>
+ * This implementation is approximately 3.5 times faster than
+ * {@link java.util.Random java.util.Random}, partly because of the algorithm, but also due
+ * to renouncing thread safety. JDK's implementation uses an AtomicLong seed, this class
+ * uses a regular Long. We can forgo thread safety since we use a new instance of the RNG
+ * for each thread.
+ */
+private[spark] class XORShiftRandom(init: Long) extends JavaRandom(init) {
+
+ def this() = this(System.nanoTime)
+
+ private var seed = init
+
+ // we need to just override next - this will be called by nextInt, nextDouble,
+ // nextGaussian, nextLong, etc.
+ override protected def next(bits: Int): Int = {
+ var nextSeed = seed ^ (seed << 21)
+ nextSeed ^= (nextSeed >>> 35)
+ nextSeed ^= (nextSeed << 4)
+ seed = nextSeed
+ (nextSeed & ((1L << bits) -1)).asInstanceOf[Int]
+ }
+}
+
+/** Contains benchmark method and main method to run benchmark of the RNG */
+private[spark] object XORShiftRandom {
+
+ /**
+ * Main method for running benchmark
+ * @param args takes one argument - the number of random numbers to generate
+ */
+ def main(args: Array[String]): Unit = {
+ if (args.length != 1) {
+ println("Benchmark of XORShiftRandom vis-a-vis java.util.Random")
+ println("Usage: XORShiftRandom number_of_random_numbers_to_generate")
+ System.exit(1)
+ }
+ println(benchmark(args(0).toInt))
+ }
+
+ /**
+ * @param numIters Number of random numbers to generate while running the benchmark
+ * @return Map of execution times for {@link java.util.Random java.util.Random}
+ * and XORShift
+ */
+ def benchmark(numIters: Int) = {
+
+ val seed = 1L
+ val million = 1e6.toInt
+ val javaRand = new JavaRandom(seed)
+ val xorRand = new XORShiftRandom(seed)
+
+ // this is just to warm up the JIT - we're not timing anything
+ timeIt(1e6.toInt) {
+ javaRand.nextInt()
+ xorRand.nextInt()
+ }
+
+ val iters = timeIt(numIters)(_)
+
+ /* Return results as a map instead of just printing to screen
+ in case the user wants to do something with them */
+ Map("javaTime" -> iters {javaRand.nextInt()},
+ "xorTime" -> iters {xorRand.nextInt()})
+
+ }
+
+} \ No newline at end of file
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala
index 369519c559..20554f0aab 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala
@@ -17,35 +17,51 @@
package org.apache.spark.util.collection
-/** Provides a simple, non-threadsafe, array-backed vector that can store primitives. */
+/**
+ * An append-only, non-threadsafe, array-backed vector that is optimized for primitive types.
+ */
private[spark]
class PrimitiveVector[@specialized(Long, Int, Double) V: ClassManifest](initialSize: Int = 64) {
- private var numElements = 0
- private var array: Array[V] = _
+ private var _numElements = 0
+ private var _array: Array[V] = _
// NB: This must be separate from the declaration, otherwise the specialized parent class
- // will get its own array with the same initial size. TODO: Figure out why...
- array = new Array[V](initialSize)
+ // will get its own array with the same initial size.
+ _array = new Array[V](initialSize)
def apply(index: Int): V = {
- require(index < numElements)
- array(index)
+ require(index < _numElements)
+ _array(index)
}
def +=(value: V) {
- if (numElements == array.length) { resize(array.length * 2) }
- array(numElements) = value
- numElements += 1
+ if (_numElements == _array.length) {
+ resize(_array.length * 2)
+ }
+ _array(_numElements) = value
+ _numElements += 1
}
- def length = numElements
+ def capacity: Int = _array.length
+
+ def length: Int = _numElements
+
+ def size: Int = _numElements
+
+ /** Gets the underlying array backing this vector. */
+ def array: Array[V] = _array
- def getUnderlyingArray = array
+ /** Trims this vector so that the capacity is equal to the size. */
+ def trim(): PrimitiveVector[V] = resize(size)
/** Resizes the array, dropping elements if the total length decreases. */
- def resize(newLength: Int) {
+ def resize(newLength: Int): PrimitiveVector[V] = {
val newArray = new Array[V](newLength)
- array.copyToArray(newArray)
- array = newArray
+ _array.copyToArray(newArray)
+ _array = newArray
+ if (newLength < _numElements) {
+ _numElements = newLength
+ }
+ this
}
}
diff --git a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala
index 459e257d79..8dd5786da6 100644
--- a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala
+++ b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala
@@ -30,7 +30,7 @@ trait LocalSparkContext extends BeforeAndAfterEach with BeforeAndAfterAll { self
@transient var sc: SparkContext = _
override def beforeAll() {
- InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory());
+ InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory())
super.beforeAll()
}
diff --git a/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala b/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala
deleted file mode 100644
index 21f16ef2c6..0000000000
--- a/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala
+++ /dev/null
@@ -1,45 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark
-
-import org.scalatest.FunSuite
-import org.apache.spark.SparkContext._
-import org.apache.spark.rdd.{RDD, PartitionPruningRDD}
-
-
-class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext {
-
- test("Pruned Partitions inherit locality prefs correctly") {
- class TestPartition(i: Int) extends Partition {
- def index = i
- }
- val rdd = new RDD[Int](sc, Nil) {
- override protected def getPartitions = {
- Array[Partition](
- new TestPartition(1),
- new TestPartition(2),
- new TestPartition(3))
- }
- def compute(split: Partition, context: TaskContext) = {Iterator()}
- }
- val prunedRDD = PartitionPruningRDD.create(rdd, {x => if (x==2) true else false})
- val p = prunedRDD.partitions(0)
- assert(p.index == 2)
- assert(prunedRDD.partitions.length == 1)
- }
-}
diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala
index 7d938917f2..1374d01774 100644
--- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala
+++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala
@@ -142,11 +142,11 @@ class PartitioningSuite extends FunSuite with SharedSparkContext {
.filter(_ >= 0.0)
// Run the partitions, including the consecutive empty ones, through StatCounter
- val stats: StatCounter = rdd.stats();
- assert(abs(6.0 - stats.sum) < 0.01);
- assert(abs(6.0/2 - rdd.mean) < 0.01);
- assert(abs(1.0 - rdd.variance) < 0.01);
- assert(abs(1.0 - rdd.stdev) < 0.01);
+ val stats: StatCounter = rdd.stats()
+ assert(abs(6.0 - stats.sum) < 0.01)
+ assert(abs(6.0/2 - rdd.mean) < 0.01)
+ assert(abs(1.0 - rdd.variance) < 0.01)
+ assert(abs(1.0 - rdd.stdev) < 0.01)
// Add other tests here for classes that should be able to handle empty partitions correctly
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala
new file mode 100644
index 0000000000..8f0954122b
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala
@@ -0,0 +1,19 @@
+package org.apache.spark.deploy.worker
+
+import java.io.File
+import org.scalatest.FunSuite
+import org.apache.spark.deploy.{ExecutorState, Command, ApplicationDescription}
+
+class ExecutorRunnerTest extends FunSuite {
+ test("command includes appId") {
+ def f(s:String) = new File(s)
+ val sparkHome = sys.env("SPARK_HOME")
+ val appDesc = new ApplicationDescription("app name", 8, 500, Command("foo", Seq(),Map()),
+ sparkHome, "appUiUrl")
+ val appId = "12345-worker321-9876"
+ val er = new ExecutorRunner(appId, 1, appDesc, 8, 500, null, "blah", "worker321", f(sparkHome),
+ f("ooga"), ExecutorState.RUNNING)
+
+ assert(er.buildCommandSeq().last === appId)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala
new file mode 100644
index 0000000000..53a7b7c44d
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import org.scalatest.FunSuite
+import org.apache.spark.{TaskContext, Partition, SharedSparkContext}
+
+
+class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext {
+
+
+ test("Pruned Partitions inherit locality prefs correctly") {
+
+ val rdd = new RDD[Int](sc, Nil) {
+ override protected def getPartitions = {
+ Array[Partition](
+ new TestPartition(0, 1),
+ new TestPartition(1, 1),
+ new TestPartition(2, 1))
+ }
+
+ def compute(split: Partition, context: TaskContext) = {
+ Iterator()
+ }
+ }
+ val prunedRDD = PartitionPruningRDD.create(rdd, {
+ x => if (x == 2) true else false
+ })
+ assert(prunedRDD.partitions.length == 1)
+ val p = prunedRDD.partitions(0)
+ assert(p.index == 0)
+ assert(p.asInstanceOf[PartitionPruningRDDPartition].parentSplit.index == 2)
+ }
+
+
+ test("Pruned Partitions can be unioned ") {
+
+ val rdd = new RDD[Int](sc, Nil) {
+ override protected def getPartitions = {
+ Array[Partition](
+ new TestPartition(0, 4),
+ new TestPartition(1, 5),
+ new TestPartition(2, 6))
+ }
+
+ def compute(split: Partition, context: TaskContext) = {
+ List(split.asInstanceOf[TestPartition].testValue).iterator
+ }
+ }
+ val prunedRDD1 = PartitionPruningRDD.create(rdd, {
+ x => if (x == 0) true else false
+ })
+
+ val prunedRDD2 = PartitionPruningRDD.create(rdd, {
+ x => if (x == 2) true else false
+ })
+
+ val merged = prunedRDD1 ++ prunedRDD2
+ assert(merged.count() == 2)
+ val take = merged.take(2)
+ assert(take.apply(0) == 4)
+ assert(take.apply(1) == 6)
+ }
+}
+
+class TestPartition(i: Int, value: Int) extends Partition with Serializable {
+ def index = i
+
+ def testValue = this.value
+
+}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 00f2fdd657..a4d41ebbff 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -100,7 +100,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
cacheLocations.clear()
results.clear()
mapOutputTracker = new MapOutputTrackerMaster()
- scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, null) {
+ scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, sc.env) {
override def runLocally(job: ActiveJob) {
// don't bother with the thread while unit testing
runLocallyWithinThread(job)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala
index 8406093246..984881861c 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala
@@ -65,7 +65,7 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers
val rootStageInfo = new StageInfo(rootStage)
joblogger.onStageSubmitted(SparkListenerStageSubmitted(rootStageInfo, null))
- joblogger.getRddNameTest(parentRdd) should be (parentRdd.getClass.getName)
+ joblogger.getRddNameTest(parentRdd) should be (parentRdd.getClass.getSimpleName)
parentRdd.setName("MyRDD")
joblogger.getRddNameTest(parentRdd) should be ("MyRDD")
joblogger.createLogWriterTest(jobID)
@@ -91,8 +91,10 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers
sc.addSparkListener(joblogger)
val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) }
rdd.reduceByKey(_+_).collect()
+
+ val user = System.getProperty("user.name", SparkContext.SPARK_UNKNOWN_USER)
- joblogger.getLogDir should be ("/tmp/spark")
+ joblogger.getLogDir should be ("/tmp/spark-%s".format(user))
joblogger.getJobIDtoPrintWriter.size should be (1)
joblogger.getStageIDToJobID.size should be (2)
joblogger.getStageIDToJobID.get(0) should be (Some(0))
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
index f7f599532a..1fd76420ea 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
@@ -83,7 +83,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
i
}
- val d = sc.parallelize(1 to 1e4.toInt, 64).map{i => w(i)}
+ val d = sc.parallelize(0 to 1e4.toInt, 64).map{i => w(i)}
d.count()
assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
listener.stageInfos.size should be (1)
diff --git a/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala b/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala
new file mode 100644
index 0000000000..b78367b6ca
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.util.Random
+import org.scalatest.FlatSpec
+import org.scalatest.FunSuite
+import org.scalatest.matchers.ShouldMatchers
+import org.apache.spark.util.Utils.times
+
+class XORShiftRandomSuite extends FunSuite with ShouldMatchers {
+
+ def fixture = new {
+ val seed = 1L
+ val xorRand = new XORShiftRandom(seed)
+ val hundMil = 1e8.toInt
+ }
+
+ /*
+ * This test is based on a chi-squared test for randomness. The values are hard-coded
+ * so as not to create Spark's dependency on apache.commons.math3 just to call one
+ * method for calculating the exact p-value for a given number of random numbers
+ * and bins. In case one would want to move to a full-fledged test based on
+ * apache.commons.math3, the relevant class is here:
+ * org.apache.commons.math3.stat.inference.ChiSquareTest
+ */
+ test ("XORShift generates valid random numbers") {
+
+ val f = fixture
+
+ val numBins = 10
+ // create 10 bins
+ val bins = Array.fill(numBins)(0)
+
+ // populate bins based on modulus of the random number
+ times(f.hundMil) {bins(math.abs(f.xorRand.nextInt) % 10) += 1}
+
+ /* since the seed is deterministic, until the algorithm is changed, we know the result will be
+ * exactly this: Array(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272,
+ * 10000790, 10002286, 9998699), so the test will never fail at the prespecified (5%)
+ * significance level. However, should the RNG implementation change, the test should still
+ * pass at the same significance level. The chi-squared test done in R gave the following
+ * results:
+ * > chisq.test(c(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272,
+ * 10000790, 10002286, 9998699))
+ * Chi-squared test for given probabilities
+ * data: c(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272, 10000790,
+ * 10002286, 9998699)
+ * X-squared = 11.975, df = 9, p-value = 0.2147
+ * Note that the p-value was ~0.22. The test will fail if alpha < 0.05, which for 100 million
+ * random numbers
+ * and 10 bins will happen at X-squared of ~16.9196. So, the test will fail if X-squared
+ * is greater than or equal to that number.
+ */
+ val binSize = f.hundMil/numBins
+ val xSquared = bins.map(x => math.pow((binSize - x), 2)/binSize).sum
+ xSquared should be < (16.9196)
+
+ }
+
+} \ No newline at end of file
diff --git a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveVectorSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveVectorSuite.scala
new file mode 100644
index 0000000000..970dade628
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveVectorSuite.scala
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.util.SizeEstimator
+
+class PrimitiveVectorSuite extends FunSuite {
+
+ test("primitive value") {
+ val vector = new PrimitiveVector[Int]
+
+ for (i <- 0 until 1000) {
+ vector += i
+ assert(vector(i) === i)
+ }
+
+ assert(vector.size === 1000)
+ assert(vector.size == vector.length)
+ intercept[IllegalArgumentException] {
+ vector(1000)
+ }
+
+ for (i <- 0 until 1000) {
+ assert(vector(i) == i)
+ }
+ }
+
+ test("non-primitive value") {
+ val vector = new PrimitiveVector[String]
+
+ for (i <- 0 until 1000) {
+ vector += i.toString
+ assert(vector(i) === i.toString)
+ }
+
+ assert(vector.size === 1000)
+ assert(vector.size == vector.length)
+ intercept[IllegalArgumentException] {
+ vector(1000)
+ }
+
+ for (i <- 0 until 1000) {
+ assert(vector(i) == i.toString)
+ }
+ }
+
+ test("ideal growth") {
+ val vector = new PrimitiveVector[Long](initialSize = 1)
+ vector += 1
+ for (i <- 1 until 1024) {
+ vector += i
+ assert(vector.size === i + 1)
+ assert(vector.capacity === Integer.highestOneBit(i) * 2)
+ }
+ assert(vector.capacity === 1024)
+ vector += 1024
+ assert(vector.capacity === 2048)
+ }
+
+ test("ideal size") {
+ val vector = new PrimitiveVector[Long](8192)
+ for (i <- 0 until 8192) {
+ vector += i
+ }
+ assert(vector.size === 8192)
+ assert(vector.capacity === 8192)
+ val actualSize = SizeEstimator.estimate(vector)
+ val expectedSize = 8192 * 8
+ // Make sure we are not allocating a significant amount of memory beyond our expected.
+ // Due to specialization wonkiness, we need to ensure we don't have 2 copies of the array.
+ assert(actualSize < expectedSize * 1.1)
+ }
+
+ test("resizing") {
+ val vector = new PrimitiveVector[Long]
+ for (i <- 0 until 4097) {
+ vector += i
+ }
+ assert(vector.size === 4097)
+ assert(vector.capacity === 8192)
+ vector.trim()
+ assert(vector.size === 4097)
+ assert(vector.capacity === 4097)
+ vector.resize(5000)
+ assert(vector.size === 4097)
+ assert(vector.capacity === 5000)
+ vector.resize(4000)
+ assert(vector.size === 4000)
+ assert(vector.capacity === 4000)
+ vector.resize(5000)
+ assert(vector.size === 4000)
+ assert(vector.capacity === 5000)
+ for (i <- 0 until 4000) {
+ assert(vector(i) == i)
+ }
+ intercept[IllegalArgumentException] {
+ vector(4000)
+ }
+ }
+}