aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2015-01-23 17:53:15 -0800
committerJosh Rosen <joshrosen@databricks.com>2015-01-23 17:53:15 -0800
commitcef1f092a628ac20709857b4388bb10e0b5143b0 (patch)
tree1a383501a69bdd6335b8bb8db9ca88352b8e6846 /core
parentea74365b7c5a3ac29cae9ba66f140f1fa5e8d312 (diff)
downloadspark-cef1f092a628ac20709857b4388bb10e0b5143b0.tar.gz
spark-cef1f092a628ac20709857b4388bb10e0b5143b0.tar.bz2
spark-cef1f092a628ac20709857b4388bb10e0b5143b0.zip
[SPARK-5063] More helpful error messages for several invalid operations
This patch adds more helpful error messages for invalid programs that define nested RDDs, broadcast RDDs, perform actions inside of transformations (e.g. calling `count()` from inside of `map()`), and call certain methods on stopped SparkContexts. Currently, these invalid programs lead to confusing NullPointerExceptions at runtime and have been a major source of questions on the mailing list and StackOverflow. In a few cases, I chose to log warnings instead of throwing exceptions in order to avoid any chance that this patch breaks programs that worked "by accident" in earlier Spark releases (e.g. programs that define nested RDDs but never run any jobs with them). In SparkContext, the new `assertNotStopped()` method is used to check whether methods are being invoked on a stopped SparkContext. In some cases, user programs will not crash in spite of calling methods on stopped SparkContexts, so I've only added `assertNotStopped()` calls to methods that always throw exceptions when called on stopped contexts (e.g. by dereferencing a null `dagScheduler` pointer). Author: Josh Rosen <joshrosen@databricks.com> Closes #3884 from JoshRosen/SPARK-5063 and squashes the following commits: a38774b [Josh Rosen] Fix spelling typo a943e00 [Josh Rosen] Convert two exceptions into warnings in order to avoid breaking user programs in some edge-cases. 2d0d7f7 [Josh Rosen] Fix test to reflect 1.2.1 compatibility 3f0ea0c [Josh Rosen] Revert two unintentional formatting changes 8e5da69 [Josh Rosen] Remove assertNotStopped() calls for methods that were sometimes safe to call on stopped SC's in Spark 1.2 8cff41a [Josh Rosen] IllegalStateException fix 6ef68d0 [Josh Rosen] Fix Python line length issues. 9f6a0b8 [Josh Rosen] Add improved error messages to PySpark. 13afd0f [Josh Rosen] SparkException -> IllegalStateException 8d404f3 [Josh Rosen] Merge remote-tracking branch 'origin/master' into SPARK-5063 b39e041 [Josh Rosen] Fix BroadcastSuite test which broadcasted an RDD 99cc09f [Josh Rosen] Guard against calling methods on stopped SparkContexts. 34833e8 [Josh Rosen] Add more descriptive error message. 57cc8a1 [Josh Rosen] Add error message when directly broadcasting RDD. 15b2e6b [Josh Rosen] [SPARK-5063] Useful error messages for nested RDDs and actions inside of transformations
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala62
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala19
-rw-r--r--core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala12
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala40
4 files changed, 119 insertions, 14 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 6a354ed4d1..8175d175b1 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -85,6 +85,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
val startTime = System.currentTimeMillis()
+ @volatile private var stopped: Boolean = false
+
+ private def assertNotStopped(): Unit = {
+ if (stopped) {
+ throw new IllegalStateException("Cannot call methods on a stopped SparkContext")
+ }
+ }
+
/**
* Create a SparkContext that loads settings from system properties (for instance, when
* launching with ./bin/spark-submit).
@@ -525,6 +533,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* modified collection. Pass a copy of the argument to avoid this.
*/
def parallelize[T: ClassTag](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = {
+ assertNotStopped()
new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]())
}
@@ -540,6 +549,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* location preferences (hostnames of Spark nodes) for each object.
* Create a new partition for each collection item. */
def makeRDD[T: ClassTag](seq: Seq[(T, Seq[String])]): RDD[T] = {
+ assertNotStopped()
val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap
new ParallelCollectionRDD[T](this, seq.map(_._1), seq.size, indexToPrefs)
}
@@ -549,6 +559,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* Hadoop-supported file system URI, and return it as an RDD of Strings.
*/
def textFile(path: String, minPartitions: Int = defaultMinPartitions): RDD[String] = {
+ assertNotStopped()
hadoopFile(path, classOf[TextInputFormat], classOf[LongWritable], classOf[Text],
minPartitions).map(pair => pair._2.toString).setName(path)
}
@@ -582,6 +593,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
*/
def wholeTextFiles(path: String, minPartitions: Int = defaultMinPartitions):
RDD[(String, String)] = {
+ assertNotStopped()
val job = new NewHadoopJob(hadoopConfiguration)
NewFileInputFormat.addInputPath(job, new Path(path))
val updateConf = job.getConfiguration
@@ -627,6 +639,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
@Experimental
def binaryFiles(path: String, minPartitions: Int = defaultMinPartitions):
RDD[(String, PortableDataStream)] = {
+ assertNotStopped()
val job = new NewHadoopJob(hadoopConfiguration)
NewFileInputFormat.addInputPath(job, new Path(path))
val updateConf = job.getConfiguration
@@ -651,6 +664,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
@Experimental
def binaryRecords(path: String, recordLength: Int, conf: Configuration = hadoopConfiguration)
: RDD[Array[Byte]] = {
+ assertNotStopped()
conf.setInt(FixedLengthBinaryInputFormat.RECORD_LENGTH_PROPERTY, recordLength)
val br = newAPIHadoopFile[LongWritable, BytesWritable, FixedLengthBinaryInputFormat](path,
classOf[FixedLengthBinaryInputFormat],
@@ -684,6 +698,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
valueClass: Class[V],
minPartitions: Int = defaultMinPartitions
): RDD[(K, V)] = {
+ assertNotStopped()
// Add necessary security credentials to the JobConf before broadcasting it.
SparkHadoopUtil.get.addCredentials(conf)
new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minPartitions)
@@ -703,6 +718,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
valueClass: Class[V],
minPartitions: Int = defaultMinPartitions
): RDD[(K, V)] = {
+ assertNotStopped()
// A Hadoop configuration can be about 10 KB, which is pretty big, so broadcast it.
val confBroadcast = broadcast(new SerializableWritable(hadoopConfiguration))
val setInputPathsFunc = (jobConf: JobConf) => FileInputFormat.setInputPaths(jobConf, path)
@@ -782,6 +798,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
kClass: Class[K],
vClass: Class[V],
conf: Configuration = hadoopConfiguration): RDD[(K, V)] = {
+ assertNotStopped()
val job = new NewHadoopJob(conf)
NewFileInputFormat.addInputPath(job, new Path(path))
val updatedConf = job.getConfiguration
@@ -802,6 +819,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
fClass: Class[F],
kClass: Class[K],
vClass: Class[V]): RDD[(K, V)] = {
+ assertNotStopped()
new NewHadoopRDD(this, fClass, kClass, vClass, conf)
}
@@ -817,6 +835,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
valueClass: Class[V],
minPartitions: Int
): RDD[(K, V)] = {
+ assertNotStopped()
val inputFormatClass = classOf[SequenceFileInputFormat[K, V]]
hadoopFile(path, inputFormatClass, keyClass, valueClass, minPartitions)
}
@@ -828,9 +847,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* If you plan to directly cache Hadoop writable objects, you should first copy them using
* a `map` function.
* */
- def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]
- ): RDD[(K, V)] =
+ def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]): RDD[(K, V)] = {
+ assertNotStopped()
sequenceFile(path, keyClass, valueClass, defaultMinPartitions)
+ }
/**
* Version of sequenceFile() for types implicitly convertible to Writables through a
@@ -858,6 +878,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
(implicit km: ClassTag[K], vm: ClassTag[V],
kcf: () => WritableConverter[K], vcf: () => WritableConverter[V])
: RDD[(K, V)] = {
+ assertNotStopped()
val kc = kcf()
val vc = vcf()
val format = classOf[SequenceFileInputFormat[Writable, Writable]]
@@ -879,6 +900,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
path: String,
minPartitions: Int = defaultMinPartitions
): RDD[T] = {
+ assertNotStopped()
sequenceFile(path, classOf[NullWritable], classOf[BytesWritable], minPartitions)
.flatMap(x => Utils.deserialize[Array[T]](x._2.getBytes, Utils.getContextOrSparkClassLoader))
}
@@ -954,6 +976,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* The variable will be sent to each cluster only once.
*/
def broadcast[T: ClassTag](value: T): Broadcast[T] = {
+ assertNotStopped()
+ if (classOf[RDD[_]].isAssignableFrom(classTag[T].runtimeClass)) {
+ // This is a warning instead of an exception in order to avoid breaking user programs that
+ // might have created RDD broadcast variables but not used them:
+ logWarning("Can not directly broadcast RDDs; instead, call collect() and "
+ + "broadcast the result (see SPARK-5063)")
+ }
val bc = env.broadcastManager.newBroadcast[T](value, isLocal)
val callSite = getCallSite
logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm)
@@ -1046,6 +1075,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* memory available for caching.
*/
def getExecutorMemoryStatus: Map[String, (Long, Long)] = {
+ assertNotStopped()
env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) =>
(blockManagerId.host + ":" + blockManagerId.port, mem)
}
@@ -1058,6 +1088,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
*/
@DeveloperApi
def getRDDStorageInfo: Array[RDDInfo] = {
+ assertNotStopped()
val rddInfos = persistentRdds.values.map(RDDInfo.fromRdd).toArray
StorageUtils.updateRddInfo(rddInfos, getExecutorStorageStatus)
rddInfos.filter(_.isCached)
@@ -1075,6 +1106,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
*/
@DeveloperApi
def getExecutorStorageStatus: Array[StorageStatus] = {
+ assertNotStopped()
env.blockManager.master.getStorageStatus
}
@@ -1084,6 +1116,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
*/
@DeveloperApi
def getAllPools: Seq[Schedulable] = {
+ assertNotStopped()
// TODO(xiajunluan): We should take nested pools into account
taskScheduler.rootPool.schedulableQueue.toSeq
}
@@ -1094,6 +1127,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
*/
@DeveloperApi
def getPoolForName(pool: String): Option[Schedulable] = {
+ assertNotStopped()
Option(taskScheduler.rootPool.schedulableNameToSchedulable.get(pool))
}
@@ -1101,6 +1135,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* Return current scheduling mode
*/
def getSchedulingMode: SchedulingMode.SchedulingMode = {
+ assertNotStopped()
taskScheduler.schedulingMode
}
@@ -1206,16 +1241,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
SparkContext.SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized {
postApplicationEnd()
ui.foreach(_.stop())
- // Do this only if not stopped already - best case effort.
- // prevent NPE if stopped more than once.
- val dagSchedulerCopy = dagScheduler
- dagScheduler = null
- if (dagSchedulerCopy != null) {
+ if (!stopped) {
+ stopped = true
env.metricsSystem.report()
metadataCleaner.cancel()
env.actorSystem.stop(heartbeatReceiver)
cleaner.foreach(_.stop())
- dagSchedulerCopy.stop()
+ dagScheduler.stop()
+ dagScheduler = null
taskScheduler = null
// TODO: Cache.stop()?
env.stop()
@@ -1289,8 +1322,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
partitions: Seq[Int],
allowLocal: Boolean,
resultHandler: (Int, U) => Unit) {
- if (dagScheduler == null) {
- throw new SparkException("SparkContext has been shutdown")
+ if (stopped) {
+ throw new IllegalStateException("SparkContext has been shutdown")
}
val callSite = getCallSite
val cleanedFunc = clean(func)
@@ -1377,6 +1410,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
func: (TaskContext, Iterator[T]) => U,
evaluator: ApproximateEvaluator[U, R],
timeout: Long): PartialResult[R] = {
+ assertNotStopped()
val callSite = getCallSite
logInfo("Starting job: " + callSite.shortForm)
val start = System.nanoTime
@@ -1399,6 +1433,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
resultHandler: (Int, U) => Unit,
resultFunc: => R): SimpleFutureAction[R] =
{
+ assertNotStopped()
val cleanF = clean(processPartition)
val callSite = getCallSite
val waiter = dagScheduler.submitJob(
@@ -1417,11 +1452,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* for more information.
*/
def cancelJobGroup(groupId: String) {
+ assertNotStopped()
dagScheduler.cancelJobGroup(groupId)
}
/** Cancel all jobs that have been scheduled or are running. */
def cancelAllJobs() {
+ assertNotStopped()
dagScheduler.cancelAllJobs()
}
@@ -1468,7 +1505,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
def getCheckpointDir = checkpointDir
/** Default level of parallelism to use when not given by user (e.g. parallelize and makeRDD). */
- def defaultParallelism: Int = taskScheduler.defaultParallelism
+ def defaultParallelism: Int = {
+ assertNotStopped()
+ taskScheduler.defaultParallelism
+ }
/** Default min number of partitions for Hadoop RDDs when not given by user */
@deprecated("use defaultMinPartitions", "1.0.0")
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 97012c7033..ab7410a1f7 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -76,10 +76,27 @@ import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, Bernoulli
* on RDD internals.
*/
abstract class RDD[T: ClassTag](
- @transient private var sc: SparkContext,
+ @transient private var _sc: SparkContext,
@transient private var deps: Seq[Dependency[_]]
) extends Serializable with Logging {
+ if (classOf[RDD[_]].isAssignableFrom(elementClassTag.runtimeClass)) {
+ // This is a warning instead of an exception in order to avoid breaking user programs that
+ // might have defined nested RDDs without running jobs with them.
+ logWarning("Spark does not support nested RDDs (see SPARK-5063)")
+ }
+
+ private def sc: SparkContext = {
+ if (_sc == null) {
+ throw new SparkException(
+ "RDD transformations and actions can only be invoked by the driver, not inside of other " +
+ "transformations; for example, rdd1.map(x => rdd2.values.count() * x) is invalid because " +
+ "the values transformation and count action cannot be performed inside of the rdd1.map " +
+ "transformation. For more information, see SPARK-5063.")
+ }
+ _sc
+ }
+
/** Construct an RDD with just a one-to-one dependency on one parent */
def this(@transient oneParent: RDD[_]) =
this(oneParent.context , List(new OneToOneDependency(oneParent)))
diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
index b0a70f012f..af3272692d 100644
--- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
+++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
@@ -170,6 +170,15 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
testPackage.runCallSiteTest(sc)
}
+ test("Broadcast variables cannot be created after SparkContext is stopped (SPARK-5065)") {
+ sc = new SparkContext("local", "test")
+ sc.stop()
+ val thrown = intercept[IllegalStateException] {
+ sc.broadcast(Seq(1, 2, 3))
+ }
+ assert(thrown.getMessage.toLowerCase.contains("stopped"))
+ }
+
/**
* Verify the persistence of state associated with an HttpBroadcast in either local mode or
* local-cluster mode (when distributed = true).
@@ -349,8 +358,7 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
package object testPackage extends Assertions {
def runCallSiteTest(sc: SparkContext) {
- val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2)
- val broadcast = sc.broadcast(rdd)
+ val broadcast = sc.broadcast(Array(1, 2, 3, 4))
broadcast.destroy()
val thrown = intercept[SparkException] { broadcast.value }
assert(thrown.getMessage.contains("BroadcastSuite.scala"))
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 381ee2d456..e33b4bbbb8 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -927,4 +927,44 @@ class RDDSuite extends FunSuite with SharedSparkContext {
mutableDependencies += dep
}
}
+
+ test("nested RDDs are not supported (SPARK-5063)") {
+ val rdd: RDD[Int] = sc.parallelize(1 to 100)
+ val rdd2: RDD[Int] = sc.parallelize(1 to 100)
+ val thrown = intercept[SparkException] {
+ val nestedRDD: RDD[RDD[Int]] = rdd.mapPartitions { x => Seq(rdd2.map(x => x)).iterator }
+ nestedRDD.count()
+ }
+ assert(thrown.getMessage.contains("SPARK-5063"))
+ }
+
+ test("actions cannot be performed inside of transformations (SPARK-5063)") {
+ val rdd: RDD[Int] = sc.parallelize(1 to 100)
+ val rdd2: RDD[Int] = sc.parallelize(1 to 100)
+ val thrown = intercept[SparkException] {
+ rdd.map(x => x * rdd2.count).collect()
+ }
+ assert(thrown.getMessage.contains("SPARK-5063"))
+ }
+
+ test("cannot run actions after SparkContext has been stopped (SPARK-5063)") {
+ val existingRDD = sc.parallelize(1 to 100)
+ sc.stop()
+ val thrown = intercept[IllegalStateException] {
+ existingRDD.count()
+ }
+ assert(thrown.getMessage.contains("shutdown"))
+ }
+
+ test("cannot call methods on a stopped SparkContext (SPARK-5063)") {
+ sc.stop()
+ def assertFails(block: => Any): Unit = {
+ val thrown = intercept[IllegalStateException] {
+ block
+ }
+ assert(thrown.getMessage.contains("stopped"))
+ }
+ assertFails { sc.parallelize(1 to 100) }
+ assertFails { sc.textFile("/nonexistent-path") }
+ }
}