aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2015-01-23 17:53:15 -0800
committerJosh Rosen <joshrosen@databricks.com>2015-01-23 17:53:47 -0800
commit2ea782a9d24388daa13e8b0ce57fa110ecbdef1b (patch)
tree91f1f81a50e26ed6029072362e7b4ee2cf8cc0ea
parent5aaf0e0ff5e5082c0064f5f4065cd66a62aa72d6 (diff)
downloadspark-2ea782a9d24388daa13e8b0ce57fa110ecbdef1b.tar.gz
spark-2ea782a9d24388daa13e8b0ce57fa110ecbdef1b.tar.bz2
spark-2ea782a9d24388daa13e8b0ce57fa110ecbdef1b.zip
[SPARK-5063] More helpful error messages for several invalid operations
This patch adds more helpful error messages for invalid programs that define nested RDDs, broadcast RDDs, perform actions inside of transformations (e.g. calling `count()` from inside of `map()`), and call certain methods on stopped SparkContexts. Currently, these invalid programs lead to confusing NullPointerExceptions at runtime and have been a major source of questions on the mailing list and StackOverflow. In a few cases, I chose to log warnings instead of throwing exceptions in order to avoid any chance that this patch breaks programs that worked "by accident" in earlier Spark releases (e.g. programs that define nested RDDs but never run any jobs with them). In SparkContext, the new `assertNotStopped()` method is used to check whether methods are being invoked on a stopped SparkContext. In some cases, user programs will not crash in spite of calling methods on stopped SparkContexts, so I've only added `assertNotStopped()` calls to methods that always throw exceptions when called on stopped contexts (e.g. by dereferencing a null `dagScheduler` pointer). Author: Josh Rosen <joshrosen@databricks.com> Closes #3884 from JoshRosen/SPARK-5063 and squashes the following commits: a38774b [Josh Rosen] Fix spelling typo a943e00 [Josh Rosen] Convert two exceptions into warnings in order to avoid breaking user programs in some edge-cases. 2d0d7f7 [Josh Rosen] Fix test to reflect 1.2.1 compatibility 3f0ea0c [Josh Rosen] Revert two unintentional formatting changes 8e5da69 [Josh Rosen] Remove assertNotStopped() calls for methods that were sometimes safe to call on stopped SC's in Spark 1.2 8cff41a [Josh Rosen] IllegalStateException fix 6ef68d0 [Josh Rosen] Fix Python line length issues. 9f6a0b8 [Josh Rosen] Add improved error messages to PySpark. 13afd0f [Josh Rosen] SparkException -> IllegalStateException 8d404f3 [Josh Rosen] Merge remote-tracking branch 'origin/master' into SPARK-5063 b39e041 [Josh Rosen] Fix BroadcastSuite test which broadcasted an RDD 99cc09f [Josh Rosen] Guard against calling methods on stopped SparkContexts. 34833e8 [Josh Rosen] Add more descriptive error message. 57cc8a1 [Josh Rosen] Add error message when directly broadcasting RDD. 15b2e6b [Josh Rosen] [SPARK-5063] Useful error messages for nested RDDs and actions inside of transformations (cherry picked from commit cef1f092a628ac20709857b4388bb10e0b5143b0) Signed-off-by: Josh Rosen <joshrosen@databricks.com>
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala62
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala19
-rw-r--r--core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala12
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala40
-rw-r--r--python/pyspark/context.py8
-rw-r--r--python/pyspark/rdd.py11
6 files changed, 138 insertions, 14 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index c2eff5c449..b50a54126e 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -85,6 +85,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
val startTime = System.currentTimeMillis()
+ @volatile private var stopped: Boolean = false
+
+ private def assertNotStopped(): Unit = {
+ if (stopped) {
+ throw new IllegalStateException("Cannot call methods on a stopped SparkContext")
+ }
+ }
+
/**
* Create a SparkContext that loads settings from system properties (for instance, when
* launching with ./bin/spark-submit).
@@ -519,6 +527,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* modified collection. Pass a copy of the argument to avoid this.
*/
def parallelize[T: ClassTag](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = {
+ assertNotStopped()
new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]())
}
@@ -534,6 +543,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* location preferences (hostnames of Spark nodes) for each object.
* Create a new partition for each collection item. */
def makeRDD[T: ClassTag](seq: Seq[(T, Seq[String])]): RDD[T] = {
+ assertNotStopped()
val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap
new ParallelCollectionRDD[T](this, seq.map(_._1), seq.size, indexToPrefs)
}
@@ -543,6 +553,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* Hadoop-supported file system URI, and return it as an RDD of Strings.
*/
def textFile(path: String, minPartitions: Int = defaultMinPartitions): RDD[String] = {
+ assertNotStopped()
hadoopFile(path, classOf[TextInputFormat], classOf[LongWritable], classOf[Text],
minPartitions).map(pair => pair._2.toString).setName(path)
}
@@ -576,6 +587,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
*/
def wholeTextFiles(path: String, minPartitions: Int = defaultMinPartitions):
RDD[(String, String)] = {
+ assertNotStopped()
val job = new NewHadoopJob(hadoopConfiguration)
NewFileInputFormat.addInputPath(job, new Path(path))
val updateConf = job.getConfiguration
@@ -621,6 +633,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
@Experimental
def binaryFiles(path: String, minPartitions: Int = defaultMinPartitions):
RDD[(String, PortableDataStream)] = {
+ assertNotStopped()
val job = new NewHadoopJob(hadoopConfiguration)
NewFileInputFormat.addInputPath(job, new Path(path))
val updateConf = job.getConfiguration
@@ -645,6 +658,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
@Experimental
def binaryRecords(path: String, recordLength: Int, conf: Configuration = hadoopConfiguration)
: RDD[Array[Byte]] = {
+ assertNotStopped()
conf.setInt(FixedLengthBinaryInputFormat.RECORD_LENGTH_PROPERTY, recordLength)
val br = newAPIHadoopFile[LongWritable, BytesWritable, FixedLengthBinaryInputFormat](path,
classOf[FixedLengthBinaryInputFormat],
@@ -678,6 +692,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
valueClass: Class[V],
minPartitions: Int = defaultMinPartitions
): RDD[(K, V)] = {
+ assertNotStopped()
// Add necessary security credentials to the JobConf before broadcasting it.
SparkHadoopUtil.get.addCredentials(conf)
new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minPartitions)
@@ -697,6 +712,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
valueClass: Class[V],
minPartitions: Int = defaultMinPartitions
): RDD[(K, V)] = {
+ assertNotStopped()
// A Hadoop configuration can be about 10 KB, which is pretty big, so broadcast it.
val confBroadcast = broadcast(new SerializableWritable(hadoopConfiguration))
val setInputPathsFunc = (jobConf: JobConf) => FileInputFormat.setInputPaths(jobConf, path)
@@ -776,6 +792,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
kClass: Class[K],
vClass: Class[V],
conf: Configuration = hadoopConfiguration): RDD[(K, V)] = {
+ assertNotStopped()
val job = new NewHadoopJob(conf)
NewFileInputFormat.addInputPath(job, new Path(path))
val updatedConf = job.getConfiguration
@@ -796,6 +813,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
fClass: Class[F],
kClass: Class[K],
vClass: Class[V]): RDD[(K, V)] = {
+ assertNotStopped()
new NewHadoopRDD(this, fClass, kClass, vClass, conf)
}
@@ -811,6 +829,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
valueClass: Class[V],
minPartitions: Int
): RDD[(K, V)] = {
+ assertNotStopped()
val inputFormatClass = classOf[SequenceFileInputFormat[K, V]]
hadoopFile(path, inputFormatClass, keyClass, valueClass, minPartitions)
}
@@ -822,9 +841,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* If you plan to directly cache Hadoop writable objects, you should first copy them using
* a `map` function.
* */
- def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]
- ): RDD[(K, V)] =
+ def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]): RDD[(K, V)] = {
+ assertNotStopped()
sequenceFile(path, keyClass, valueClass, defaultMinPartitions)
+ }
/**
* Version of sequenceFile() for types implicitly convertible to Writables through a
@@ -852,6 +872,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
(implicit km: ClassTag[K], vm: ClassTag[V],
kcf: () => WritableConverter[K], vcf: () => WritableConverter[V])
: RDD[(K, V)] = {
+ assertNotStopped()
val kc = kcf()
val vc = vcf()
val format = classOf[SequenceFileInputFormat[Writable, Writable]]
@@ -873,6 +894,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
path: String,
minPartitions: Int = defaultMinPartitions
): RDD[T] = {
+ assertNotStopped()
sequenceFile(path, classOf[NullWritable], classOf[BytesWritable], minPartitions)
.flatMap(x => Utils.deserialize[Array[T]](x._2.getBytes, Utils.getContextOrSparkClassLoader))
}
@@ -948,6 +970,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* The variable will be sent to each cluster only once.
*/
def broadcast[T: ClassTag](value: T): Broadcast[T] = {
+ assertNotStopped()
+ if (classOf[RDD[_]].isAssignableFrom(classTag[T].runtimeClass)) {
+ // This is a warning instead of an exception in order to avoid breaking user programs that
+ // might have created RDD broadcast variables but not used them:
+ logWarning("Can not directly broadcast RDDs; instead, call collect() and "
+ + "broadcast the result (see SPARK-5063)")
+ }
val bc = env.broadcastManager.newBroadcast[T](value, isLocal)
val callSite = getCallSite
logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm)
@@ -1036,6 +1065,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* memory available for caching.
*/
def getExecutorMemoryStatus: Map[String, (Long, Long)] = {
+ assertNotStopped()
env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) =>
(blockManagerId.host + ":" + blockManagerId.port, mem)
}
@@ -1048,6 +1078,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
*/
@DeveloperApi
def getRDDStorageInfo: Array[RDDInfo] = {
+ assertNotStopped()
val rddInfos = persistentRdds.values.map(RDDInfo.fromRdd).toArray
StorageUtils.updateRddInfo(rddInfos, getExecutorStorageStatus)
rddInfos.filter(_.isCached)
@@ -1065,6 +1096,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
*/
@DeveloperApi
def getExecutorStorageStatus: Array[StorageStatus] = {
+ assertNotStopped()
env.blockManager.master.getStorageStatus
}
@@ -1074,6 +1106,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
*/
@DeveloperApi
def getAllPools: Seq[Schedulable] = {
+ assertNotStopped()
// TODO(xiajunluan): We should take nested pools into account
taskScheduler.rootPool.schedulableQueue.toSeq
}
@@ -1084,6 +1117,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
*/
@DeveloperApi
def getPoolForName(pool: String): Option[Schedulable] = {
+ assertNotStopped()
Option(taskScheduler.rootPool.schedulableNameToSchedulable.get(pool))
}
@@ -1091,6 +1125,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* Return current scheduling mode
*/
def getSchedulingMode: SchedulingMode.SchedulingMode = {
+ assertNotStopped()
taskScheduler.schedulingMode
}
@@ -1196,16 +1231,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
SparkContext.SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized {
postApplicationEnd()
ui.foreach(_.stop())
- // Do this only if not stopped already - best case effort.
- // prevent NPE if stopped more than once.
- val dagSchedulerCopy = dagScheduler
- dagScheduler = null
- if (dagSchedulerCopy != null) {
+ if (!stopped) {
+ stopped = true
env.metricsSystem.report()
metadataCleaner.cancel()
env.actorSystem.stop(heartbeatReceiver)
cleaner.foreach(_.stop())
- dagSchedulerCopy.stop()
+ dagScheduler.stop()
+ dagScheduler = null
taskScheduler = null
// TODO: Cache.stop()?
env.stop()
@@ -1279,8 +1312,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
partitions: Seq[Int],
allowLocal: Boolean,
resultHandler: (Int, U) => Unit) {
- if (dagScheduler == null) {
- throw new SparkException("SparkContext has been shutdown")
+ if (stopped) {
+ throw new IllegalStateException("SparkContext has been shutdown")
}
val callSite = getCallSite
val cleanedFunc = clean(func)
@@ -1367,6 +1400,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
func: (TaskContext, Iterator[T]) => U,
evaluator: ApproximateEvaluator[U, R],
timeout: Long): PartialResult[R] = {
+ assertNotStopped()
val callSite = getCallSite
logInfo("Starting job: " + callSite.shortForm)
val start = System.nanoTime
@@ -1389,6 +1423,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
resultHandler: (Int, U) => Unit,
resultFunc: => R): SimpleFutureAction[R] =
{
+ assertNotStopped()
val cleanF = clean(processPartition)
val callSite = getCallSite
val waiter = dagScheduler.submitJob(
@@ -1407,11 +1442,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* for more information.
*/
def cancelJobGroup(groupId: String) {
+ assertNotStopped()
dagScheduler.cancelJobGroup(groupId)
}
/** Cancel all jobs that have been scheduled or are running. */
def cancelAllJobs() {
+ assertNotStopped()
dagScheduler.cancelAllJobs()
}
@@ -1458,7 +1495,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
def getCheckpointDir = checkpointDir
/** Default level of parallelism to use when not given by user (e.g. parallelize and makeRDD). */
- def defaultParallelism: Int = taskScheduler.defaultParallelism
+ def defaultParallelism: Int = {
+ assertNotStopped()
+ taskScheduler.defaultParallelism
+ }
/** Default min number of partitions for Hadoop RDDs when not given by user */
@deprecated("use defaultMinPartitions", "1.0.0")
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index c26425dea0..1814318a8b 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -75,10 +75,27 @@ import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, Bernoulli
* on RDD internals.
*/
abstract class RDD[T: ClassTag](
- @transient private var sc: SparkContext,
+ @transient private var _sc: SparkContext,
@transient private var deps: Seq[Dependency[_]]
) extends Serializable with Logging {
+ if (classOf[RDD[_]].isAssignableFrom(elementClassTag.runtimeClass)) {
+ // This is a warning instead of an exception in order to avoid breaking user programs that
+ // might have defined nested RDDs without running jobs with them.
+ logWarning("Spark does not support nested RDDs (see SPARK-5063)")
+ }
+
+ private def sc: SparkContext = {
+ if (_sc == null) {
+ throw new SparkException(
+ "RDD transformations and actions can only be invoked by the driver, not inside of other " +
+ "transformations; for example, rdd1.map(x => rdd2.values.count() * x) is invalid because " +
+ "the values transformation and count action cannot be performed inside of the rdd1.map " +
+ "transformation. For more information, see SPARK-5063.")
+ }
+ _sc
+ }
+
/** Construct an RDD with just a one-to-one dependency on one parent */
def this(@transient oneParent: RDD[_]) =
this(oneParent.context , List(new OneToOneDependency(oneParent)))
diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
index b0a70f012f..af3272692d 100644
--- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
+++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
@@ -170,6 +170,15 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
testPackage.runCallSiteTest(sc)
}
+ test("Broadcast variables cannot be created after SparkContext is stopped (SPARK-5065)") {
+ sc = new SparkContext("local", "test")
+ sc.stop()
+ val thrown = intercept[IllegalStateException] {
+ sc.broadcast(Seq(1, 2, 3))
+ }
+ assert(thrown.getMessage.toLowerCase.contains("stopped"))
+ }
+
/**
* Verify the persistence of state associated with an HttpBroadcast in either local mode or
* local-cluster mode (when distributed = true).
@@ -349,8 +358,7 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
package object testPackage extends Assertions {
def runCallSiteTest(sc: SparkContext) {
- val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2)
- val broadcast = sc.broadcast(rdd)
+ val broadcast = sc.broadcast(Array(1, 2, 3, 4))
broadcast.destroy()
val thrown = intercept[SparkException] { broadcast.value }
assert(thrown.getMessage.contains("BroadcastSuite.scala"))
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 8de634a676..4e802abbaa 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -907,4 +907,44 @@ class RDDSuite extends FunSuite with SharedSparkContext {
mutableDependencies += dep
}
}
+
+ test("nested RDDs are not supported (SPARK-5063)") {
+ val rdd: RDD[Int] = sc.parallelize(1 to 100)
+ val rdd2: RDD[Int] = sc.parallelize(1 to 100)
+ val thrown = intercept[SparkException] {
+ val nestedRDD: RDD[RDD[Int]] = rdd.mapPartitions { x => Seq(rdd2.map(x => x)).iterator }
+ nestedRDD.count()
+ }
+ assert(thrown.getMessage.contains("SPARK-5063"))
+ }
+
+ test("actions cannot be performed inside of transformations (SPARK-5063)") {
+ val rdd: RDD[Int] = sc.parallelize(1 to 100)
+ val rdd2: RDD[Int] = sc.parallelize(1 to 100)
+ val thrown = intercept[SparkException] {
+ rdd.map(x => x * rdd2.count).collect()
+ }
+ assert(thrown.getMessage.contains("SPARK-5063"))
+ }
+
+ test("cannot run actions after SparkContext has been stopped (SPARK-5063)") {
+ val existingRDD = sc.parallelize(1 to 100)
+ sc.stop()
+ val thrown = intercept[IllegalStateException] {
+ existingRDD.count()
+ }
+ assert(thrown.getMessage.contains("shutdown"))
+ }
+
+ test("cannot call methods on a stopped SparkContext (SPARK-5063)") {
+ sc.stop()
+ def assertFails(block: => Any): Unit = {
+ val thrown = intercept[IllegalStateException] {
+ block
+ }
+ assert(thrown.getMessage.contains("stopped"))
+ }
+ assertFails { sc.parallelize(1 to 100) }
+ assertFails { sc.textFile("/nonexistent-path") }
+ }
}
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index b5c2421b88..23ff8ccf61 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -230,6 +230,14 @@ class SparkContext(object):
else:
SparkContext._active_spark_context = instance
+ def __getnewargs__(self):
+ # This method is called when attempting to pickle SparkContext, which is always an error:
+ raise Exception(
+ "It appears that you are attempting to reference SparkContext from a broadcast "
+ "variable, action, or transforamtion. SparkContext can only be used on the driver, "
+ "not in code that it run on workers. For more information, see SPARK-5063."
+ )
+
def __enter__(self):
"""
Enable 'with SparkContext(...) as sc: app(sc)' syntax.
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index bd2ff00c0f..f8b5f18253 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -141,6 +141,17 @@ class RDD(object):
def __repr__(self):
return self._jrdd.toString()
+ def __getnewargs__(self):
+ # This method is called when attempting to pickle an RDD, which is always an error:
+ raise Exception(
+ "It appears that you are attempting to broadcast an RDD or reference an RDD from an "
+ "action or transformation. RDD transformations and actions can only be invoked by the "
+ "driver, not inside of other transformations; for example, "
+ "rdd1.map(lambda x: rdd2.values.count() * x) is invalid because the values "
+ "transformation and count action cannot be performed inside of the rdd1.map "
+ "transformation. For more information, see SPARK-5063."
+ )
+
@property
def context(self):
"""