diff options
Diffstat (limited to 'core')
19 files changed, 120 insertions, 107 deletions
diff --git a/core/src/main/java/org/apache/spark/network/netty/FileClient.java b/core/src/main/java/org/apache/spark/network/netty/FileClient.java index edd0fc56f8..46d61503bc 100644 --- a/core/src/main/java/org/apache/spark/network/netty/FileClient.java +++ b/core/src/main/java/org/apache/spark/network/netty/FileClient.java @@ -20,19 +20,24 @@ package org.apache.spark.network.netty; import io.netty.bootstrap.Bootstrap; import io.netty.channel.Channel; import io.netty.channel.ChannelOption; +import io.netty.channel.EventLoopGroup; import io.netty.channel.oio.OioEventLoopGroup; import io.netty.channel.socket.oio.OioSocketChannel; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.concurrent.TimeUnit; + class FileClient { private Logger LOG = LoggerFactory.getLogger(this.getClass().getName()); - private FileClientHandler handler = null; + private final FileClientHandler handler; private Channel channel = null; private Bootstrap bootstrap = null; - private int connectTimeout = 60*1000; // 1 min + private EventLoopGroup group = null; + private final int connectTimeout; + private final int sendTimeout = 60; // 1 min public FileClient(FileClientHandler handler, int connectTimeout) { this.handler = handler; @@ -40,8 +45,9 @@ class FileClient { } public void init() { + group = new OioEventLoopGroup(); bootstrap = new Bootstrap(); - bootstrap.group(new OioEventLoopGroup()) + bootstrap.group(group) .channel(OioSocketChannel.class) .option(ChannelOption.SO_KEEPALIVE, true) .option(ChannelOption.TCP_NODELAY, true) @@ -56,6 +62,7 @@ class FileClient { // ChannelFuture cf = channel.closeFuture(); //cf.addListener(new ChannelCloseListener(this)); } catch (InterruptedException e) { + LOG.warn("FileClient interrupted while trying to connect", e); close(); } } @@ -71,16 +78,21 @@ class FileClient { public void sendRequest(String file) { //assert(file == null); //assert(channel == null); - channel.write(file + "\r\n"); + try { + // Should be able to send the message to network link channel. + boolean bSent = channel.writeAndFlush(file + "\r\n").await(sendTimeout, TimeUnit.SECONDS); + if (!bSent) { + throw new RuntimeException("Failed to send"); + } + } catch (InterruptedException e) { + LOG.error("Error", e); + } } public void close() { - if(channel != null) { - channel.close(); - channel = null; - } - if ( bootstrap!=null) { - bootstrap.shutdown(); + if (group != null) { + group.shutdownGracefully(); + group = null; bootstrap = null; } } diff --git a/core/src/main/java/org/apache/spark/network/netty/FileClientChannelInitializer.java b/core/src/main/java/org/apache/spark/network/netty/FileClientChannelInitializer.java index 65ee15d63b..fb61be1c12 100644 --- a/core/src/main/java/org/apache/spark/network/netty/FileClientChannelInitializer.java +++ b/core/src/main/java/org/apache/spark/network/netty/FileClientChannelInitializer.java @@ -17,15 +17,13 @@ package org.apache.spark.network.netty; -import io.netty.buffer.BufType; import io.netty.channel.ChannelInitializer; import io.netty.channel.socket.SocketChannel; import io.netty.handler.codec.string.StringEncoder; - class FileClientChannelInitializer extends ChannelInitializer<SocketChannel> { - private FileClientHandler fhandler; + private final FileClientHandler fhandler; public FileClientChannelInitializer(FileClientHandler handler) { fhandler = handler; @@ -35,7 +33,7 @@ class FileClientChannelInitializer extends ChannelInitializer<SocketChannel> { public void initChannel(SocketChannel channel) { // file no more than 2G channel.pipeline() - .addLast("encoder", new StringEncoder(BufType.BYTE)) + .addLast("encoder", new StringEncoder()) .addLast("handler", fhandler); } } diff --git a/core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java b/core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java index 8a09210245..63d3d92725 100644 --- a/core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java +++ b/core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java @@ -19,11 +19,11 @@ package org.apache.spark.network.netty; import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundByteHandlerAdapter; +import io.netty.channel.SimpleChannelInboundHandler; import org.apache.spark.storage.BlockId; -abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter { +abstract class FileClientHandler extends SimpleChannelInboundHandler<ByteBuf> { private FileHeader currentHeader = null; @@ -37,13 +37,7 @@ abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter { public abstract void handleError(BlockId blockId); @Override - public ByteBuf newInboundBuffer(ChannelHandlerContext ctx) { - // Use direct buffer if possible. - return ctx.alloc().ioBuffer(); - } - - @Override - public void inboundBufferUpdated(ChannelHandlerContext ctx, ByteBuf in) { + public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) { // get header if (currentHeader == null && in.readableBytes() >= FileHeader.HEADER_SIZE()) { currentHeader = FileHeader.create(in.readBytes(FileHeader.HEADER_SIZE())); diff --git a/core/src/main/java/org/apache/spark/network/netty/FileServer.java b/core/src/main/java/org/apache/spark/network/netty/FileServer.java index a99af348ce..aea7534459 100644 --- a/core/src/main/java/org/apache/spark/network/netty/FileServer.java +++ b/core/src/main/java/org/apache/spark/network/netty/FileServer.java @@ -22,13 +22,12 @@ import java.net.InetSocketAddress; import io.netty.bootstrap.ServerBootstrap; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelOption; +import io.netty.channel.EventLoopGroup; import io.netty.channel.oio.OioEventLoopGroup; import io.netty.channel.socket.oio.OioServerSocketChannel; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; - /** * Server that accept the path of a file an echo back its content. */ @@ -36,7 +35,8 @@ class FileServer { private Logger LOG = LoggerFactory.getLogger(this.getClass().getName()); - private ServerBootstrap bootstrap = null; + private EventLoopGroup bossGroup = null; + private EventLoopGroup workerGroup = null; private ChannelFuture channelFuture = null; private int port = 0; private Thread blockingThread = null; @@ -45,8 +45,11 @@ class FileServer { InetSocketAddress addr = new InetSocketAddress(port); // Configure the server. - bootstrap = new ServerBootstrap(); - bootstrap.group(new OioEventLoopGroup(), new OioEventLoopGroup()) + bossGroup = new OioEventLoopGroup(); + workerGroup = new OioEventLoopGroup(); + + ServerBootstrap bootstrap = new ServerBootstrap(); + bootstrap.group(bossGroup, workerGroup) .channel(OioServerSocketChannel.class) .option(ChannelOption.SO_BACKLOG, 100) .option(ChannelOption.SO_RCVBUF, 1500) @@ -89,13 +92,19 @@ class FileServer { public void stop() { // Close the bound channel. if (channelFuture != null) { - channelFuture.channel().close(); + channelFuture.channel().close().awaitUninterruptibly(); channelFuture = null; } - // Shutdown bootstrap. - if (bootstrap != null) { - bootstrap.shutdown(); - bootstrap = null; + + // Shutdown event groups + if (bossGroup != null) { + bossGroup.shutdownGracefully(); + bossGroup = null; + } + + if (workerGroup != null) { + workerGroup.shutdownGracefully(); + workerGroup = null; } // TODO: Shutdown all accepted channels as well ? } diff --git a/core/src/main/java/org/apache/spark/network/netty/FileServerChannelInitializer.java b/core/src/main/java/org/apache/spark/network/netty/FileServerChannelInitializer.java index 833af1632d..3f15ff898f 100644 --- a/core/src/main/java/org/apache/spark/network/netty/FileServerChannelInitializer.java +++ b/core/src/main/java/org/apache/spark/network/netty/FileServerChannelInitializer.java @@ -23,7 +23,6 @@ import io.netty.handler.codec.DelimiterBasedFrameDecoder; import io.netty.handler.codec.Delimiters; import io.netty.handler.codec.string.StringDecoder; - class FileServerChannelInitializer extends ChannelInitializer<SocketChannel> { PathResolver pResolver; @@ -36,7 +35,7 @@ class FileServerChannelInitializer extends ChannelInitializer<SocketChannel> { public void initChannel(SocketChannel channel) { channel.pipeline() .addLast("framer", new DelimiterBasedFrameDecoder(8192, Delimiters.lineDelimiter())) - .addLast("strDecoder", new StringDecoder()) + .addLast("stringDecoder", new StringDecoder()) .addLast("handler", new FileServerHandler(pResolver)); } } diff --git a/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java b/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java index 172c6e4b1c..e2d9391b4c 100644 --- a/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java +++ b/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java @@ -21,22 +21,26 @@ import java.io.File; import java.io.FileInputStream; import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundMessageHandlerAdapter; +import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.DefaultFileRegion; import org.apache.spark.storage.BlockId; import org.apache.spark.storage.FileSegment; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; -class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> { +class FileServerHandler extends SimpleChannelInboundHandler<String> { - PathResolver pResolver; + private Logger LOG = LoggerFactory.getLogger(this.getClass().getName()); + + private final PathResolver pResolver; public FileServerHandler(PathResolver pResolver){ this.pResolver = pResolver; } @Override - public void messageReceived(ChannelHandlerContext ctx, String blockIdString) { + public void channelRead0(ChannelHandlerContext ctx, String blockIdString) { BlockId blockId = BlockId.apply(blockIdString); FileSegment fileSegment = pResolver.getBlockLocation(blockId); // if getBlockLocation returns null, close the channel @@ -60,10 +64,10 @@ class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> { int len = new Long(length).intValue(); ctx.write((new FileHeader(len, blockId)).buffer()); try { - ctx.sendFile(new DefaultFileRegion(new FileInputStream(file) + ctx.write(new DefaultFileRegion(new FileInputStream(file) .getChannel(), fileSegment.offset(), fileSegment.length())); } catch (Exception e) { - e.printStackTrace(); + LOG.error("Exception: ", e); } } else { ctx.write(new FileHeader(0, blockId).buffer()); @@ -73,7 +77,7 @@ class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> { @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { - cause.printStackTrace(); + LOG.error("Exception: ", cause); ctx.close(); } } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 8134ce7eb3..fbc7a78bf5 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -19,7 +19,7 @@ package org.apache.spark import java.io._ import java.net.URI -import java.util.Properties +import java.util.{UUID, Properties} import java.util.concurrent.atomic.AtomicInteger import scala.collection.{Map, Set, immutable} @@ -897,22 +897,15 @@ class SparkContext( /** * Set the directory under which RDDs are going to be checkpointed. The directory must - * be a HDFS path if running on a cluster. If the directory does not exist, it will - * be created. If the directory exists and useExisting is set to true, then the - * exisiting directory will be used. Otherwise an exception will be thrown to - * prevent accidental overriding of checkpoint files in the existing directory. + * be a HDFS path if running on a cluster. */ - def setCheckpointDir(dir: String, useExisting: Boolean = false) { - val path = new Path(dir) - val fs = path.getFileSystem(SparkHadoopUtil.get.newConfiguration()) - if (!useExisting) { - if (fs.exists(path)) { - throw new Exception("Checkpoint directory '" + path + "' already exists.") - } else { - fs.mkdirs(path) - } + def setCheckpointDir(directory: String) { + checkpointDir = Option(directory).map { dir => + val path = new Path(dir, UUID.randomUUID().toString) + val fs = path.getFileSystem(hadoopConfiguration) + fs.mkdirs(path) + fs.getFileStatus(path).getPath().toString } - checkpointDir = Some(dir) } /** Default level of parallelism to use when not given by user (e.g. parallelize and makeRDD). */ diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index d6aeed7661..0680a065a3 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -394,20 +394,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork /** * Set the directory under which RDDs are going to be checkpointed. The directory must - * be a HDFS path if running on a cluster. If the directory does not exist, it will - * be created. If the directory exists and useExisting is set to true, then the - * exisiting directory will be used. Otherwise an exception will be thrown to - * prevent accidental overriding of checkpoint files in the existing directory. - */ - def setCheckpointDir(dir: String, useExisting: Boolean) { - sc.setCheckpointDir(dir, useExisting) - } - - /** - * Set the directory under which RDDs are going to be checkpointed. The directory must - * be a HDFS path if running on a cluster. If the directory does not exist, it will - * be created. If the directory exists, an exception will be thrown to prevent accidental - * overriding of checkpoint files. + * be a HDFS path if running on a cluster. */ def setCheckpointDir(dir: String) { sc.setCheckpointDir(dir) diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index 2897c4b841..172ba6b01c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -18,12 +18,12 @@ package org.apache.spark.rdd import java.io.IOException - import scala.reflect.ClassTag - -import org.apache.hadoop.fs.Path import org.apache.spark._ +import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {} @@ -34,6 +34,8 @@ private[spark] class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String) extends RDD[T](sc, Nil) { + val broadcastedConf = sc.broadcast(new SerializableWritable(sc.hadoopConfiguration)) + @transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration) override def getPartitions: Array[Partition] = { @@ -65,7 +67,7 @@ class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String) override def compute(split: Partition, context: TaskContext): Iterator[T] = { val file = new Path(checkpointPath, CheckpointRDD.splitIdToFile(split.index)) - CheckpointRDD.readFromFile(file, context) + CheckpointRDD.readFromFile(file, broadcastedConf, context) } override def checkpoint() { @@ -78,10 +80,14 @@ private[spark] object CheckpointRDD extends Logging { "part-%05d".format(splitId) } - def writeToFile[T](path: String, blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) { + def writeToFile[T]( + path: String, + broadcastedConf: Broadcast[SerializableWritable[Configuration]], + blockSize: Int = -1 + )(ctx: TaskContext, iterator: Iterator[T]) { val env = SparkEnv.get val outputDir = new Path(path) - val fs = outputDir.getFileSystem(SparkHadoopUtil.get.newConfiguration()) + val fs = outputDir.getFileSystem(broadcastedConf.value.value) val finalOutputName = splitIdToFile(ctx.partitionId) val finalOutputPath = new Path(outputDir, finalOutputName) @@ -118,9 +124,13 @@ private[spark] object CheckpointRDD extends Logging { } } - def readFromFile[T](path: Path, context: TaskContext): Iterator[T] = { + def readFromFile[T]( + path: Path, + broadcastedConf: Broadcast[SerializableWritable[Configuration]], + context: TaskContext + ): Iterator[T] = { val env = SparkEnv.get - val fs = path.getFileSystem(SparkHadoopUtil.get.newConfiguration()) + val fs = path.getFileSystem(broadcastedConf.value.value) val bufferSize = env.conf.getOrElse("spark.buffer.size", "65536").toInt val fileInputStream = fs.open(path, bufferSize) val serializer = env.serializer.newInstance() @@ -143,8 +153,10 @@ private[spark] object CheckpointRDD extends Logging { val sc = new SparkContext(cluster, "CheckpointRDD Test") val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000) val path = new Path(hdfsPath, "temp") - val fs = path.getFileSystem(SparkHadoopUtil.get.newConfiguration()) - sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 1024) _) + val conf = SparkHadoopUtil.get.newConfiguration() + val fs = path.getFileSystem(conf) + val broadcastedConf = sc.broadcast(new SerializableWritable(conf)) + sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, broadcastedConf, 1024) _) val cpRDD = new CheckpointRDD[Int](sc, path.toString) assert(cpRDD.partitions.length == rdd.partitions.length, "Number of partitions is not the same") assert(cpRDD.collect.toList == rdd.collect.toList, "Data of partitions not the same") diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index 3b56e45aa9..642dabaad5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -22,7 +22,7 @@ import scala.reflect.ClassTag import org.apache.hadoop.fs.Path import org.apache.hadoop.conf.Configuration -import org.apache.spark.{Partition, SparkException, Logging} +import org.apache.spark.{SerializableWritable, Partition, SparkException, Logging} import org.apache.spark.scheduler.{ResultTask, ShuffleMapTask} /** @@ -85,14 +85,21 @@ private[spark] class RDDCheckpointData[T: ClassTag](rdd: RDD[T]) // Create the output path for the checkpoint val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id) - val fs = path.getFileSystem(new Configuration()) + val fs = path.getFileSystem(rdd.context.hadoopConfiguration) if (!fs.mkdirs(path)) { throw new SparkException("Failed to create checkpoint path " + path) } // Save to file, and reload it as an RDD - rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path.toString) _) + val broadcastedConf = rdd.context.broadcast( + new SerializableWritable(rdd.context.hadoopConfiguration)) + rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path.toString, broadcastedConf) _) val newRDD = new CheckpointRDD[T](rdd.context, path.toString) + if (newRDD.partitions.size != rdd.partitions.size) { + throw new SparkException( + "Checkpoint RDD " + newRDD + "("+ newRDD.partitions.size + ") has different " + + "number of partitions than original RDD " + rdd + "(" + rdd.partitions.size + ")") + } // Change the dependencies and partitions of the RDD RDDCheckpointData.synchronized { @@ -101,8 +108,8 @@ private[spark] class RDDCheckpointData[T: ClassTag](rdd: RDD[T]) rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions cpState = Checkpointed RDDCheckpointData.clearTaskCaches() - logInfo("Done checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id) } + logInfo("Done checkpointing RDD " + rdd.id + " to " + path + ", new parent is RDD " + newRDD.id) } // Get preferred location of a split after checkpointing diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index e06e49d9d2..043e01dbfb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -828,7 +828,7 @@ class DAGScheduler( } logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime)) stageToInfos(stage).completionTime = Some(System.currentTimeMillis()) - listenerBus.post(StageCompleted(stageToInfos(stage))) + listenerBus.post(SparkListenerStageCompleted(stageToInfos(stage))) running -= stage } event.reason match { diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala index be5c95e59e..f8fa5a9f7a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala @@ -297,7 +297,7 @@ class JobLogger(val user: String, val logDirName: String) * When stage is completed, record stage completion status * @param stageCompleted Stage completed event */ - override def onStageCompleted(stageCompleted: StageCompleted) { + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) { stageLogInfo(stageCompleted.stage.stageId, "STAGE_ID=%d STATUS=COMPLETED".format( stageCompleted.stage.stageId)) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index ee63b3c4a1..627995c826 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -27,7 +27,7 @@ sealed trait SparkListenerEvents case class SparkListenerStageSubmitted(stage: StageInfo, properties: Properties) extends SparkListenerEvents -case class StageCompleted(val stage: StageInfo) extends SparkListenerEvents +case class SparkListenerStageCompleted(val stage: StageInfo) extends SparkListenerEvents case class SparkListenerTaskStart(task: Task[_], taskInfo: TaskInfo) extends SparkListenerEvents @@ -47,7 +47,7 @@ trait SparkListener { /** * Called when a stage is completed, with information on the completed stage */ - def onStageCompleted(stageCompleted: StageCompleted) { } + def onStageCompleted(stageCompleted: SparkListenerStageCompleted) { } /** * Called when a stage is submitted @@ -86,7 +86,7 @@ trait SparkListener { * Simple SparkListener that logs a few summary statistics when each stage completes */ class StatsReportListener extends SparkListener with Logging { - override def onStageCompleted(stageCompleted: StageCompleted) { + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) { import org.apache.spark.scheduler.StatsReportListener._ implicit val sc = stageCompleted this.logInfo("Finished stage: " + stageCompleted.stage) @@ -119,13 +119,17 @@ object StatsReportListener extends Logging { val probabilities = percentiles.map{_ / 100.0} val percentilesHeader = "\t" + percentiles.mkString("%\t") + "%" - def extractDoubleDistribution(stage:StageCompleted, getMetric: (TaskInfo,TaskMetrics) => Option[Double]): Option[Distribution] = { + def extractDoubleDistribution(stage: SparkListenerStageCompleted, + getMetric: (TaskInfo,TaskMetrics) => Option[Double]) + : Option[Distribution] = { Distribution(stage.stage.taskInfos.flatMap { case ((info,metric)) => getMetric(info, metric)}) } //is there some way to setup the types that I can get rid of this completely? - def extractLongDistribution(stage:StageCompleted, getMetric: (TaskInfo,TaskMetrics) => Option[Long]): Option[Distribution] = { + def extractLongDistribution(stage: SparkListenerStageCompleted, + getMetric: (TaskInfo,TaskMetrics) => Option[Long]) + : Option[Distribution] = { extractDoubleDistribution(stage, (info, metric) => getMetric(info,metric).map{_.toDouble}) } @@ -147,12 +151,12 @@ object StatsReportListener extends Logging { } def showDistribution(heading:String, format: String, getMetric: (TaskInfo,TaskMetrics) => Option[Double]) - (implicit stage: StageCompleted) { + (implicit stage: SparkListenerStageCompleted) { showDistribution(heading, extractDoubleDistribution(stage, getMetric), format) } def showBytesDistribution(heading:String, getMetric: (TaskInfo,TaskMetrics) => Option[Long]) - (implicit stage: StageCompleted) { + (implicit stage: SparkListenerStageCompleted) { showBytesDistribution(heading, extractLongDistribution(stage, getMetric)) } @@ -169,7 +173,7 @@ object StatsReportListener extends Logging { } def showMillisDistribution(heading: String, getMetric: (TaskInfo, TaskMetrics) => Option[Long]) - (implicit stage: StageCompleted) { + (implicit stage: SparkListenerStageCompleted) { showMillisDistribution(heading, extractLongDistribution(stage, getMetric)) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index 85687ea330..e7defd768b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -41,7 +41,7 @@ private[spark] class SparkListenerBus() extends Logging { event match {
case stageSubmitted: SparkListenerStageSubmitted =>
sparkListeners.foreach(_.onStageSubmitted(stageSubmitted))
- case stageCompleted: StageCompleted =>
+ case stageCompleted: SparkListenerStageCompleted =>
sparkListeners.foreach(_.onStageCompleted(stageCompleted))
case jobStart: SparkListenerJobStart =>
sparkListeners.foreach(_.onJobStart(jobStart))
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index d752e6f111..b99664ae00 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -114,10 +114,6 @@ private[spark] class TaskSetManager( // Task index, start and finish time for each task attempt (indexed by task ID) val taskInfos = new HashMap[Long, TaskInfo] - // Did the TaskSet fail? - var failed = false - var causeOfFailure = "" - // How frequently to reprint duplicate exceptions in full, in milliseconds val EXCEPTION_PRINT_INTERVAL = conf.getOrElse("spark.logging.exceptionPrintInterval", "10000").toLong @@ -558,8 +554,6 @@ private[spark] class TaskSetManager( } def abort(message: String) { - failed = true - causeOfFailure = message // TODO: Kill running tasks if we were not terminated due to a Mesos error sched.dagScheduler.taskSetFailed(taskSet, message) removeAllRunningTasks() diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index eed3544b70..315014d27d 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -61,7 +61,7 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList override def onJobStart(jobStart: SparkListenerJobStart) {} - override def onStageCompleted(stageCompleted: StageCompleted) = synchronized { + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) = synchronized { val stage = stageCompleted.stage poolToActiveStages(stageIdToPool(stage.stageId)) -= stage activeStages -= stage diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java b/core/src/test/scala/org/apache/spark/JavaAPISuite.java index 79913dc718..5e2899c97b 100644 --- a/core/src/test/scala/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java @@ -851,7 +851,7 @@ public class JavaAPISuite implements Serializable { public void checkpointAndComputation() { File tempDir = Files.createTempDir(); JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); - sc.setCheckpointDir(tempDir.getAbsolutePath(), true); + sc.setCheckpointDir(tempDir.getAbsolutePath()); Assert.assertEquals(false, rdd.isCheckpointed()); rdd.checkpoint(); rdd.count(); // Forces the DAG to cause a checkpoint @@ -863,7 +863,7 @@ public class JavaAPISuite implements Serializable { public void checkpointAndRestore() { File tempDir = Files.createTempDir(); JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); - sc.setCheckpointDir(tempDir.getAbsolutePath(), true); + sc.setCheckpointDir(tempDir.getAbsolutePath()); Assert.assertEquals(false, rdd.isCheckpointed()); rdd.checkpoint(); rdd.count(); // Forces the DAG to cause a checkpoint diff --git a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala index dd122615ad..5cc48ee00a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala @@ -117,7 +117,7 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = onTaskEndCount += 1 override def onJobEnd(jobEnd: SparkListenerJobEnd) = onJobEndCount += 1 override def onJobStart(jobStart: SparkListenerJobStart) = onJobStartCount += 1 - override def onStageCompleted(stageCompleted: StageCompleted) = onStageCompletedCount += 1 + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) = onStageCompletedCount += 1 override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) = onStageSubmittedCount += 1 } sc.addSparkListener(joblogger) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index d4320e5e14..1a16e438c4 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -174,7 +174,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc class SaveStageInfo extends SparkListener { val stageInfos = Buffer[StageInfo]() - override def onStageCompleted(stage: StageCompleted) { + override def onStageCompleted(stage: SparkListenerStageCompleted) { stageInfos += stage.stage } } |