diff options
author | walker <walker@transwarp01.(none)> | 2014-01-07 01:05:18 +0800 |
---|---|---|
committer | walker <walker@transwarp01.(none)> | 2014-01-07 01:05:18 +0800 |
commit | a0c6d96e270fc0bd26bfb7da0e3b2e40d934e1eb (patch) | |
tree | db554652e72b7bb234d5f56acaa3d4068331ce74 /core | |
parent | 0af4b4f3e86791dc47673be832e77e51b8a8ebcc (diff) | |
parent | a2e7e0497484554f86bd71e93705eb0422b1512b (diff) | |
download | spark-a0c6d96e270fc0bd26bfb7da0e3b2e40d934e1eb.tar.gz spark-a0c6d96e270fc0bd26bfb7da0e3b2e40d934e1eb.tar.bz2 spark-a0c6d96e270fc0bd26bfb7da0e3b2e40d934e1eb.zip |
Merge remote branch 'upstream/master'
Diffstat (limited to 'core')
140 files changed, 3418 insertions, 2742 deletions
diff --git a/core/pom.xml b/core/pom.xml index 043f6cf68d..aac0a9d11e 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -17,215 +17,219 @@ --> <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> - <modelVersion>4.0.0</modelVersion> - <parent> - <groupId>org.apache.spark</groupId> - <artifactId>spark-parent</artifactId> - <version>0.9.0-incubating-SNAPSHOT</version> - <relativePath>../pom.xml</relativePath> - </parent> + <modelVersion>4.0.0</modelVersion> + <parent> + <groupId>org.apache.spark</groupId> + <artifactId>spark-parent</artifactId> + <version>0.9.0-incubating-SNAPSHOT</version> + <relativePath>../pom.xml</relativePath> + </parent> - <groupId>org.apache.spark</groupId> - <artifactId>spark-core_2.10</artifactId> - <packaging>jar</packaging> - <name>Spark Project Core</name> - <url>http://spark.incubator.apache.org/</url> + <groupId>org.apache.spark</groupId> + <artifactId>spark-core_2.10</artifactId> + <packaging>jar</packaging> + <name>Spark Project Core</name> + <url>http://spark.incubator.apache.org/</url> - <dependencies> - <dependency> - <groupId>org.apache.hadoop</groupId> - <artifactId>hadoop-client</artifactId> - </dependency> - <dependency> - <groupId>net.java.dev.jets3t</groupId> - <artifactId>jets3t</artifactId> - </dependency> - <dependency> - <groupId>org.apache.avro</groupId> - <artifactId>avro</artifactId> - </dependency> - <dependency> - <groupId>org.apache.avro</groupId> - <artifactId>avro-ipc</artifactId> - </dependency> - <dependency> - <groupId>org.apache.zookeeper</groupId> - <artifactId>zookeeper</artifactId> - </dependency> - <dependency> - <groupId>org.eclipse.jetty</groupId> - <artifactId>jetty-server</artifactId> - </dependency> - <dependency> - <groupId>com.google.guava</groupId> - <artifactId>guava</artifactId> - </dependency> - <dependency> - <groupId>com.google.code.findbugs</groupId> - <artifactId>jsr305</artifactId> - </dependency> - <dependency> - <groupId>org.slf4j</groupId> - <artifactId>slf4j-api</artifactId> - </dependency> - <dependency> - <groupId>com.ning</groupId> - <artifactId>compress-lzf</artifactId> - </dependency> - <dependency> - <groupId>org.xerial.snappy</groupId> - <artifactId>snappy-java</artifactId> - </dependency> - <dependency> - <groupId>org.ow2.asm</groupId> - <artifactId>asm</artifactId> - </dependency> - <dependency> - <groupId>com.twitter</groupId> - <artifactId>chill_${scala.binary.version}</artifactId> - <version>0.3.1</version> - </dependency> - <dependency> - <groupId>com.twitter</groupId> - <artifactId>chill-java</artifactId> - <version>0.3.1</version> - </dependency> - <dependency> - <groupId>${akka.group}</groupId> - <artifactId>akka-remote_${scala.binary.version}</artifactId> - </dependency> - <dependency> - <groupId>${akka.group}</groupId> - <artifactId>akka-slf4j_${scala.binary.version}</artifactId> - </dependency> - <dependency> - <groupId>org.scala-lang</groupId> - <artifactId>scala-library</artifactId> - </dependency> - <dependency> - <groupId>net.liftweb</groupId> - <artifactId>lift-json_${scala.binary.version}</artifactId> - </dependency> - <dependency> - <groupId>it.unimi.dsi</groupId> - <artifactId>fastutil</artifactId> - </dependency> - <dependency> - <groupId>colt</groupId> - <artifactId>colt</artifactId> - </dependency> - <dependency> - <groupId>org.apache.mesos</groupId> - <artifactId>mesos</artifactId> - </dependency> - <dependency> - <groupId>io.netty</groupId> - <artifactId>netty-all</artifactId> - </dependency> - <dependency> - <groupId>log4j</groupId> - <artifactId>log4j</artifactId> - </dependency> - <dependency> - <groupId>com.codahale.metrics</groupId> - <artifactId>metrics-core</artifactId> - </dependency> - <dependency> - <groupId>com.codahale.metrics</groupId> - <artifactId>metrics-jvm</artifactId> - </dependency> - <dependency> - <groupId>com.codahale.metrics</groupId> - <artifactId>metrics-json</artifactId> - </dependency> - <dependency> - <groupId>com.codahale.metrics</groupId> - <artifactId>metrics-ganglia</artifactId> - </dependency> - <dependency> - <groupId>com.codahale.metrics</groupId> - <artifactId>metrics-graphite</artifactId> - </dependency> - <dependency> - <groupId>org.apache.derby</groupId> - <artifactId>derby</artifactId> - <scope>test</scope> - </dependency> - <dependency> - <groupId>commons-io</groupId> - <artifactId>commons-io</artifactId> - <scope>test</scope> - </dependency> - <dependency> - <groupId>org.scalatest</groupId> - <artifactId>scalatest_${scala.binary.version}</artifactId> - <scope>test</scope> - </dependency> - <dependency> - <groupId>org.scalacheck</groupId> - <artifactId>scalacheck_${scala.binary.version}</artifactId> - <scope>test</scope> - </dependency> - <dependency> - <groupId>org.easymock</groupId> - <artifactId>easymock</artifactId> - <scope>test</scope> - </dependency> - <dependency> - <groupId>com.novocode</groupId> - <artifactId>junit-interface</artifactId> - <scope>test</scope> - </dependency> - <dependency> - <groupId>org.slf4j</groupId> - <artifactId>slf4j-log4j12</artifactId> - <scope>test</scope> - </dependency> - </dependencies> - <build> - <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory> - <testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory> - <plugins> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-antrun-plugin</artifactId> - <executions> - <execution> - <phase>test</phase> - <goals> - <goal>run</goal> - </goals> - <configuration> - <exportAntProperties>true</exportAntProperties> - <tasks> - <property name="spark.classpath" refid="maven.test.classpath" /> - <property environment="env" /> - <fail message="Please set the SCALA_HOME (or SCALA_LIBRARY_PATH if scala is on the path) environment variables and retry."> - <condition> - <not> - <or> - <isset property="env.SCALA_HOME" /> - <isset property="env.SCALA_LIBRARY_PATH" /> - </or> - </not> - </condition> - </fail> - </tasks> - </configuration> - </execution> - </executions> - </plugin> - <plugin> - <groupId>org.scalatest</groupId> - <artifactId>scalatest-maven-plugin</artifactId> - <configuration> - <environmentVariables> - <SPARK_HOME>${basedir}/..</SPARK_HOME> - <SPARK_TESTING>1</SPARK_TESTING> - <SPARK_CLASSPATH>${spark.classpath}</SPARK_CLASSPATH> - </environmentVariables> - </configuration> - </plugin> - </plugins> - </build> + <dependencies> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-client</artifactId> + </dependency> + <dependency> + <groupId>net.java.dev.jets3t</groupId> + <artifactId>jets3t</artifactId> + </dependency> + <dependency> + <groupId>org.apache.avro</groupId> + <artifactId>avro</artifactId> + </dependency> + <dependency> + <groupId>org.apache.avro</groupId> + <artifactId>avro-ipc</artifactId> + </dependency> + <dependency> + <groupId>org.apache.zookeeper</groupId> + <artifactId>zookeeper</artifactId> + </dependency> + <dependency> + <groupId>org.eclipse.jetty</groupId> + <artifactId>jetty-server</artifactId> + </dependency> + <dependency> + <groupId>com.google.guava</groupId> + <artifactId>guava</artifactId> + </dependency> + <dependency> + <groupId>com.google.code.findbugs</groupId> + <artifactId>jsr305</artifactId> + </dependency> + <dependency> + <groupId>org.slf4j</groupId> + <artifactId>slf4j-api</artifactId> + </dependency> + <dependency> + <groupId>com.ning</groupId> + <artifactId>compress-lzf</artifactId> + </dependency> + <dependency> + <groupId>org.xerial.snappy</groupId> + <artifactId>snappy-java</artifactId> + </dependency> + <dependency> + <groupId>org.ow2.asm</groupId> + <artifactId>asm</artifactId> + </dependency> + <dependency> + <groupId>com.twitter</groupId> + <artifactId>chill_${scala.binary.version}</artifactId> + <version>0.3.1</version> + </dependency> + <dependency> + <groupId>com.twitter</groupId> + <artifactId>chill-java</artifactId> + <version>0.3.1</version> + </dependency> + <dependency> + <groupId>${akka.group}</groupId> + <artifactId>akka-remote_${scala.binary.version}</artifactId> + </dependency> + <dependency> + <groupId>${akka.group}</groupId> + <artifactId>akka-slf4j_${scala.binary.version}</artifactId> + </dependency> + <dependency> + <groupId>org.scala-lang</groupId> + <artifactId>scala-library</artifactId> + </dependency> + <dependency> + <groupId>net.liftweb</groupId> + <artifactId>lift-json_${scala.binary.version}</artifactId> + </dependency> + <dependency> + <groupId>it.unimi.dsi</groupId> + <artifactId>fastutil</artifactId> + </dependency> + <dependency> + <groupId>colt</groupId> + <artifactId>colt</artifactId> + </dependency> + <dependency> + <groupId>org.apache.mesos</groupId> + <artifactId>mesos</artifactId> + </dependency> + <dependency> + <groupId>io.netty</groupId> + <artifactId>netty-all</artifactId> + </dependency> + <dependency> + <groupId>log4j</groupId> + <artifactId>log4j</artifactId> + </dependency> + <dependency> + <groupId>com.clearspring.analytics</groupId> + <artifactId>stream</artifactId> + </dependency> + <dependency> + <groupId>com.codahale.metrics</groupId> + <artifactId>metrics-core</artifactId> + </dependency> + <dependency> + <groupId>com.codahale.metrics</groupId> + <artifactId>metrics-jvm</artifactId> + </dependency> + <dependency> + <groupId>com.codahale.metrics</groupId> + <artifactId>metrics-json</artifactId> + </dependency> + <dependency> + <groupId>com.codahale.metrics</groupId> + <artifactId>metrics-ganglia</artifactId> + </dependency> + <dependency> + <groupId>com.codahale.metrics</groupId> + <artifactId>metrics-graphite</artifactId> + </dependency> + <dependency> + <groupId>org.apache.derby</groupId> + <artifactId>derby</artifactId> + <scope>test</scope> + </dependency> + <dependency> + <groupId>commons-io</groupId> + <artifactId>commons-io</artifactId> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.scalatest</groupId> + <artifactId>scalatest_${scala.binary.version}</artifactId> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.scalacheck</groupId> + <artifactId>scalacheck_${scala.binary.version}</artifactId> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.easymock</groupId> + <artifactId>easymock</artifactId> + <scope>test</scope> + </dependency> + <dependency> + <groupId>com.novocode</groupId> + <artifactId>junit-interface</artifactId> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.slf4j</groupId> + <artifactId>slf4j-log4j12</artifactId> + <scope>test</scope> + </dependency> + </dependencies> + <build> + <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory> + <testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory> + <plugins> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-antrun-plugin</artifactId> + <executions> + <execution> + <phase>test</phase> + <goals> + <goal>run</goal> + </goals> + <configuration> + <exportAntProperties>true</exportAntProperties> + <tasks> + <property name="spark.classpath" refid="maven.test.classpath" /> + <property environment="env" /> + <fail message="Please set the SCALA_HOME (or SCALA_LIBRARY_PATH if scala is on the path) environment variables and retry."> + <condition> + <not> + <or> + <isset property="env.SCALA_HOME" /> + <isset property="env.SCALA_LIBRARY_PATH" /> + </or> + </not> + </condition> + </fail> + </tasks> + </configuration> + </execution> + </executions> + </plugin> + <plugin> + <groupId>org.scalatest</groupId> + <artifactId>scalatest-maven-plugin</artifactId> + <configuration> + <environmentVariables> + <SPARK_HOME>${basedir}/..</SPARK_HOME> + <SPARK_TESTING>1</SPARK_TESTING> + <SPARK_CLASSPATH>${spark.classpath}</SPARK_CLASSPATH> + </environmentVariables> + </configuration> + </plugin> + </plugins> + </build> </project> diff --git a/core/src/main/java/org/apache/spark/network/netty/FileClient.java b/core/src/main/java/org/apache/spark/network/netty/FileClient.java index edd0fc56f8..46d61503bc 100644 --- a/core/src/main/java/org/apache/spark/network/netty/FileClient.java +++ b/core/src/main/java/org/apache/spark/network/netty/FileClient.java @@ -20,19 +20,24 @@ package org.apache.spark.network.netty; import io.netty.bootstrap.Bootstrap; import io.netty.channel.Channel; import io.netty.channel.ChannelOption; +import io.netty.channel.EventLoopGroup; import io.netty.channel.oio.OioEventLoopGroup; import io.netty.channel.socket.oio.OioSocketChannel; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.concurrent.TimeUnit; + class FileClient { private Logger LOG = LoggerFactory.getLogger(this.getClass().getName()); - private FileClientHandler handler = null; + private final FileClientHandler handler; private Channel channel = null; private Bootstrap bootstrap = null; - private int connectTimeout = 60*1000; // 1 min + private EventLoopGroup group = null; + private final int connectTimeout; + private final int sendTimeout = 60; // 1 min public FileClient(FileClientHandler handler, int connectTimeout) { this.handler = handler; @@ -40,8 +45,9 @@ class FileClient { } public void init() { + group = new OioEventLoopGroup(); bootstrap = new Bootstrap(); - bootstrap.group(new OioEventLoopGroup()) + bootstrap.group(group) .channel(OioSocketChannel.class) .option(ChannelOption.SO_KEEPALIVE, true) .option(ChannelOption.TCP_NODELAY, true) @@ -56,6 +62,7 @@ class FileClient { // ChannelFuture cf = channel.closeFuture(); //cf.addListener(new ChannelCloseListener(this)); } catch (InterruptedException e) { + LOG.warn("FileClient interrupted while trying to connect", e); close(); } } @@ -71,16 +78,21 @@ class FileClient { public void sendRequest(String file) { //assert(file == null); //assert(channel == null); - channel.write(file + "\r\n"); + try { + // Should be able to send the message to network link channel. + boolean bSent = channel.writeAndFlush(file + "\r\n").await(sendTimeout, TimeUnit.SECONDS); + if (!bSent) { + throw new RuntimeException("Failed to send"); + } + } catch (InterruptedException e) { + LOG.error("Error", e); + } } public void close() { - if(channel != null) { - channel.close(); - channel = null; - } - if ( bootstrap!=null) { - bootstrap.shutdown(); + if (group != null) { + group.shutdownGracefully(); + group = null; bootstrap = null; } } diff --git a/core/src/main/java/org/apache/spark/network/netty/FileClientChannelInitializer.java b/core/src/main/java/org/apache/spark/network/netty/FileClientChannelInitializer.java index 65ee15d63b..fb61be1c12 100644 --- a/core/src/main/java/org/apache/spark/network/netty/FileClientChannelInitializer.java +++ b/core/src/main/java/org/apache/spark/network/netty/FileClientChannelInitializer.java @@ -17,15 +17,13 @@ package org.apache.spark.network.netty; -import io.netty.buffer.BufType; import io.netty.channel.ChannelInitializer; import io.netty.channel.socket.SocketChannel; import io.netty.handler.codec.string.StringEncoder; - class FileClientChannelInitializer extends ChannelInitializer<SocketChannel> { - private FileClientHandler fhandler; + private final FileClientHandler fhandler; public FileClientChannelInitializer(FileClientHandler handler) { fhandler = handler; @@ -35,7 +33,7 @@ class FileClientChannelInitializer extends ChannelInitializer<SocketChannel> { public void initChannel(SocketChannel channel) { // file no more than 2G channel.pipeline() - .addLast("encoder", new StringEncoder(BufType.BYTE)) + .addLast("encoder", new StringEncoder()) .addLast("handler", fhandler); } } diff --git a/core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java b/core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java index 8a09210245..63d3d92725 100644 --- a/core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java +++ b/core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java @@ -19,11 +19,11 @@ package org.apache.spark.network.netty; import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundByteHandlerAdapter; +import io.netty.channel.SimpleChannelInboundHandler; import org.apache.spark.storage.BlockId; -abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter { +abstract class FileClientHandler extends SimpleChannelInboundHandler<ByteBuf> { private FileHeader currentHeader = null; @@ -37,13 +37,7 @@ abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter { public abstract void handleError(BlockId blockId); @Override - public ByteBuf newInboundBuffer(ChannelHandlerContext ctx) { - // Use direct buffer if possible. - return ctx.alloc().ioBuffer(); - } - - @Override - public void inboundBufferUpdated(ChannelHandlerContext ctx, ByteBuf in) { + public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) { // get header if (currentHeader == null && in.readableBytes() >= FileHeader.HEADER_SIZE()) { currentHeader = FileHeader.create(in.readBytes(FileHeader.HEADER_SIZE())); diff --git a/core/src/main/java/org/apache/spark/network/netty/FileServer.java b/core/src/main/java/org/apache/spark/network/netty/FileServer.java index a99af348ce..aea7534459 100644 --- a/core/src/main/java/org/apache/spark/network/netty/FileServer.java +++ b/core/src/main/java/org/apache/spark/network/netty/FileServer.java @@ -22,13 +22,12 @@ import java.net.InetSocketAddress; import io.netty.bootstrap.ServerBootstrap; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelOption; +import io.netty.channel.EventLoopGroup; import io.netty.channel.oio.OioEventLoopGroup; import io.netty.channel.socket.oio.OioServerSocketChannel; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; - /** * Server that accept the path of a file an echo back its content. */ @@ -36,7 +35,8 @@ class FileServer { private Logger LOG = LoggerFactory.getLogger(this.getClass().getName()); - private ServerBootstrap bootstrap = null; + private EventLoopGroup bossGroup = null; + private EventLoopGroup workerGroup = null; private ChannelFuture channelFuture = null; private int port = 0; private Thread blockingThread = null; @@ -45,8 +45,11 @@ class FileServer { InetSocketAddress addr = new InetSocketAddress(port); // Configure the server. - bootstrap = new ServerBootstrap(); - bootstrap.group(new OioEventLoopGroup(), new OioEventLoopGroup()) + bossGroup = new OioEventLoopGroup(); + workerGroup = new OioEventLoopGroup(); + + ServerBootstrap bootstrap = new ServerBootstrap(); + bootstrap.group(bossGroup, workerGroup) .channel(OioServerSocketChannel.class) .option(ChannelOption.SO_BACKLOG, 100) .option(ChannelOption.SO_RCVBUF, 1500) @@ -89,13 +92,19 @@ class FileServer { public void stop() { // Close the bound channel. if (channelFuture != null) { - channelFuture.channel().close(); + channelFuture.channel().close().awaitUninterruptibly(); channelFuture = null; } - // Shutdown bootstrap. - if (bootstrap != null) { - bootstrap.shutdown(); - bootstrap = null; + + // Shutdown event groups + if (bossGroup != null) { + bossGroup.shutdownGracefully(); + bossGroup = null; + } + + if (workerGroup != null) { + workerGroup.shutdownGracefully(); + workerGroup = null; } // TODO: Shutdown all accepted channels as well ? } diff --git a/core/src/main/java/org/apache/spark/network/netty/FileServerChannelInitializer.java b/core/src/main/java/org/apache/spark/network/netty/FileServerChannelInitializer.java index 833af1632d..3f15ff898f 100644 --- a/core/src/main/java/org/apache/spark/network/netty/FileServerChannelInitializer.java +++ b/core/src/main/java/org/apache/spark/network/netty/FileServerChannelInitializer.java @@ -23,7 +23,6 @@ import io.netty.handler.codec.DelimiterBasedFrameDecoder; import io.netty.handler.codec.Delimiters; import io.netty.handler.codec.string.StringDecoder; - class FileServerChannelInitializer extends ChannelInitializer<SocketChannel> { PathResolver pResolver; @@ -36,7 +35,7 @@ class FileServerChannelInitializer extends ChannelInitializer<SocketChannel> { public void initChannel(SocketChannel channel) { channel.pipeline() .addLast("framer", new DelimiterBasedFrameDecoder(8192, Delimiters.lineDelimiter())) - .addLast("strDecoder", new StringDecoder()) + .addLast("stringDecoder", new StringDecoder()) .addLast("handler", new FileServerHandler(pResolver)); } } diff --git a/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java b/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java index 172c6e4b1c..e2d9391b4c 100644 --- a/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java +++ b/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java @@ -21,22 +21,26 @@ import java.io.File; import java.io.FileInputStream; import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundMessageHandlerAdapter; +import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.DefaultFileRegion; import org.apache.spark.storage.BlockId; import org.apache.spark.storage.FileSegment; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; -class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> { +class FileServerHandler extends SimpleChannelInboundHandler<String> { - PathResolver pResolver; + private Logger LOG = LoggerFactory.getLogger(this.getClass().getName()); + + private final PathResolver pResolver; public FileServerHandler(PathResolver pResolver){ this.pResolver = pResolver; } @Override - public void messageReceived(ChannelHandlerContext ctx, String blockIdString) { + public void channelRead0(ChannelHandlerContext ctx, String blockIdString) { BlockId blockId = BlockId.apply(blockIdString); FileSegment fileSegment = pResolver.getBlockLocation(blockId); // if getBlockLocation returns null, close the channel @@ -60,10 +64,10 @@ class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> { int len = new Long(length).intValue(); ctx.write((new FileHeader(len, blockId)).buffer()); try { - ctx.sendFile(new DefaultFileRegion(new FileInputStream(file) + ctx.write(new DefaultFileRegion(new FileInputStream(file) .getChannel(), fileSegment.offset(), fileSegment.length())); } catch (Exception e) { - e.printStackTrace(); + LOG.error("Exception: ", e); } } else { ctx.write(new FileHeader(0, blockId).buffer()); @@ -73,7 +77,7 @@ class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> { @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { - cause.printStackTrace(); + LOG.error("Exception: ", cause); ctx.close(); } } diff --git a/core/src/main/resources/org/apache/spark/default-log4j.properties b/core/src/main/resources/org/apache/spark/default-log4j.properties new file mode 100644 index 0000000000..d72dbadc39 --- /dev/null +++ b/core/src/main/resources/org/apache/spark/default-log4j.properties @@ -0,0 +1,8 @@ +# Set everything to be logged to the console +log4j.rootCategory=INFO, console +log4j.appender.console=org.apache.log4j.ConsoleAppender +log4j.appender.console.layout=org.apache.log4j.PatternLayout +log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n + +# Ignore messages below warning level from Jetty, because it's a bit verbose +log4j.logger.org.eclipse.jetty=WARN diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index 6e922a612a..5f73d234aa 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -41,7 +41,7 @@ class Accumulable[R, T] ( @transient initialValue: R, param: AccumulableParam[R, T]) extends Serializable { - + val id = Accumulators.newId @transient private var value_ = initialValue // Current value on master val zero = param.zero(initialValue) // Zero value to be passed to workers @@ -113,7 +113,7 @@ class Accumulable[R, T] ( def setValue(newValue: R) { this.value = newValue } - + // Called by Java when deserializing an object private def readObject(in: ObjectInputStream) { in.defaultReadObject() @@ -177,7 +177,7 @@ class GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Ser def zero(initialValue: R): R = { // We need to clone initialValue, but it's hard to specify that R should also be Cloneable. // Instead we'll serialize it to a buffer and load it back. - val ser = new JavaSerializer().newInstance() + val ser = new JavaSerializer(new SparkConf(false)).newInstance() val copy = ser.deserialize[R](ser.serialize(initialValue)) copy.clear() // In case it contained stuff copy @@ -215,7 +215,7 @@ private object Accumulators { val originals = Map[Long, Accumulable[_, _]]() val localAccums = Map[Thread, Map[Long, Accumulable[_, _]]]() var lastId: Long = 0 - + def newId: Long = synchronized { lastId += 1 return lastId diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala index cdfc9dd54e..69a738dc44 100644 --- a/core/src/main/scala/org/apache/spark/HttpServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpServer.scala @@ -46,6 +46,7 @@ private[spark] class HttpServer(resourceBase: File) extends Logging { if (server != null) { throw new ServerStateException("Server is already started") } else { + logInfo("Starting HTTP Server") server = new Server() val connector = new SocketConnector connector.setMaxIdleTime(60*1000) diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala index 6a973ea495..d519fc5a29 100644 --- a/core/src/main/scala/org/apache/spark/Logging.scala +++ b/core/src/main/scala/org/apache/spark/Logging.scala @@ -17,8 +17,8 @@ package org.apache.spark -import org.slf4j.Logger -import org.slf4j.LoggerFactory +import org.apache.log4j.{LogManager, PropertyConfigurator} +import org.slf4j.{Logger, LoggerFactory} /** * Utility trait for classes that want to log data. Creates a SLF4J logger for the class and allows @@ -33,6 +33,7 @@ trait Logging { // Method to get or create the logger for this object protected def log: Logger = { if (log_ == null) { + initializeIfNecessary() var className = this.getClass.getName // Ignore trailing $'s in the class names for Scala objects if (className.endsWith("$")) { @@ -89,7 +90,37 @@ trait Logging { log.isTraceEnabled } - // Method for ensuring that logging is initialized, to avoid having multiple - // threads do it concurrently (as SLF4J initialization is not thread safe). - protected def initLogging() { log } + private def initializeIfNecessary() { + if (!Logging.initialized) { + Logging.initLock.synchronized { + if (!Logging.initialized) { + initializeLogging() + } + } + } + } + + private def initializeLogging() { + // If Log4j doesn't seem initialized, load a default properties file + val log4jInitialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements + if (!log4jInitialized) { + val defaultLogProps = "org/apache/spark/default-log4j.properties" + val classLoader = this.getClass.getClassLoader + Option(classLoader.getResource(defaultLogProps)) match { + case Some(url) => PropertyConfigurator.configure(url) + case None => System.err.println(s"Spark was unable to load $defaultLogProps") + } + log.info(s"Using Spark's default log4j profile: $defaultLogProps") + } + Logging.initialized = true + + // Force a call into slf4j to initialize it. Avoids this happening from mutliple threads + // and triggering this: http://mailman.qos.ch/pipermail/slf4j-dev/2010-April/002956.html + log + } +} + +object Logging { + @volatile private var initialized = false + val initLock = new Object() } diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index ccffcc356c..cdae167aef 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -50,9 +50,9 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster } } -private[spark] class MapOutputTracker extends Logging { +private[spark] class MapOutputTracker(conf: SparkConf) extends Logging { - private val timeout = AkkaUtils.askTimeout + private val timeout = AkkaUtils.askTimeout(conf) // Set to the MapOutputTrackerActor living on the driver var trackerActor: Either[ActorRef, ActorSelection] = _ @@ -65,7 +65,7 @@ private[spark] class MapOutputTracker extends Logging { protected val epochLock = new java.lang.Object private val metadataCleaner = - new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup) + new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup, conf) // Send a message to the trackerActor and get its result within a default timeout, or // throw a SparkException if this fails. @@ -129,7 +129,7 @@ private[spark] class MapOutputTracker extends Logging { if (fetchedStatuses == null) { // We won the race to fetch the output locs; do so logInfo("Doing the fetch; tracker actor = " + trackerActor) - val hostPort = Utils.localHostPort() + val hostPort = Utils.localHostPort(conf) // This try-finally prevents hangs due to timeouts: try { val fetchedBytes = @@ -192,7 +192,8 @@ private[spark] class MapOutputTracker extends Logging { } } -private[spark] class MapOutputTrackerMaster extends MapOutputTracker { +private[spark] class MapOutputTrackerMaster(conf: SparkConf) + extends MapOutputTracker(conf) { // Cache a serialized version of the output statuses for each shuffle to send them out faster private var cacheEpoch = epoch diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index bcec41c439..31b0773bfe 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -52,7 +52,7 @@ object Partitioner { for (r <- bySize if r.partitioner != None) { return r.partitioner.get } - if (System.getProperty("spark.default.parallelism") != null) { + if (rdd.context.conf.contains("spark.default.parallelism")) { return new HashPartitioner(rdd.context.defaultParallelism) } else { return new HashPartitioner(bySize.head.partitions.size) @@ -90,7 +90,7 @@ class HashPartitioner(partitions: Int) extends Partitioner { class RangePartitioner[K <% Ordered[K]: ClassTag, V]( partitions: Int, @transient rdd: RDD[_ <: Product2[K,V]], - private val ascending: Boolean = true) + private val ascending: Boolean = true) extends Partitioner { // An array of upper bounds for the first (partitions - 1) partitions diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala new file mode 100644 index 0000000000..98343e9532 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -0,0 +1,189 @@ +package org.apache.spark + +import scala.collection.JavaConverters._ +import scala.collection.mutable.HashMap + +import com.typesafe.config.ConfigFactory + +/** + * Configuration for a Spark application. Used to set various Spark parameters as key-value pairs. + * + * Most of the time, you would create a SparkConf object with `new SparkConf()`, which will load + * values from both the `spark.*` Java system properties and any `spark.conf` on your application's + * classpath (if it has one). In this case, system properties take priority over `spark.conf`, and + * any parameters you set directly on the `SparkConf` object take priority over both of those. + * + * For unit tests, you can also call `new SparkConf(false)` to skip loading external settings and + * get the same configuration no matter what is on the classpath. + * + * All setter methods in this class support chaining. For example, you can write + * `new SparkConf().setMaster("local").setAppName("My app")`. + * + * Note that once a SparkConf object is passed to Spark, it is cloned and can no longer be modified + * by the user. Spark does not support modifying the configuration at runtime. + * + * @param loadDefaults whether to load values from the system properties and classpath + */ +class SparkConf(loadDefaults: Boolean) extends Serializable with Cloneable { + + /** Create a SparkConf that loads defaults from system properties and the classpath */ + def this() = this(true) + + private val settings = new HashMap[String, String]() + + if (loadDefaults) { + ConfigFactory.invalidateCaches() + val typesafeConfig = ConfigFactory.systemProperties() + .withFallback(ConfigFactory.parseResources("spark.conf")) + for (e <- typesafeConfig.entrySet().asScala if e.getKey.startsWith("spark.")) { + settings(e.getKey) = e.getValue.unwrapped.toString + } + } + + /** Set a configuration variable. */ + def set(key: String, value: String): SparkConf = { + if (key == null) { + throw new NullPointerException("null key") + } + if (value == null) { + throw new NullPointerException("null value") + } + settings(key) = value + this + } + + /** + * The master URL to connect to, such as "local" to run locally with one thread, "local[4]" to + * run locally with 4 cores, or "spark://master:7077" to run on a Spark standalone cluster. + */ + def setMaster(master: String): SparkConf = { + set("spark.master", master) + } + + /** Set a name for your application. Shown in the Spark web UI. */ + def setAppName(name: String): SparkConf = { + set("spark.app.name", name) + } + + /** Set JAR files to distribute to the cluster. */ + def setJars(jars: Seq[String]): SparkConf = { + set("spark.jars", jars.mkString(",")) + } + + /** Set JAR files to distribute to the cluster. (Java-friendly version.) */ + def setJars(jars: Array[String]): SparkConf = { + setJars(jars.toSeq) + } + + /** + * Set an environment variable to be used when launching executors for this application. + * These variables are stored as properties of the form spark.executorEnv.VAR_NAME + * (for example spark.executorEnv.PATH) but this method makes them easier to set. + */ + def setExecutorEnv(variable: String, value: String): SparkConf = { + set("spark.executorEnv." + variable, value) + } + + /** + * Set multiple environment variables to be used when launching executors. + * These variables are stored as properties of the form spark.executorEnv.VAR_NAME + * (for example spark.executorEnv.PATH) but this method makes them easier to set. + */ + def setExecutorEnv(variables: Seq[(String, String)]): SparkConf = { + for ((k, v) <- variables) { + setExecutorEnv(k, v) + } + this + } + + /** + * Set multiple environment variables to be used when launching executors. + * (Java-friendly version.) + */ + def setExecutorEnv(variables: Array[(String, String)]): SparkConf = { + setExecutorEnv(variables.toSeq) + } + + /** + * Set the location where Spark is installed on worker nodes. + */ + def setSparkHome(home: String): SparkConf = { + set("spark.home", home) + } + + /** Set multiple parameters together */ + def setAll(settings: Traversable[(String, String)]) = { + this.settings ++= settings + this + } + + /** Set a parameter if it isn't already configured */ + def setIfMissing(key: String, value: String): SparkConf = { + if (!settings.contains(key)) { + settings(key) = value + } + this + } + + /** Remove a parameter from the configuration */ + def remove(key: String): SparkConf = { + settings.remove(key) + this + } + + /** Get a parameter; throws a NoSuchElementException if it's not set */ + def get(key: String): String = { + settings.getOrElse(key, throw new NoSuchElementException(key)) + } + + /** Get a parameter, falling back to a default if not set */ + def get(key: String, defaultValue: String): String = { + settings.getOrElse(key, defaultValue) + } + + /** Get a parameter as an Option */ + def getOption(key: String): Option[String] = { + settings.get(key) + } + + /** Get all parameters as a list of pairs */ + def getAll: Array[(String, String)] = settings.clone().toArray + + /** Get a parameter as an integer, falling back to a default if not set */ + def getInt(key: String, defaultValue: Int): Int = { + getOption(key).map(_.toInt).getOrElse(defaultValue) + } + + /** Get a parameter as a long, falling back to a default if not set */ + def getLong(key: String, defaultValue: Long): Long = { + getOption(key).map(_.toLong).getOrElse(defaultValue) + } + + /** Get a parameter as a double, falling back to a default if not set */ + def getDouble(key: String, defaultValue: Double): Double = { + getOption(key).map(_.toDouble).getOrElse(defaultValue) + } + + /** Get all executor environment variables set on this SparkConf */ + def getExecutorEnv: Seq[(String, String)] = { + val prefix = "spark.executorEnv." + getAll.filter{case (k, v) => k.startsWith(prefix)} + .map{case (k, v) => (k.substring(prefix.length), v)} + } + + /** Does the configuration contain a given parameter? */ + def contains(key: String): Boolean = settings.contains(key) + + /** Copy this object */ + override def clone: SparkConf = { + new SparkConf(false).setAll(settings) + } + + /** + * Return a string listing all keys and values, one per line. This is useful to print the + * configuration out for debugging. + */ + def toDebugString: String = { + settings.toArray.sorted.map{case (k, v) => k + "=" + v}.mkString("\n") + } +} diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index a0f794edfd..e80e43af6d 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -19,36 +19,23 @@ package org.apache.spark import java.io._ import java.net.URI -import java.util.Properties +import java.util.{UUID, Properties} import java.util.concurrent.atomic.AtomicInteger -import scala.collection.Map +import scala.collection.{Map, Set} import scala.collection.generic.Growable -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap + +import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.reflect.{ClassTag, classTag} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.hadoop.io.ArrayWritable -import org.apache.hadoop.io.BooleanWritable -import org.apache.hadoop.io.BytesWritable -import org.apache.hadoop.io.DoubleWritable -import org.apache.hadoop.io.FloatWritable -import org.apache.hadoop.io.IntWritable -import org.apache.hadoop.io.LongWritable -import org.apache.hadoop.io.NullWritable -import org.apache.hadoop.io.Text -import org.apache.hadoop.io.Writable -import org.apache.hadoop.mapred.FileInputFormat -import org.apache.hadoop.mapred.InputFormat -import org.apache.hadoop.mapred.JobConf -import org.apache.hadoop.mapred.SequenceFileInputFormat -import org.apache.hadoop.mapred.TextInputFormat -import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} -import org.apache.hadoop.mapreduce.{Job => NewHadoopJob} +import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable, +FloatWritable, IntWritable, LongWritable, NullWritable, Text, Writable} +import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf, SequenceFileInputFormat, +TextInputFormat} +import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHadoopJob} import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} - import org.apache.mesos.MesosNativeLibrary import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} @@ -56,59 +43,102 @@ import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, - SparkDeploySchedulerBackend, ClusterScheduler, SimrSchedulerBackend} + SparkDeploySchedulerBackend, SimrSchedulerBackend} import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} -import org.apache.spark.scheduler.local.LocalScheduler -import org.apache.spark.scheduler.StageInfo +import org.apache.spark.scheduler.local.LocalBackend import org.apache.spark.storage.{BlockManagerSource, RDDInfo, StorageStatus, StorageUtils} import org.apache.spark.ui.SparkUI -import org.apache.spark.util.{ClosureCleaner, MetadataCleaner, MetadataCleanerType, - TimeStampedHashMap, Utils} +import org.apache.spark.util.{Utils, TimeStampedHashMap, MetadataCleaner, MetadataCleanerType, +ClosureCleaner} /** * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark * cluster, and can be used to create RDDs, accumulators and broadcast variables on that cluster. * - * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). - * @param appName A name for your application, to display on the cluster web UI. - * @param sparkHome Location where Spark is installed on cluster nodes. - * @param jars Collection of JARs to send to the cluster. These can be paths on the local file - * system or HDFS, HTTP, HTTPS, or FTP URLs. - * @param environment Environment variables to set on worker nodes. + * @param config a Spark Config object describing the application configuration. Any settings in + * this config overrides the default configs as well as system properties. + * @param preferredNodeLocationData used in YARN mode to select nodes to launch containers on. Can + * be generated using [[org.apache.spark.scheduler.InputFormatInfo.computePreferredLocations]] + * from a list of input files or InputFormats for the application. */ class SparkContext( - val master: String, - val appName: String, - val sparkHome: String = null, - val jars: Seq[String] = Nil, - val environment: Map[String, String] = Map(), + config: SparkConf, // This is used only by YARN for now, but should be relevant to other cluster types (Mesos, etc) - // too. This is typically generated from InputFormatInfo.computePreferredLocations .. host, set - // of data-local splits on host - val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] = - scala.collection.immutable.Map()) + // too. This is typically generated from InputFormatInfo.computePreferredLocations. It contains + // a map from hostname to a list of input format splits on the host. + val preferredNodeLocationData: Map[String, Set[SplitInfo]] = Map()) extends Logging { - // Ensure logging is initialized before we spawn any threads - initLogging() + /** + * Alternative constructor that allows setting common Spark properties directly + * + * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). + * @param appName A name for your application, to display on the cluster web UI + * @param conf a [[org.apache.spark.SparkConf]] object specifying other Spark parameters + */ + def this(master: String, appName: String, conf: SparkConf) = + this(SparkContext.updatedConf(conf, master, appName)) - // Set Spark driver host and port system properties - if (System.getProperty("spark.driver.host") == null) { - System.setProperty("spark.driver.host", Utils.localHostName()) + /** + * Alternative constructor that allows setting common Spark properties directly + * + * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). + * @param appName A name for your application, to display on the cluster web UI. + * @param sparkHome Location where Spark is installed on cluster nodes. + * @param jars Collection of JARs to send to the cluster. These can be paths on the local file + * system or HDFS, HTTP, HTTPS, or FTP URLs. + * @param environment Environment variables to set on worker nodes. + */ + def this( + master: String, + appName: String, + sparkHome: String = null, + jars: Seq[String] = Nil, + environment: Map[String, String] = Map(), + preferredNodeLocationData: Map[String, Set[SplitInfo]] = Map()) = + { + this(SparkContext.updatedConf(new SparkConf(), master, appName, sparkHome, jars, environment), + preferredNodeLocationData) + } + + private[spark] val conf = config.clone() + + /** + * Return a copy of this SparkContext's configuration. The configuration ''cannot'' be + * changed at runtime. + */ + def getConf: SparkConf = conf.clone() + + if (!conf.contains("spark.master")) { + throw new SparkException("A master URL must be set in your configuration") } - if (System.getProperty("spark.driver.port") == null) { - System.setProperty("spark.driver.port", "0") + if (!conf.contains("spark.app.name")) { + throw new SparkException("An application must be set in your configuration") } + // Set Spark driver host and port system properties + conf.setIfMissing("spark.driver.host", Utils.localHostName()) + conf.setIfMissing("spark.driver.port", "0") + + val jars: Seq[String] = if (conf.contains("spark.jars")) { + conf.get("spark.jars").split(",").filter(_.size != 0) + } else { + null + } + + val master = conf.get("spark.master") + val appName = conf.get("spark.app.name") + val isLocal = (master == "local" || master.startsWith("local[")) // Create the Spark execution environment (cache, map output tracker, etc) - private[spark] val env = SparkEnv.createFromSystemProperties( + private[spark] val env = SparkEnv.create( + conf, "<driver>", - System.getProperty("spark.driver.host"), - System.getProperty("spark.driver.port").toInt, - true, - isLocal) + conf.get("spark.driver.host"), + conf.get("spark.driver.port").toInt, + isDriver = true, + isLocal = isLocal) SparkEnv.set(env) // Used to store a URL for each static file/jar together with the file's local timestamp @@ -117,7 +147,8 @@ class SparkContext( // Keeps track of all persisted RDDs private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]] - private[spark] val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup) + private[spark] val metadataCleaner = + new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup, conf) // Initialize the Spark UI private[spark] val ui = new SparkUI(this) @@ -127,23 +158,30 @@ class SparkContext( // Add each JAR given through the constructor if (jars != null) { - jars.foreach { addJar(_) } + jars.foreach(addJar) } + private[spark] val executorMemory = conf.getOption("spark.executor.memory") + .orElse(Option(System.getenv("SPARK_MEM"))) + .map(Utils.memoryStringToMb) + .getOrElse(512) + // Environment variables to pass to our executors private[spark] val executorEnvs = HashMap[String, String]() // Note: SPARK_MEM is included for Mesos, but overwritten for standalone mode in ExecutorRunner - for (key <- Seq("SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS", "SPARK_TESTING")) { - val value = System.getenv(key) - if (value != null) { - executorEnvs(key) = value - } + for (key <- Seq("SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS"); + value <- Option(System.getenv(key))) { + executorEnvs(key) = value } - // Since memory can be set with a system property too, use that - executorEnvs("SPARK_MEM") = SparkContext.executorMemoryRequested + "m" - if (environment != null) { - executorEnvs ++= environment + // Convert java options to env vars as a work around + // since we can't set env vars directly in sbt. + for { (envKey, propKey) <- Seq(("SPARK_HOME", "spark.home"), ("SPARK_TESTING", "spark.testing")) + value <- Option(System.getenv(envKey)).orElse(Option(System.getProperty(propKey)))} { + executorEnvs(envKey) = value } + // Since memory can be set with a system property too, use that + executorEnvs("SPARK_MEM") = executorMemory + "m" + executorEnvs ++= conf.getExecutorEnv // Set SPARK_USER for user who is running SparkContext. val sparkUser = Option { @@ -165,24 +203,24 @@ class SparkContext( /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ val hadoopConfiguration = { val env = SparkEnv.get - val conf = SparkHadoopUtil.get.newConfiguration() + val hadoopConf = SparkHadoopUtil.get.newConfiguration() // Explicitly check for S3 environment variables if (System.getenv("AWS_ACCESS_KEY_ID") != null && System.getenv("AWS_SECRET_ACCESS_KEY") != null) { - conf.set("fs.s3.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID")) - conf.set("fs.s3n.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID")) - conf.set("fs.s3.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY")) - conf.set("fs.s3n.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY")) + hadoopConf.set("fs.s3.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID")) + hadoopConf.set("fs.s3n.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID")) + hadoopConf.set("fs.s3.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY")) + hadoopConf.set("fs.s3n.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY")) } // Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar" - Utils.getSystemProperties.foreach { case (key, value) => + conf.getAll.foreach { case (key, value) => if (key.startsWith("spark.hadoop.")) { - conf.set(key.substring("spark.hadoop.".length), value) + hadoopConf.set(key.substring("spark.hadoop.".length), value) } } - val bufferSize = System.getProperty("spark.buffer.size", "65536") - conf.set("io.file.buffer.size", bufferSize) - conf + val bufferSize = conf.get("spark.buffer.size", "65536") + hadoopConf.set("io.file.buffer.size", bufferSize) + hadoopConf } private[spark] var checkpointDir: Option[String] = None @@ -192,7 +230,7 @@ class SparkContext( override protected def childValue(parent: Properties): Properties = new Properties(parent) } - private[spark] def getLocalProperties(): Properties = localProperties.get() + private[spark] def getLocalProperties: Properties = localProperties.get() private[spark] def setLocalProperties(props: Properties) { localProperties.set(props) @@ -522,10 +560,8 @@ class SparkContext( } addedFiles(key) = System.currentTimeMillis - // Fetch the file locally in case a job is executed locally. - // Jobs that run through LocalScheduler will already fetch the required dependencies, - // but jobs run in DAGScheduler.runLocally() will not so we must fetch the files here. - Utils.fetchFile(path, new File(SparkFiles.getRootDirectory)) + // Fetch the file locally in case a job is executed using DAGScheduler.runLocally(). + Utils.fetchFile(path, new File(SparkFiles.getRootDirectory), conf) logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key)) } @@ -695,15 +731,27 @@ class SparkContext( * (in that order of preference). If neither of these is set, return None. */ private[spark] def getSparkHome(): Option[String] = { - if (sparkHome != null) { - Some(sparkHome) - } else if (System.getProperty("spark.home") != null) { - Some(System.getProperty("spark.home")) - } else if (System.getenv("SPARK_HOME") != null) { - Some(System.getenv("SPARK_HOME")) - } else { - None - } + conf.getOption("spark.home").orElse(Option(System.getenv("SPARK_HOME"))) + } + + /** + * Support function for API backtraces. + */ + def setCallSite(site: String) { + setLocalProperty("externalCallSite", site) + } + + /** + * Support function for API backtraces. + */ + def clearCallSite() { + setLocalProperty("externalCallSite", null) + } + + private[spark] def getCallSite(): String = { + val callSite = getLocalProperty("externalCallSite") + if (callSite == null) return Utils.formatSparkCallSite + callSite } /** @@ -718,7 +766,7 @@ class SparkContext( partitions: Seq[Int], allowLocal: Boolean, resultHandler: (Int, U) => Unit) { - val callSite = Utils.formatSparkCallSite + val callSite = getCallSite val cleanedFunc = clean(func) logInfo("Starting job: " + callSite) val start = System.nanoTime @@ -802,7 +850,7 @@ class SparkContext( func: (TaskContext, Iterator[T]) => U, evaluator: ApproximateEvaluator[U, R], timeout: Long): PartialResult[R] = { - val callSite = Utils.formatSparkCallSite + val callSite = getCallSite logInfo("Starting job: " + callSite) val start = System.nanoTime val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout, @@ -822,7 +870,7 @@ class SparkContext( resultFunc: => R): SimpleFutureAction[R] = { val cleanF = clean(processPartition) - val callSite = Utils.formatSparkCallSite + val callSite = getCallSite val waiter = dagScheduler.submitJob( rdd, (context: TaskContext, iter: Iterator[T]) => cleanF(iter), @@ -858,22 +906,15 @@ class SparkContext( /** * Set the directory under which RDDs are going to be checkpointed. The directory must - * be a HDFS path if running on a cluster. If the directory does not exist, it will - * be created. If the directory exists and useExisting is set to true, then the - * exisiting directory will be used. Otherwise an exception will be thrown to - * prevent accidental overriding of checkpoint files in the existing directory. + * be a HDFS path if running on a cluster. */ - def setCheckpointDir(dir: String, useExisting: Boolean = false) { - val path = new Path(dir) - val fs = path.getFileSystem(SparkHadoopUtil.get.newConfiguration()) - if (!useExisting) { - if (fs.exists(path)) { - throw new Exception("Checkpoint directory '" + path + "' already exists.") - } else { - fs.mkdirs(path) - } + def setCheckpointDir(directory: String) { + checkpointDir = Option(directory).map { dir => + val path = new Path(dir, UUID.randomUUID().toString) + val fs = path.getFileSystem(hadoopConfiguration) + fs.mkdirs(path) + fs.getFileStatus(path).getPath().toString } - checkpointDir = Some(dir) } /** Default level of parallelism to use when not given by user (e.g. parallelize and makeRDD). */ @@ -997,7 +1038,7 @@ object SparkContext { /** * Find the JAR from which a given class was loaded, to make it easy for users to pass - * their JARs to SparkContext + * their JARs to SparkContext. */ def jarOfClass(cls: Class[_]): Seq[String] = { val uri = cls.getResource("/" + cls.getName.replace('.', '/') + ".class") @@ -1014,21 +1055,44 @@ object SparkContext { } } - /** Find the JAR that contains the class of a particular object */ + /** + * Find the JAR that contains the class of a particular object, to make it easy for users + * to pass their JARs to SparkContext. In most cases you can call jarOfObject(this) in + * your driver program. + */ def jarOfObject(obj: AnyRef): Seq[String] = jarOfClass(obj.getClass) - /** Get the amount of memory per executor requested through system properties or SPARK_MEM */ - private[spark] val executorMemoryRequested = { - // TODO: Might need to add some extra memory for the non-heap parts of the JVM - Option(System.getProperty("spark.executor.memory")) - .orElse(Option(System.getenv("SPARK_MEM"))) - .map(Utils.memoryStringToMb) - .getOrElse(512) + /** + * Creates a modified version of a SparkConf with the parameters that can be passed separately + * to SparkContext, to make it easier to write SparkContext's constructors. This ignores + * parameters that are passed as the default value of null, instead of throwing an exception + * like SparkConf would. + */ + private def updatedConf( + conf: SparkConf, + master: String, + appName: String, + sparkHome: String = null, + jars: Seq[String] = Nil, + environment: Map[String, String] = Map()): SparkConf = + { + val res = conf.clone() + res.setMaster(master) + res.setAppName(appName) + if (sparkHome != null) { + res.setSparkHome(sparkHome) + } + if (!jars.isEmpty) { + res.setJars(jars) + } + res.setExecutorEnv(environment.toSeq) + res } - // Creates a task scheduler based on a given master URL. Extracted for testing. - private - def createTaskScheduler(sc: SparkContext, master: String, appName: String): TaskScheduler = { + /** Creates a task scheduler based on a given master URL. Extracted for testing. */ + private def createTaskScheduler(sc: SparkContext, master: String, appName: String) + : TaskScheduler = + { // Regular expression used for local[N] master format val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r // Regular expression for local[N, maxRetries], used in tests with failing tasks @@ -1042,18 +1106,30 @@ object SparkContext { // Regular expression for connection to Simr cluster val SIMR_REGEX = """simr://(.*)""".r + // When running locally, don't try to re-execute tasks on failure. + val MAX_LOCAL_TASK_FAILURES = 1 + master match { case "local" => - new LocalScheduler(1, 0, sc) + val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true) + val backend = new LocalBackend(scheduler, 1) + scheduler.initialize(backend) + scheduler case LOCAL_N_REGEX(threads) => - new LocalScheduler(threads.toInt, 0, sc) + val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true) + val backend = new LocalBackend(scheduler, threads.toInt) + scheduler.initialize(backend) + scheduler case LOCAL_N_FAILURES_REGEX(threads, maxFailures) => - new LocalScheduler(threads.toInt, maxFailures.toInt, sc) + val scheduler = new TaskSchedulerImpl(sc, maxFailures.toInt, isLocal = true) + val backend = new LocalBackend(scheduler, threads.toInt) + scheduler.initialize(backend) + scheduler case SPARK_REGEX(sparkUrl) => - val scheduler = new ClusterScheduler(sc) + val scheduler = new TaskSchedulerImpl(sc) val masterUrls = sparkUrl.split(",").map("spark://" + _) val backend = new SparkDeploySchedulerBackend(scheduler, sc, masterUrls, appName) scheduler.initialize(backend) @@ -1062,13 +1138,13 @@ object SparkContext { case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) => // Check to make sure memory requested <= memoryPerSlave. Otherwise Spark will just hang. val memoryPerSlaveInt = memoryPerSlave.toInt - if (SparkContext.executorMemoryRequested > memoryPerSlaveInt) { + if (sc.executorMemory > memoryPerSlaveInt) { throw new SparkException( "Asked to launch cluster with %d MB RAM / worker but requested %d MB/worker".format( - memoryPerSlaveInt, SparkContext.executorMemoryRequested)) + memoryPerSlaveInt, sc.executorMemory)) } - val scheduler = new ClusterScheduler(sc) + val scheduler = new TaskSchedulerImpl(sc) val localCluster = new LocalSparkCluster( numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt) val masterUrls = localCluster.start() @@ -1083,7 +1159,7 @@ object SparkContext { val scheduler = try { val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClusterScheduler") val cons = clazz.getConstructor(classOf[SparkContext]) - cons.newInstance(sc).asInstanceOf[ClusterScheduler] + cons.newInstance(sc).asInstanceOf[TaskSchedulerImpl] } catch { // TODO: Enumerate the exact reasons why it can fail // But irrespective of it, it means we cannot proceed ! @@ -1099,7 +1175,7 @@ object SparkContext { val scheduler = try { val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClientClusterScheduler") val cons = clazz.getConstructor(classOf[SparkContext]) - cons.newInstance(sc).asInstanceOf[ClusterScheduler] + cons.newInstance(sc).asInstanceOf[TaskSchedulerImpl] } catch { case th: Throwable => { @@ -1109,7 +1185,7 @@ object SparkContext { val backend = try { val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClientSchedulerBackend") - val cons = clazz.getConstructor(classOf[ClusterScheduler], classOf[SparkContext]) + val cons = clazz.getConstructor(classOf[TaskSchedulerImpl], classOf[SparkContext]) cons.newInstance(scheduler, sc).asInstanceOf[CoarseGrainedSchedulerBackend] } catch { case th: Throwable => { @@ -1122,8 +1198,8 @@ object SparkContext { case mesosUrl @ MESOS_REGEX(_) => MesosNativeLibrary.load() - val scheduler = new ClusterScheduler(sc) - val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean + val scheduler = new TaskSchedulerImpl(sc) + val coarseGrained = sc.conf.get("spark.mesos.coarse", "false").toBoolean val url = mesosUrl.stripPrefix("mesos://") // strip scheme from raw Mesos URLs val backend = if (coarseGrained) { new CoarseMesosSchedulerBackend(scheduler, sc, url, appName) @@ -1134,7 +1210,7 @@ object SparkContext { scheduler case SIMR_REGEX(simrUrl) => - val scheduler = new ClusterScheduler(sc) + val scheduler = new TaskSchedulerImpl(sc) val backend = new SimrSchedulerBackend(scheduler, sc, simrUrl) scheduler.initialize(backend) scheduler diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 826f5c2d8c..634a94f0a7 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -40,7 +40,7 @@ import com.google.common.collect.MapMaker * objects needs to have the right SparkEnv set. You can get the current environment with * SparkEnv.get (e.g. after creating a SparkContext) and set it with SparkEnv.set. */ -class SparkEnv ( +class SparkEnv private[spark] ( val executorId: String, val actorSystem: ActorSystem, val serializerManager: SerializerManager, @@ -54,7 +54,8 @@ class SparkEnv ( val connectionManager: ConnectionManager, val httpFileServer: HttpFileServer, val sparkFilesDir: String, - val metricsSystem: MetricsSystem) { + val metricsSystem: MetricsSystem, + val conf: SparkConf) { private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() @@ -62,7 +63,7 @@ class SparkEnv ( // (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats). private[spark] val hadoopJobMetadata = new MapMaker().softValues().makeMap[String, Any]() - def stop() { + private[spark] def stop() { pythonWorkers.foreach { case(key, worker) => worker.stop() } httpFileServer.stop() mapOutputTracker.stop() @@ -78,6 +79,7 @@ class SparkEnv ( //actorSystem.awaitTermination() } + private[spark] def createPythonWorker(pythonExec: String, envVars: Map[String, String]): java.net.Socket = { synchronized { val key = (pythonExec, envVars) @@ -106,33 +108,35 @@ object SparkEnv extends Logging { /** * Returns the ThreadLocal SparkEnv. */ - def getThreadLocal : SparkEnv = { + def getThreadLocal: SparkEnv = { env.get() } - def createFromSystemProperties( + private[spark] def create( + conf: SparkConf, executorId: String, hostname: String, port: Int, isDriver: Boolean, isLocal: Boolean): SparkEnv = { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port) + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port, + conf = conf) // Bit of a hack: If this is the driver and our port was 0 (meaning bind to any free port), // figure out which port number Akka actually bound to and set spark.driver.port to it. if (isDriver && port == 0) { - System.setProperty("spark.driver.port", boundPort.toString) + conf.set("spark.driver.port", boundPort.toString) } // set only if unset until now. - if (System.getProperty("spark.hostPort", null) == null) { + if (!conf.contains("spark.hostPort")) { if (!isDriver){ // unexpected Utils.logErrorWithStack("Unexpected NOT to have spark.hostPort set") } Utils.checkHost(hostname) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) + conf.set("spark.hostPort", hostname + ":" + boundPort) } val classLoader = Thread.currentThread.getContextClassLoader @@ -140,25 +144,26 @@ object SparkEnv extends Logging { // Create an instance of the class named by the given Java system property, or by // defaultClassName if the property is not set, and return it as a T def instantiateClass[T](propertyName: String, defaultClassName: String): T = { - val name = System.getProperty(propertyName, defaultClassName) + val name = conf.get(propertyName, defaultClassName) Class.forName(name, true, classLoader).newInstance().asInstanceOf[T] } val serializerManager = new SerializerManager val serializer = serializerManager.setDefault( - System.getProperty("spark.serializer", "org.apache.spark.serializer.JavaSerializer")) + conf.get("spark.serializer", "org.apache.spark.serializer.JavaSerializer"), conf) val closureSerializer = serializerManager.get( - System.getProperty("spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer")) + conf.get("spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer"), + conf) def registerOrLookup(name: String, newActor: => Actor): Either[ActorRef, ActorSelection] = { if (isDriver) { logInfo("Registering " + name) Left(actorSystem.actorOf(Props(newActor), name = name)) } else { - val driverHost: String = System.getProperty("spark.driver.host", "localhost") - val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt + val driverHost: String = conf.get("spark.driver.host", "localhost") + val driverPort: Int = conf.get("spark.driver.port", "7077").toInt Utils.checkHost(driverHost, "Expected hostname") val url = "akka.tcp://spark@%s:%s/user/%s".format(driverHost, driverPort, name) logInfo("Connecting to " + name + ": " + url) @@ -168,21 +173,21 @@ object SparkEnv extends Logging { val blockManagerMaster = new BlockManagerMaster(registerOrLookup( "BlockManagerMaster", - new BlockManagerMasterActor(isLocal))) - val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, serializer) + new BlockManagerMasterActor(isLocal, conf)), conf) + val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, serializer, conf) val connectionManager = blockManager.connectionManager - val broadcastManager = new BroadcastManager(isDriver) + val broadcastManager = new BroadcastManager(isDriver, conf) val cacheManager = new CacheManager(blockManager) // Have to assign trackerActor after initialization as MapOutputTrackerActor // requires the MapOutputTracker itself val mapOutputTracker = if (isDriver) { - new MapOutputTrackerMaster() + new MapOutputTrackerMaster(conf) } else { - new MapOutputTracker() + new MapOutputTracker(conf) } mapOutputTracker.trackerActor = registerOrLookup( "MapOutputTracker", @@ -193,12 +198,12 @@ object SparkEnv extends Logging { val httpFileServer = new HttpFileServer() httpFileServer.initialize() - System.setProperty("spark.fileserver.uri", httpFileServer.serverUri) + conf.set("spark.fileserver.uri", httpFileServer.serverUri) val metricsSystem = if (isDriver) { - MetricsSystem.createMetricsSystem("driver") + MetricsSystem.createMetricsSystem("driver", conf) } else { - MetricsSystem.createMetricsSystem("executor") + MetricsSystem.createMetricsSystem("executor", conf) } metricsSystem.start() @@ -212,7 +217,7 @@ object SparkEnv extends Logging { } // Warn about deprecated spark.cache.class property - if (System.getProperty("spark.cache.class") != null) { + if (conf.contains("spark.cache.class")) { logWarning("The spark.cache.class property is no longer being used! Specify storage " + "levels using the RDD.persist() method instead.") } @@ -231,6 +236,7 @@ object SparkEnv extends Logging { connectionManager, httpFileServer, sparkFilesDir, - metricsSystem) + metricsSystem, + conf) } } diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index c1e5e04b31..faf6dcd618 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -53,5 +53,3 @@ private[spark] case class ExceptionFailure( private[spark] case object TaskResultLost extends TaskEndReason private[spark] case object TaskKilled extends TaskEndReason - -private[spark] case class OtherFailure(message: String) extends TaskEndReason diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 363667fa86..55c87450ac 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -611,6 +611,42 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kClassTag: ClassTag[K * Return an RDD with the values of each tuple. */ def values(): JavaRDD[V] = JavaRDD.fromRDD[V](rdd.map(_._2)) + + /** + * Return approximate number of distinct values for each key in this RDD. + * The accuracy of approximation can be controlled through the relative standard deviation + * (relativeSD) parameter, which also controls the amount of memory used. Lower values result in + * more accurate counts but increase the memory footprint and vise versa. Uses the provided + * Partitioner to partition the output RDD. + */ + def countApproxDistinctByKey(relativeSD: Double, partitioner: Partitioner): JavaRDD[(K, Long)] = { + rdd.countApproxDistinctByKey(relativeSD, partitioner) + } + + /** + * Return approximate number of distinct values for each key this RDD. + * The accuracy of approximation can be controlled through the relative standard deviation + * (relativeSD) parameter, which also controls the amount of memory used. Lower values result in + * more accurate counts but increase the memory footprint and vise versa. The default value of + * relativeSD is 0.05. Hash-partitions the output RDD using the existing partitioner/parallelism + * level. + */ + def countApproxDistinctByKey(relativeSD: Double = 0.05): JavaRDD[(K, Long)] = { + rdd.countApproxDistinctByKey(relativeSD) + } + + + /** + * Return approximate number of distinct values for each key in this RDD. + * The accuracy of approximation can be controlled through the relative standard deviation + * (relativeSD) parameter, which also controls the amount of memory used. Lower values result in + * more accurate counts but increase the memory footprint and vise versa. HashPartitions the + * output RDD into numPartitions. + * + */ + def countApproxDistinctByKey(relativeSD: Double, numPartitions: Int): JavaRDD[(K, Long)] = { + rdd.countApproxDistinctByKey(relativeSD, numPartitions) + } } object JavaPairRDD { diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index f344804b4c..924d8af060 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -444,4 +444,15 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[T]] takeOrdered(num, comp) } + + /** + * Return approximate number of distinct elements in the RDD. + * + * The accuracy of approximation can be controlled through the relative standard deviation + * (relativeSD) parameter, which also controls the amount of memory used. Lower values result in + * more accurate counts but increase the memory footprint and vise versa. The default value of + * relativeSD is 0.05. + */ + def countApproxDistinct(relativeSD: Double = 0.05): Long = rdd.countApproxDistinct(relativeSD) + } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index acf328aa6a..e93b10fd7e 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -29,17 +29,22 @@ import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import com.google.common.base.Optional -import org.apache.spark.{Accumulable, AccumulableParam, Accumulator, AccumulatorParam, SparkContext} +import org.apache.spark._ import org.apache.spark.SparkContext.IntAccumulatorParam import org.apache.spark.SparkContext.DoubleAccumulatorParam import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD +import scala.Tuple2 /** * A Java-friendly version of [[org.apache.spark.SparkContext]] that returns [[org.apache.spark.api.java.JavaRDD]]s and * works with Java collections instead of Scala ones. */ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWorkaround { + /** + * @param conf a [[org.apache.spark.SparkConf]] object specifying Spark parameters + */ + def this(conf: SparkConf) = this(new SparkContext(conf)) /** * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). @@ -50,6 +55,14 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork /** * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). * @param appName A name for your application, to display on the cluster web UI + * @param conf a [[org.apache.spark.SparkConf]] object specifying other Spark parameters + */ + def this(master: String, appName: String, conf: SparkConf) = + this(conf.setMaster(master).setAppName(appName)) + + /** + * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). + * @param appName A name for your application, to display on the cluster web UI * @param sparkHome The SPARK_HOME directory on the slave nodes * @param jarFile JAR file to send to the cluster. This can be a path on the local file system * or an HDFS, HTTP, HTTPS, or FTP URL. @@ -381,20 +394,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork /** * Set the directory under which RDDs are going to be checkpointed. The directory must - * be a HDFS path if running on a cluster. If the directory does not exist, it will - * be created. If the directory exists and useExisting is set to true, then the - * exisiting directory will be used. Otherwise an exception will be thrown to - * prevent accidental overriding of checkpoint files in the existing directory. - */ - def setCheckpointDir(dir: String, useExisting: Boolean) { - sc.setCheckpointDir(dir, useExisting) - } - - /** - * Set the directory under which RDDs are going to be checkpointed. The directory must - * be a HDFS path if running on a cluster. If the directory does not exist, it will - * be created. If the directory exists, an exception will be thrown to prevent accidental - * overriding of checkpoint files. + * be a HDFS path if running on a cluster. */ def setCheckpointDir(dir: String) { sc.setCheckpointDir(dir) @@ -405,10 +405,36 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] new JavaRDD(sc.checkpointFile(path)) } + + /** + * Return a copy of this JavaSparkContext's configuration. The configuration ''cannot'' be + * changed at runtime. + */ + def getConf: SparkConf = sc.getConf + + /** + * Pass-through to SparkContext.setCallSite. For API support only. + */ + def setCallSite(site: String) { + sc.setCallSite(site) + } + + /** + * Pass-through to SparkContext.setCallSite. For API support only. + */ + def clearCallSite() { + sc.clearCallSite() + } } object JavaSparkContext { implicit def fromSparkContext(sc: SparkContext): JavaSparkContext = new JavaSparkContext(sc) implicit def toSparkContext(jsc: JavaSparkContext): SparkContext = jsc.sc + + /** + * Find the JAR from which a given class was loaded, to make it easy for users to pass + * their JARs to SparkContext. + */ + def jarOfClass(cls: Class[_]) = SparkContext.jarOfClass(cls).toArray } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index ca42c76928..32cc70e8c9 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -41,7 +41,7 @@ private[spark] class PythonRDD[T: ClassTag]( accumulator: Accumulator[JList[Array[Byte]]]) extends RDD[Array[Byte]](parent) { - val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt + val bufferSize = conf.get("spark.buffer.size", "65536").toInt override def getPartitions = parent.partitions @@ -250,7 +250,7 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort: Utils.checkHost(serverHost, "Expected hostname") - val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt + val bufferSize = SparkEnv.get.conf.get("spark.buffer.size", "65536").toInt override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index 43c18294c5..0fc478a419 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -31,8 +31,8 @@ abstract class Broadcast[T](private[spark] val id: Long) extends Serializable { override def toString = "Broadcast(" + id + ")" } -private[spark] -class BroadcastManager(val _isDriver: Boolean) extends Logging with Serializable { +private[spark] +class BroadcastManager(val _isDriver: Boolean, conf: SparkConf) extends Logging with Serializable { private var initialized = false private var broadcastFactory: BroadcastFactory = null @@ -43,14 +43,14 @@ class BroadcastManager(val _isDriver: Boolean) extends Logging with Serializable private def initialize() { synchronized { if (!initialized) { - val broadcastFactoryClass = System.getProperty( + val broadcastFactoryClass = conf.get( "spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") broadcastFactory = Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] // Initialize appropriate BroadcastFactory and BroadcastObject - broadcastFactory.initialize(isDriver) + broadcastFactory.initialize(isDriver, conf) initialized = true } diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala index 68bff75b90..fb161ce69d 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala @@ -17,6 +17,8 @@ package org.apache.spark.broadcast +import org.apache.spark.SparkConf + /** * An interface for all the broadcast implementations in Spark (to allow * multiple broadcast implementations). SparkContext uses a user-specified @@ -24,7 +26,7 @@ package org.apache.spark.broadcast * entire Spark job. */ private[spark] trait BroadcastFactory { - def initialize(isDriver: Boolean): Unit + def initialize(isDriver: Boolean, conf: SparkConf): Unit def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T] def stop(): Unit } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 47db720416..db596d5fcc 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -24,14 +24,14 @@ import java.util.concurrent.TimeUnit import it.unimi.dsi.fastutil.io.FastBufferedInputStream import it.unimi.dsi.fastutil.io.FastBufferedOutputStream -import org.apache.spark.{HttpServer, Logging, SparkEnv} +import org.apache.spark.{SparkConf, HttpServer, Logging, SparkEnv} import org.apache.spark.io.CompressionCodec import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils} private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) extends Broadcast[T](id) with Logging with Serializable { - + def value = value_ def blockId = BroadcastBlockId(id) @@ -40,7 +40,7 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) } - if (!isLocal) { + if (!isLocal) { HttpBroadcast.write(id, value_) } @@ -64,7 +64,7 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea } private[spark] class HttpBroadcastFactory extends BroadcastFactory { - def initialize(isDriver: Boolean) { HttpBroadcast.initialize(isDriver) } + def initialize(isDriver: Boolean, conf: SparkConf) { HttpBroadcast.initialize(isDriver, conf) } def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = new HttpBroadcast[T](value_, isLocal, id) @@ -81,44 +81,51 @@ private object HttpBroadcast extends Logging { private var serverUri: String = null private var server: HttpServer = null + // TODO: This shouldn't be a global variable so that multiple SparkContexts can coexist private val files = new TimeStampedHashSet[String] - private val cleaner = new MetadataCleaner(MetadataCleanerType.HTTP_BROADCAST, cleanup) + private var cleaner: MetadataCleaner = null - private val httpReadTimeout = TimeUnit.MILLISECONDS.convert(5,TimeUnit.MINUTES).toInt + private val httpReadTimeout = TimeUnit.MILLISECONDS.convert(5, TimeUnit.MINUTES).toInt - private lazy val compressionCodec = CompressionCodec.createCodec() + private var compressionCodec: CompressionCodec = null - def initialize(isDriver: Boolean) { + def initialize(isDriver: Boolean, conf: SparkConf) { synchronized { if (!initialized) { - bufferSize = System.getProperty("spark.buffer.size", "65536").toInt - compress = System.getProperty("spark.broadcast.compress", "true").toBoolean + bufferSize = conf.get("spark.buffer.size", "65536").toInt + compress = conf.get("spark.broadcast.compress", "true").toBoolean if (isDriver) { - createServer() + createServer(conf) + conf.set("spark.httpBroadcast.uri", serverUri) } - serverUri = System.getProperty("spark.httpBroadcast.uri") + serverUri = conf.get("spark.httpBroadcast.uri") + cleaner = new MetadataCleaner(MetadataCleanerType.HTTP_BROADCAST, cleanup, conf) + compressionCodec = CompressionCodec.createCodec(conf) initialized = true } } } - + def stop() { synchronized { if (server != null) { server.stop() server = null } + if (cleaner != null) { + cleaner.cancel() + cleaner = null + } + compressionCodec = null initialized = false - cleaner.cancel() } } - private def createServer() { - broadcastDir = Utils.createTempDir(Utils.getLocalDir) + private def createServer(conf: SparkConf) { + broadcastDir = Utils.createTempDir(Utils.getLocalDir(conf)) server = new HttpServer(broadcastDir) server.start() serverUri = server.uri - System.setProperty("spark.httpBroadcast.uri", serverUri) logInfo("Broadcast server started at " + serverUri) } @@ -143,7 +150,7 @@ private object HttpBroadcast extends Logging { val in = { val httpConnection = new URL(url).openConnection() httpConnection.setReadTimeout(httpReadTimeout) - val inputStream = httpConnection.getInputStream() + val inputStream = httpConnection.getInputStream if (compress) { compressionCodec.compressedInputStream(inputStream) } else { diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 073a0a5029..9530938278 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -83,13 +83,13 @@ extends Broadcast[T](id) with Logging with Serializable { case None => val start = System.nanoTime logInfo("Started reading broadcast variable " + id) - + // Initialize @transient variables that will receive garbage values from the master. resetWorkerVariables() if (receiveBroadcast(id)) { value_ = TorrentBroadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) - + // Store the merged copy in cache so that the next worker doesn't need to rebuild it. // This creates a tradeoff between memory usage and latency. // Storing copy doubles the memory footprint; not storing doubles deserialization cost. @@ -122,14 +122,14 @@ extends Broadcast[T](id) with Logging with Serializable { while (attemptId > 0 && totalBlocks == -1) { TorrentBroadcast.synchronized { SparkEnv.get.blockManager.getSingle(metaId) match { - case Some(x) => + case Some(x) => val tInfo = x.asInstanceOf[TorrentInfo] totalBlocks = tInfo.totalBlocks totalBytes = tInfo.totalBytes arrayOfBlocks = new Array[TorrentBlock](totalBlocks) hasBlocks = 0 - - case None => + + case None => Thread.sleep(500) } } @@ -145,13 +145,13 @@ extends Broadcast[T](id) with Logging with Serializable { val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + pid) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.getSingle(pieceId) match { - case Some(x) => + case Some(x) => arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock] hasBlocks += 1 SparkEnv.get.blockManager.putSingle( pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, true) - - case None => + + case None => throw new SparkException("Failed to get " + pieceId + " of " + broadcastId) } } @@ -166,21 +166,22 @@ private object TorrentBroadcast extends Logging { private var initialized = false - - def initialize(_isDriver: Boolean) { + private var conf: SparkConf = null + def initialize(_isDriver: Boolean, conf: SparkConf) { + TorrentBroadcast.conf = conf //TODO: we might have to fix it in tests synchronized { if (!initialized) { initialized = true } } } - + def stop() { initialized = false } - val BLOCK_SIZE = System.getProperty("spark.broadcast.blockSize", "4096").toInt * 1024 - + lazy val BLOCK_SIZE = conf.get("spark.broadcast.blockSize", "4096").toInt * 1024 + def blockifyObject[T](obj: T): TorrentInfo = { val byteArray = Utils.serialize[T](obj) val bais = new ByteArrayInputStream(byteArray) @@ -209,7 +210,7 @@ extends Logging { } def unBlockifyObject[T](arrayOfBlocks: Array[TorrentBlock], - totalBytes: Int, + totalBytes: Int, totalBlocks: Int): T = { var retByteArray = new Array[Byte](totalBytes) for (i <- 0 until totalBlocks) { @@ -222,23 +223,23 @@ extends Logging { } private[spark] case class TorrentBlock( - blockID: Int, - byteArray: Array[Byte]) + blockID: Int, + byteArray: Array[Byte]) extends Serializable private[spark] case class TorrentInfo( @transient arrayOfBlocks : Array[TorrentBlock], - totalBlocks: Int, - totalBytes: Int) + totalBlocks: Int, + totalBytes: Int) extends Serializable { - - @transient var hasBlocks = 0 + + @transient var hasBlocks = 0 } private[spark] class TorrentBroadcastFactory extends BroadcastFactory { - - def initialize(isDriver: Boolean) { TorrentBroadcast.initialize(isDriver) } + + def initialize(isDriver: Boolean, conf: SparkConf) { TorrentBroadcast.initialize(isDriver, conf) } def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = new TorrentBroadcast[T](value_, isLocal, id) diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala index 0aa8852649..4dfb19ed8a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala @@ -190,7 +190,7 @@ private[spark] object FaultToleranceTest extends App with Logging { /** Creates a SparkContext, which constructs a Client to interact with our cluster. */ def createClient() = { if (sc != null) { sc.stop() } - // Counter-hack: Because of a hack in SparkEnv#createFromSystemProperties() that changes this + // Counter-hack: Because of a hack in SparkEnv#create() that changes this // property, we need to reset it. System.setProperty("spark.driver.port", "0") sc = new SparkContext(getMasterUrls(masters), "fault-tolerance", containerSparkHome) @@ -417,4 +417,4 @@ private[spark] object Docker extends Logging { "docker ps -l -q".!(ProcessLogger(line => id = line)) new DockerId(id) } -}
\ No newline at end of file +} diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala index 59d12a3e6f..ffc0cb0903 100644 --- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala @@ -22,7 +22,7 @@ import akka.actor.ActorSystem import org.apache.spark.deploy.worker.Worker import org.apache.spark.deploy.master.Master import org.apache.spark.util.Utils -import org.apache.spark.Logging +import org.apache.spark.{SparkConf, Logging} import scala.collection.mutable.ArrayBuffer @@ -43,7 +43,8 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I logInfo("Starting a local Spark cluster with " + numWorkers + " workers.") /* Start the Master */ - val (masterSystem, masterPort, _) = Master.startSystemAndActor(localHostname, 0, 0) + val conf = new SparkConf(false) + val (masterSystem, masterPort, _) = Master.startSystemAndActor(localHostname, 0, 0, conf) masterActorSystems += masterSystem val masterUrl = "spark://" + localHostname + ":" + masterPort val masters = Array(masterUrl) @@ -55,7 +56,7 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I workerActorSystems += workerSystem } - return masters + masters } def stop() { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index fc1537f796..27dc42bf7e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -34,10 +34,10 @@ class SparkHadoopUtil { UserGroupInformation.setConfiguration(conf) def runAsUser(user: String)(func: () => Unit) { - // if we are already running as the user intended there is no reason to do the doAs. It + // if we are already running as the user intended there is no reason to do the doAs. It // will actually break secure HDFS access as it doesn't fill in the credentials. Also if - // the user is UNKNOWN then we shouldn't be creating a remote unknown user - // (this is actually the path spark on yarn takes) since SPARK_USER is initialized only + // the user is UNKNOWN then we shouldn't be creating a remote unknown user + // (this is actually the path spark on yarn takes) since SPARK_USER is initialized only // in SparkContext. val currentUser = Option(System.getProperty("user.name")). getOrElse(SparkContext.SPARK_UNKNOWN_USER) @@ -67,11 +67,15 @@ class SparkHadoopUtil { } object SparkHadoopUtil { + private val hadoop = { - val yarnMode = java.lang.Boolean.valueOf(System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE"))) + val yarnMode = java.lang.Boolean.valueOf( + System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE"))) if (yarnMode) { try { - Class.forName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil").newInstance.asInstanceOf[SparkHadoopUtil] + Class.forName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil") + .newInstance() + .asInstanceOf[SparkHadoopUtil] } catch { case th: Throwable => throw new SparkException("Unable to load YARN support", th) } diff --git a/core/src/main/scala/org/apache/spark/deploy/client/Client.scala b/core/src/main/scala/org/apache/spark/deploy/client/Client.scala index 953755e40d..481026eaa2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/Client.scala @@ -19,20 +19,19 @@ package org.apache.spark.deploy.client import java.util.concurrent.TimeoutException -import scala.concurrent.duration._ import scala.concurrent.Await +import scala.concurrent.duration._ import akka.actor._ import akka.pattern.ask -import akka.remote.{RemotingLifecycleEvent, DisassociatedEvent} +import akka.remote.{AssociationErrorEvent, DisassociatedEvent, RemotingLifecycleEvent} -import org.apache.spark.{SparkException, Logging} +import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.Master import org.apache.spark.util.AkkaUtils - /** * The main class used to talk to a Spark deploy cluster. Takes a master URL, an app description, * and a listener for cluster events, and calls back the listener when various events occur. @@ -43,7 +42,8 @@ private[spark] class Client( actorSystem: ActorSystem, masterUrls: Array[String], appDescription: ApplicationDescription, - listener: ClientListener) + listener: ClientListener, + conf: SparkConf) extends Logging { val REGISTRATION_TIMEOUT = 20.seconds @@ -111,6 +111,12 @@ private[spark] class Client( } } + private def isPossibleMaster(remoteUrl: Address) = { + masterUrls.map(s => Master.toAkkaUrl(s)) + .map(u => AddressFromURIString(u).hostPort) + .contains(remoteUrl.hostPort) + } + override def receive = { case RegisteredApplication(appId_, masterUrl) => appId = appId_ @@ -146,6 +152,9 @@ private[spark] class Client( logWarning(s"Connection to $address failed; waiting for master to reconnect...") markDisconnected() + case AssociationErrorEvent(cause, _, address, _) if isPossibleMaster(address) => + logWarning(s"Could not connect to $address: $cause") + case StopClient => markDead() sender ! true @@ -178,7 +187,7 @@ private[spark] class Client( def stop() { if (actor != null) { try { - val timeout = AkkaUtils.askTimeout + val timeout = AkkaUtils.askTimeout(conf) val future = actor.ask(StopClient)(timeout) Await.result(future, timeout) } catch { diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala index 5b62d3ba6c..ef649fd80c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala @@ -18,7 +18,7 @@ package org.apache.spark.deploy.client import org.apache.spark.util.{Utils, AkkaUtils} -import org.apache.spark.{Logging} +import org.apache.spark.{SparkConf, SparkContext, Logging} import org.apache.spark.deploy.{Command, ApplicationDescription} private[spark] object TestClient { @@ -45,11 +45,13 @@ private[spark] object TestClient { def main(args: Array[String]) { val url = args(0) - val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0) + val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0, + conf = new SparkConf) val desc = new ApplicationDescription( - "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), "dummy-spark-home", "ignored") + "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), + "dummy-spark-home", "ignored") val listener = new TestListener - val client = new Client(actorSystem, Array(url), desc, listener) + val client = new Client(actorSystem, Array(url), desc, listener, new SparkConf) client.start() actorSystem.awaitTermination() } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index eebd0794b8..7b696cfcca 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -29,7 +29,7 @@ import akka.pattern.ask import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} import akka.serialization.SerializationExtension -import org.apache.spark.{Logging, SparkException} +import org.apache.spark.{SparkConf, SparkContext, Logging, SparkException} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.MasterMessages._ @@ -38,14 +38,16 @@ import org.apache.spark.metrics.MetricsSystem import org.apache.spark.util.{AkkaUtils, Utils} private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Actor with Logging { - import context.dispatcher + import context.dispatcher // to use Akka's scheduler.schedule() + + val conf = new SparkConf val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs - val WORKER_TIMEOUT = System.getProperty("spark.worker.timeout", "60").toLong * 1000 - val RETAINED_APPLICATIONS = System.getProperty("spark.deploy.retainedApplications", "200").toInt - val REAPER_ITERATIONS = System.getProperty("spark.dead.worker.persistence", "15").toInt - val RECOVERY_DIR = System.getProperty("spark.deploy.recoveryDirectory", "") - val RECOVERY_MODE = System.getProperty("spark.deploy.recoveryMode", "NONE") + val WORKER_TIMEOUT = conf.get("spark.worker.timeout", "60").toLong * 1000 + val RETAINED_APPLICATIONS = conf.get("spark.deploy.retainedApplications", "200").toInt + val REAPER_ITERATIONS = conf.get("spark.dead.worker.persistence", "15").toInt + val RECOVERY_DIR = conf.get("spark.deploy.recoveryDirectory", "") + val RECOVERY_MODE = conf.get("spark.deploy.recoveryMode", "NONE") var nextAppNumber = 0 val workers = new HashSet[WorkerInfo] @@ -63,8 +65,8 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act Utils.checkHost(host, "Expected hostname") - val masterMetricsSystem = MetricsSystem.createMetricsSystem("master") - val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications") + val masterMetricsSystem = MetricsSystem.createMetricsSystem("master", conf) + val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications", conf) val masterSource = new MasterSource(this) val webUi = new MasterWebUI(this, webUiPort) @@ -86,7 +88,7 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act // As a temporary workaround before better ways of configuring memory, we allow users to set // a flag that will perform round-robin scheduling across the nodes (spreading out each app // among all the nodes) instead of trying to consolidate each app onto a small # of nodes. - val spreadOutApps = System.getProperty("spark.deploy.spreadOut", "true").toBoolean + val spreadOutApps = conf.get("spark.deploy.spreadOut", "true").toBoolean override def preStart() { logInfo("Starting Spark master at " + masterUrl) @@ -103,7 +105,7 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act persistenceEngine = RECOVERY_MODE match { case "ZOOKEEPER" => logInfo("Persisting recovery state to ZooKeeper") - new ZooKeeperPersistenceEngine(SerializationExtension(context.system)) + new ZooKeeperPersistenceEngine(SerializationExtension(context.system), conf) case "FILESYSTEM" => logInfo("Persisting recovery state to directory: " + RECOVERY_DIR) new FileSystemPersistenceEngine(RECOVERY_DIR, SerializationExtension(context.system)) @@ -113,7 +115,7 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act leaderElectionAgent = RECOVERY_MODE match { case "ZOOKEEPER" => - context.actorOf(Props(classOf[ZooKeeperLeaderElectionAgent], self, masterUrl)) + context.actorOf(Props(classOf[ZooKeeperLeaderElectionAgent], self, masterUrl, conf)) case _ => context.actorOf(Props(classOf[MonarchyLeaderAgent], self)) } @@ -495,7 +497,7 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act removeWorker(worker) } else { if (worker.lastHeartbeat < currentTime - ((REAPER_ITERATIONS + 1) * WORKER_TIMEOUT)) - workers -= worker // we've seen this DEAD worker in the UI, etc. for long enough; cull it + workers -= worker // we've seen this DEAD worker in the UI, etc. for long enough; cull it } } } @@ -507,8 +509,9 @@ private[spark] object Master { val sparkUrlRegex = "spark://([^:]+):([0-9]+)".r def main(argStrings: Array[String]) { - val args = new MasterArguments(argStrings) - val (actorSystem, _, _) = startSystemAndActor(args.host, args.port, args.webUiPort) + val conf = new SparkConf + val args = new MasterArguments(argStrings, conf) + val (actorSystem, _, _) = startSystemAndActor(args.host, args.port, args.webUiPort, conf) actorSystem.awaitTermination() } @@ -522,10 +525,12 @@ private[spark] object Master { } } - def startSystemAndActor(host: String, port: Int, webUiPort: Int): (ActorSystem, Int, Int) = { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port) + def startSystemAndActor(host: String, port: Int, webUiPort: Int, conf: SparkConf) + : (ActorSystem, Int, Int) = + { + val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf) val actor = actorSystem.actorOf(Props(classOf[Master], host, boundPort, webUiPort), actorName) - val timeout = AkkaUtils.askTimeout + val timeout = AkkaUtils.askTimeout(conf) val respFuture = actor.ask(RequestWebUIPort)(timeout) val resp = Await.result(respFuture, timeout).asInstanceOf[WebUIPortResponse] (actorSystem, boundPort, resp.webUIBoundPort) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala index 9d89b455fb..e7f3224091 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala @@ -18,16 +18,17 @@ package org.apache.spark.deploy.master import org.apache.spark.util.{Utils, IntParam} +import org.apache.spark.SparkConf /** * Command-line parser for the master. */ -private[spark] class MasterArguments(args: Array[String]) { +private[spark] class MasterArguments(args: Array[String], conf: SparkConf) { var host = Utils.localHostName() var port = 7077 var webUiPort = 8080 - - // Check for settings in environment variables + + // Check for settings in environment variables if (System.getenv("SPARK_MASTER_HOST") != null) { host = System.getenv("SPARK_MASTER_HOST") } @@ -37,8 +38,8 @@ private[spark] class MasterArguments(args: Array[String]) { if (System.getenv("SPARK_MASTER_WEBUI_PORT") != null) { webUiPort = System.getenv("SPARK_MASTER_WEBUI_PORT").toInt } - if (System.getProperty("master.ui.port") != null) { - webUiPort = System.getProperty("master.ui.port").toInt + if (conf.contains("master.ui.port")) { + webUiPort = conf.get("master.ui.port").toInt } parse(args.toList) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/SparkZooKeeperSession.scala b/core/src/main/scala/org/apache/spark/deploy/master/SparkZooKeeperSession.scala index 6cc7fd2ff4..999090ad74 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/SparkZooKeeperSession.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/SparkZooKeeperSession.scala @@ -23,7 +23,7 @@ import org.apache.zookeeper._ import org.apache.zookeeper.Watcher.Event.KeeperState import org.apache.zookeeper.data.Stat -import org.apache.spark.Logging +import org.apache.spark.{SparkConf, Logging} /** * Provides a Scala-side interface to the standard ZooKeeper client, with the addition of retry @@ -35,8 +35,9 @@ import org.apache.spark.Logging * Additionally, all commands sent to ZooKeeper will be retried until they either fail too many * times or a semantic exception is thrown (e.g., "node already exists"). */ -private[spark] class SparkZooKeeperSession(zkWatcher: SparkZooKeeperWatcher) extends Logging { - val ZK_URL = System.getProperty("spark.deploy.zookeeper.url", "") +private[spark] class SparkZooKeeperSession(zkWatcher: SparkZooKeeperWatcher, + conf: SparkConf) extends Logging { + val ZK_URL = conf.get("spark.deploy.zookeeper.url", "") val ZK_ACL = ZooDefs.Ids.OPEN_ACL_UNSAFE val ZK_TIMEOUT_MILLIS = 30000 diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala index 7d535b08de..77c23fb9fb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala @@ -21,16 +21,17 @@ import akka.actor.ActorRef import org.apache.zookeeper._ import org.apache.zookeeper.Watcher.Event.EventType -import org.apache.spark.Logging +import org.apache.spark.{SparkConf, Logging} import org.apache.spark.deploy.master.MasterMessages._ -private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: ActorRef, masterUrl: String) +private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: ActorRef, + masterUrl: String, conf: SparkConf) extends LeaderElectionAgent with SparkZooKeeperWatcher with Logging { - val WORKING_DIR = System.getProperty("spark.deploy.zookeeper.dir", "/spark") + "/leader_election" + val WORKING_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + "/leader_election" private val watcher = new ZooKeeperWatcher() - private val zk = new SparkZooKeeperSession(this) + private val zk = new SparkZooKeeperSession(this, conf) private var status = LeadershipStatus.NOT_LEADER private var myLeaderFile: String = _ private var leaderUrl: String = _ diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala index 825344b3bb..52000d4f9c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala @@ -17,19 +17,19 @@ package org.apache.spark.deploy.master -import org.apache.spark.Logging +import org.apache.spark.{SparkConf, Logging} import org.apache.zookeeper._ import akka.serialization.Serialization -class ZooKeeperPersistenceEngine(serialization: Serialization) +class ZooKeeperPersistenceEngine(serialization: Serialization, conf: SparkConf) extends PersistenceEngine with SparkZooKeeperWatcher with Logging { - val WORKING_DIR = System.getProperty("spark.deploy.zookeeper.dir", "/spark") + "/master_status" + val WORKING_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + "/master_status" - val zk = new SparkZooKeeperSession(this) + val zk = new SparkZooKeeperSession(this, conf) zk.connect() diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index 9ab594b682..ead35662fc 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -31,7 +31,7 @@ import org.apache.spark.util.{AkkaUtils, Utils} */ private[spark] class MasterWebUI(val master: Master, requestedPort: Int) extends Logging { - val timeout = AkkaUtils.askTimeout + val timeout = AkkaUtils.askTimeout(master.conf) val host = Utils.localHostName() val port = requestedPort diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 87531b6719..fcaf4e92b1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -25,23 +25,14 @@ import scala.collection.mutable.HashMap import scala.concurrent.duration._ import akka.actor._ -import akka.remote.{ DisassociatedEvent, RemotingLifecycleEvent} - -import org.apache.spark.{SparkException, Logging} +import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} +import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.deploy.{ExecutorDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.Master import org.apache.spark.deploy.worker.ui.WorkerWebUI import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.util.{Utils, AkkaUtils} -import org.apache.spark.deploy.DeployMessages.WorkerStateResponse -import org.apache.spark.deploy.DeployMessages.RegisterWorkerFailed -import org.apache.spark.deploy.DeployMessages.KillExecutor -import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged -import org.apache.spark.deploy.DeployMessages.Heartbeat -import org.apache.spark.deploy.DeployMessages.RegisteredWorker -import org.apache.spark.deploy.DeployMessages.LaunchExecutor -import org.apache.spark.deploy.DeployMessages.RegisterWorker +import org.apache.spark.util.{AkkaUtils, Utils} /** * @param masterUrls Each url should look like spark://host:port. @@ -53,7 +44,8 @@ private[spark] class Worker( cores: Int, memory: Int, masterUrls: Array[String], - workDirPath: String = null) + workDirPath: String = null, + val conf: SparkConf) extends Actor with Logging { import context.dispatcher @@ -63,7 +55,7 @@ private[spark] class Worker( val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For worker and executor IDs // Send a heartbeat every (heartbeat timeout) / 4 milliseconds - val HEARTBEAT_MILLIS = System.getProperty("spark.worker.timeout", "60").toLong * 1000 / 4 + val HEARTBEAT_MILLIS = conf.get("spark.worker.timeout", "60").toLong * 1000 / 4 val REGISTRATION_TIMEOUT = 20.seconds val REGISTRATION_RETRIES = 3 @@ -92,7 +84,7 @@ private[spark] class Worker( var coresUsed = 0 var memoryUsed = 0 - val metricsSystem = MetricsSystem.createMetricsSystem("worker") + val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf) val workerSource = new WorkerSource(this) def coresFree: Int = cores - coresUsed @@ -275,6 +267,7 @@ private[spark] class Worker( } private[spark] object Worker { + def main(argStrings: Array[String]) { val args = new WorkerArguments(argStrings) val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores, @@ -283,13 +276,16 @@ private[spark] object Worker { } def startSystemAndActor(host: String, port: Int, webUiPort: Int, cores: Int, memory: Int, - masterUrls: Array[String], workDir: String, workerNumber: Option[Int] = None) - : (ActorSystem, Int) = { + masterUrls: Array[String], workDir: String, workerNumber: Option[Int] = None) + : (ActorSystem, Int) = + { // The LocalSparkCluster runs multiple local sparkWorkerX actor systems + val conf = new SparkConf val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("") - val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port) + val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, + conf = conf) actorSystem.actorOf(Props(classOf[Worker], host, boundPort, webUiPort, cores, memory, - masterUrls, workDir), name = "Worker") + masterUrls, workDir, conf), name = "Worker") (actorSystem, boundPort) } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index 40d6bdb3fd..c382034c99 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -22,7 +22,7 @@ import java.io.File import javax.servlet.http.HttpServletRequest import org.eclipse.jetty.server.{Handler, Server} -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkConf} import org.apache.spark.deploy.worker.Worker import org.apache.spark.ui.{JettyUtils, UIUtils} import org.apache.spark.ui.JettyUtils._ @@ -34,10 +34,10 @@ import org.apache.spark.util.{AkkaUtils, Utils} private[spark] class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[Int] = None) extends Logging { - val timeout = AkkaUtils.askTimeout + val timeout = AkkaUtils.askTimeout(worker.conf) val host = Utils.localHostName() val port = requestedPort.getOrElse( - System.getProperty("worker.ui.port", WorkerWebUI.DEFAULT_PORT).toInt) + worker.conf.get("worker.ui.port", WorkerWebUI.DEFAULT_PORT).toInt) var server: Option[Server] = None var boundPort: Option[Int] = None @@ -140,12 +140,12 @@ class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[I <body> {linkToMaster} <div> - <div style="float:left;width:40%">{backButton}</div> + <div style="float:left; margin-right:10px">{backButton}</div> <div style="float:left;">{range}</div> - <div style="float:right;">{nextButton}</div> + <div style="float:right; margin-left:10px">{nextButton}</div> </div> <br /> - <div style="height:500px;overflow:auto;padding:5px;"> + <div style="height:500px; overflow:auto; padding:5px;"> <pre>{logText}</pre> </div> </body> diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index debbdd4c44..53a2b94a52 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -22,7 +22,7 @@ import java.nio.ByteBuffer import akka.actor._ import akka.remote._ -import org.apache.spark.Logging +import org.apache.spark.{SparkConf, SparkContext, Logging} import org.apache.spark.TaskState.TaskState import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util.{Utils, AkkaUtils} @@ -98,10 +98,10 @@ private[spark] object CoarseGrainedExecutorBackend { // Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor // before getting started with all our system properties, etc val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0, - indestructible = true) + indestructible = true, conf = new SparkConf) // set it val sparkHostPort = hostname + ":" + boundPort - System.setProperty("spark.hostPort", sparkHostPort) +// conf.set("spark.hostPort", sparkHostPort) actorSystem.actorOf( Props(classOf[CoarseGrainedExecutorBackend], driverUrl, executorId, sparkHostPort, cores), name = "Executor") diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 0f19d7a96b..e51d274d33 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -48,8 +48,6 @@ private[spark] class Executor( private val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0)) - initLogging() - // No ip or host:port - just hostname Utils.checkHost(slaveHostname, "Expected executed slave to be a hostname") // must not have port specified. @@ -58,16 +56,17 @@ private[spark] class Executor( // Make sure the local hostname we report matches the cluster scheduler's name for this host Utils.setCustomHostname(slaveHostname) - // Set spark.* system properties from executor arg - for ((key, value) <- properties) { - System.setProperty(key, value) - } + // Set spark.* properties from executor arg + val conf = new SparkConf(false) + conf.setAll(properties) // If we are in yarn mode, systems can have different disk layouts so we must set it // to what Yarn on this system said was available. This will be used later when SparkEnv // created. - if (java.lang.Boolean.valueOf(System.getenv("SPARK_YARN_MODE"))) { - System.setProperty("spark.local.dir", getYarnLocalDirs()) + if (java.lang.Boolean.valueOf( + System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE")))) + { + conf.set("spark.local.dir", getYarnLocalDirs()) } // Create our ClassLoader and set it on this thread @@ -108,7 +107,7 @@ private[spark] class Executor( // Initialize Spark environment (using system properties read above) private val env = { if (!isLocal) { - val _env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0, + val _env = SparkEnv.create(conf, executorId, slaveHostname, 0, isDriver = false, isLocal = false) SparkEnv.set(_env) _env.metricsSystem.registerSource(executorSource) @@ -142,11 +141,6 @@ private[spark] class Executor( val tr = runningTasks.get(taskId) if (tr != null) { tr.kill() - // We remove the task also in the finally block in TaskRunner.run. - // The reason we need to remove it here is because killTask might be called before the task - // is even launched, and never reaching that finally block. ConcurrentHashMap's remove is - // idempotent. - runningTasks.remove(taskId) } } @@ -168,6 +162,8 @@ private[spark] class Executor( class TaskRunner(execBackend: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) extends Runnable { + object TaskKilledException extends Exception + @volatile private var killed = false @volatile private var task: Task[Any] = _ @@ -201,9 +197,11 @@ private[spark] class Executor( // If this task has been killed before we deserialized it, let's quit now. Otherwise, // continue executing the task. if (killed) { - logInfo("Executor killed task " + taskId) - execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) - return + // Throw an exception rather than returning, because returning within a try{} block + // causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl + // exception will be caught by the catch block, leading to an incorrect ExceptionFailure + // for the task. + throw TaskKilledException } attemptedTask = Some(task) @@ -217,9 +215,7 @@ private[spark] class Executor( // If the task has been killed, let's fail it. if (task.killed) { - logInfo("Executor killed task " + taskId) - execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) - return + throw TaskKilledException } val resultSer = SparkEnv.get.serializer.newInstance() @@ -261,6 +257,11 @@ private[spark] class Executor( execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) } + case TaskKilledException => { + logInfo("Executor killed task " + taskId) + execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) + } + case t: Throwable => { val serviceTime = (System.currentTimeMillis() - taskStart).toInt val metrics = attemptedTask.flatMap(t => t.metrics) @@ -303,7 +304,7 @@ private[spark] class Executor( * new classes defined by the REPL as the user types code */ private def addReplClassLoaderIfNeeded(parent: ClassLoader): ClassLoader = { - val classUri = System.getProperty("spark.repl.class.uri") + val classUri = conf.get("spark.repl.class.uri", null) if (classUri != null) { logInfo("Using REPL class URI: " + classUri) try { @@ -331,12 +332,12 @@ private[spark] class Executor( // Fetch missing dependencies for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf) currentFiles(name) = timestamp } for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf) currentJars(name) = timestamp // Add it to our class loader val localName = name.split("/").last diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 570a979b56..a1e98845f6 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -22,6 +22,7 @@ import java.io.{InputStream, OutputStream} import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} import org.xerial.snappy.{SnappyInputStream, SnappyOutputStream} +import org.apache.spark.{SparkEnv, SparkConf} /** @@ -37,15 +38,15 @@ trait CompressionCodec { private[spark] object CompressionCodec { - - def createCodec(): CompressionCodec = { - createCodec(System.getProperty( + def createCodec(conf: SparkConf): CompressionCodec = { + createCodec(conf, conf.get( "spark.io.compression.codec", classOf[LZFCompressionCodec].getName)) } - def createCodec(codecName: String): CompressionCodec = { - Class.forName(codecName, true, Thread.currentThread.getContextClassLoader) - .newInstance().asInstanceOf[CompressionCodec] + def createCodec(conf: SparkConf, codecName: String): CompressionCodec = { + val ctor = Class.forName(codecName, true, Thread.currentThread.getContextClassLoader) + .getConstructor(classOf[SparkConf]) + ctor.newInstance(conf).asInstanceOf[CompressionCodec] } } @@ -53,7 +54,7 @@ private[spark] object CompressionCodec { /** * LZF implementation of [[org.apache.spark.io.CompressionCodec]]. */ -class LZFCompressionCodec extends CompressionCodec { +class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec { override def compressedOutputStream(s: OutputStream): OutputStream = { new LZFOutputStream(s).setFinishBlockOnFlush(true) @@ -67,10 +68,10 @@ class LZFCompressionCodec extends CompressionCodec { * Snappy implementation of [[org.apache.spark.io.CompressionCodec]]. * Block size can be configured by spark.io.compression.snappy.block.size. */ -class SnappyCompressionCodec extends CompressionCodec { +class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec { override def compressedOutputStream(s: OutputStream): OutputStream = { - val blockSize = System.getProperty("spark.io.compression.snappy.block.size", "32768").toInt + val blockSize = conf.get("spark.io.compression.snappy.block.size", "32768").toInt new SnappyOutputStream(s, blockSize) } diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala index caab748d60..6f9f29969e 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala @@ -26,7 +26,6 @@ import scala.util.matching.Regex import org.apache.spark.Logging private[spark] class MetricsConfig(val configFile: Option[String]) extends Logging { - initLogging() val DEFAULT_PREFIX = "*" val INSTANCE_REGEX = "^(\\*|[a-zA-Z]+)\\.(.+)".r diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index bec0c83be8..9930537b34 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -24,7 +24,7 @@ import java.util.concurrent.TimeUnit import scala.collection.mutable -import org.apache.spark.Logging +import org.apache.spark.{SparkConf, Logging} import org.apache.spark.metrics.sink.{MetricsServlet, Sink} import org.apache.spark.metrics.source.Source @@ -62,10 +62,10 @@ import org.apache.spark.metrics.source.Source * * [options] is the specific property of this source or sink. */ -private[spark] class MetricsSystem private (val instance: String) extends Logging { - initLogging() +private[spark] class MetricsSystem private (val instance: String, + conf: SparkConf) extends Logging { - val confFile = System.getProperty("spark.metrics.conf") + val confFile = conf.get("spark.metrics.conf", null) val metricsConfig = new MetricsConfig(Option(confFile)) val sinks = new mutable.ArrayBuffer[Sink] @@ -159,5 +159,6 @@ private[spark] object MetricsSystem { } } - def createMetricsSystem(instance: String): MetricsSystem = new MetricsSystem(instance) + def createMetricsSystem(instance: String, conf: SparkConf): MetricsSystem = + new MetricsSystem(instance, conf) } diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index 703bc6a9ca..46c40d0a2a 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -37,7 +37,7 @@ import scala.concurrent.duration._ import org.apache.spark.util.Utils -private[spark] class ConnectionManager(port: Int) extends Logging { +private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Logging { class MessageStatus( val message: Message, @@ -54,22 +54,22 @@ private[spark] class ConnectionManager(port: Int) extends Logging { private val selector = SelectorProvider.provider.openSelector() private val handleMessageExecutor = new ThreadPoolExecutor( - System.getProperty("spark.core.connection.handler.threads.min","20").toInt, - System.getProperty("spark.core.connection.handler.threads.max","60").toInt, - System.getProperty("spark.core.connection.handler.threads.keepalive","60").toInt, TimeUnit.SECONDS, + conf.get("spark.core.connection.handler.threads.min", "20").toInt, + conf.get("spark.core.connection.handler.threads.max", "60").toInt, + conf.get("spark.core.connection.handler.threads.keepalive", "60").toInt, TimeUnit.SECONDS, new LinkedBlockingDeque[Runnable]()) private val handleReadWriteExecutor = new ThreadPoolExecutor( - System.getProperty("spark.core.connection.io.threads.min","4").toInt, - System.getProperty("spark.core.connection.io.threads.max","32").toInt, - System.getProperty("spark.core.connection.io.threads.keepalive","60").toInt, TimeUnit.SECONDS, + conf.get("spark.core.connection.io.threads.min", "4").toInt, + conf.get("spark.core.connection.io.threads.max", "32").toInt, + conf.get("spark.core.connection.io.threads.keepalive", "60").toInt, TimeUnit.SECONDS, new LinkedBlockingDeque[Runnable]()) // Use a different, yet smaller, thread pool - infrequently used with very short lived tasks : which should be executed asap private val handleConnectExecutor = new ThreadPoolExecutor( - System.getProperty("spark.core.connection.connect.threads.min","1").toInt, - System.getProperty("spark.core.connection.connect.threads.max","8").toInt, - System.getProperty("spark.core.connection.connect.threads.keepalive","60").toInt, TimeUnit.SECONDS, + conf.get("spark.core.connection.connect.threads.min", "1").toInt, + conf.get("spark.core.connection.connect.threads.max", "8").toInt, + conf.get("spark.core.connection.connect.threads.keepalive", "60").toInt, TimeUnit.SECONDS, new LinkedBlockingDeque[Runnable]()) private val serverChannel = ServerSocketChannel.open() @@ -594,7 +594,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging { private[spark] object ConnectionManager { def main(args: Array[String]) { - val manager = new ConnectionManager(9999) + val manager = new ConnectionManager(9999, new SparkConf) manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { println("Received [" + msg + "] from [" + id + "]") None diff --git a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala b/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala index 781715108b..1c9d6030d6 100644 --- a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala +++ b/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala @@ -19,19 +19,19 @@ package org.apache.spark.network import java.nio.ByteBuffer import java.net.InetAddress +import org.apache.spark.SparkConf private[spark] object ReceiverTest { - def main(args: Array[String]) { - val manager = new ConnectionManager(9999) + val manager = new ConnectionManager(9999, new SparkConf) println("Started connection manager with id = " + manager.id) - - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + + manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { /*println("Received [" + msg + "] from [" + id + "] at " + System.currentTimeMillis)*/ - val buffer = ByteBuffer.wrap("response".getBytes()) + val buffer = ByteBuffer.wrap("response".getBytes) Some(Message.createBufferMessage(buffer, msg.id)) }) - Thread.currentThread.join() + Thread.currentThread.join() } } diff --git a/core/src/main/scala/org/apache/spark/network/SenderTest.scala b/core/src/main/scala/org/apache/spark/network/SenderTest.scala index 777574980f..dcbd183c88 100644 --- a/core/src/main/scala/org/apache/spark/network/SenderTest.scala +++ b/core/src/main/scala/org/apache/spark/network/SenderTest.scala @@ -19,29 +19,29 @@ package org.apache.spark.network import java.nio.ByteBuffer import java.net.InetAddress +import org.apache.spark.SparkConf private[spark] object SenderTest { - def main(args: Array[String]) { - + if (args.length < 2) { println("Usage: SenderTest <target host> <target port>") System.exit(1) } - + val targetHost = args(0) val targetPort = args(1).toInt val targetConnectionManagerId = new ConnectionManagerId(targetHost, targetPort) - val manager = new ConnectionManager(0) + val manager = new ConnectionManager(0, new SparkConf) println("Started connection manager with id = " + manager.id) - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { println("Received [" + msg + "] from [" + id + "]") None }) - - val size = 100 * 1024 * 1024 + + val size = 100 * 1024 * 1024 val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) buffer.flip @@ -50,7 +50,7 @@ private[spark] object SenderTest { val count = 100 (0 until count).foreach(i => { val dataMessage = Message.createBufferMessage(buffer.duplicate) - val startTime = System.currentTimeMillis + val startTime = System.currentTimeMillis /*println("Started timer at " + startTime)*/ val responseStr = manager.sendMessageReliablySync(targetConnectionManagerId, dataMessage) match { case Some(response) => diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala index b1e1576dad..b729eb11c5 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala @@ -23,20 +23,20 @@ import io.netty.buffer.ByteBuf import io.netty.channel.ChannelHandlerContext import io.netty.util.CharsetUtil -import org.apache.spark.Logging +import org.apache.spark.{SparkContext, SparkConf, Logging} import org.apache.spark.network.ConnectionManagerId import scala.collection.JavaConverters._ import org.apache.spark.storage.BlockId -private[spark] class ShuffleCopier extends Logging { +private[spark] class ShuffleCopier(conf: SparkConf) extends Logging { def getBlock(host: String, port: Int, blockId: BlockId, resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) { val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback) - val connectTimeout = System.getProperty("spark.shuffle.netty.connect.timeout", "60000").toInt + val connectTimeout = conf.get("spark.shuffle.netty.connect.timeout", "60000").toInt val fc = new FileClient(handler, connectTimeout) try { @@ -104,10 +104,10 @@ private[spark] object ShuffleCopier extends Logging { val threads = if (args.length > 3) args(3).toInt else 10 val copiers = Executors.newFixedThreadPool(80) - val tasks = (for (i <- Range(0, threads)) yield { + val tasks = (for (i <- Range(0, threads)) yield { Executors.callable(new Runnable() { def run() { - val copier = new ShuffleCopier() + val copier = new ShuffleCopier(new SparkConf) copier.getBlock(host, port, blockId, echoResultCollectCallBack) } }) diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index a712ef1c27..6d4f46125f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -18,12 +18,12 @@ package org.apache.spark.rdd import java.io.IOException - import scala.reflect.ClassTag - -import org.apache.hadoop.fs.Path import org.apache.spark._ +import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {} @@ -34,6 +34,8 @@ private[spark] class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String) extends RDD[T](sc, Nil) { + val broadcastedConf = sc.broadcast(new SerializableWritable(sc.hadoopConfiguration)) + @transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration) override def getPartitions: Array[Partition] = { @@ -65,7 +67,7 @@ class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String) override def compute(split: Partition, context: TaskContext): Iterator[T] = { val file = new Path(checkpointPath, CheckpointRDD.splitIdToFile(split.index)) - CheckpointRDD.readFromFile(file, context) + CheckpointRDD.readFromFile(file, broadcastedConf, context) } override def checkpoint() { @@ -74,15 +76,18 @@ class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String) } private[spark] object CheckpointRDD extends Logging { - def splitIdToFile(splitId: Int): String = { "part-%05d".format(splitId) } - def writeToFile[T](path: String, blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) { + def writeToFile[T]( + path: String, + broadcastedConf: Broadcast[SerializableWritable[Configuration]], + blockSize: Int = -1 + )(ctx: TaskContext, iterator: Iterator[T]) { val env = SparkEnv.get val outputDir = new Path(path) - val fs = outputDir.getFileSystem(SparkHadoopUtil.get.newConfiguration()) + val fs = outputDir.getFileSystem(broadcastedConf.value.value) val finalOutputName = splitIdToFile(ctx.partitionId) val finalOutputPath = new Path(outputDir, finalOutputName) @@ -92,7 +97,7 @@ private[spark] object CheckpointRDD extends Logging { throw new IOException("Checkpoint failed: temporary path " + tempOutputPath + " already exists") } - val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt + val bufferSize = env.conf.get("spark.buffer.size", "65536").toInt val fileOutputStream = if (blockSize < 0) { fs.create(tempOutputPath, false, bufferSize) @@ -119,10 +124,14 @@ private[spark] object CheckpointRDD extends Logging { } } - def readFromFile[T](path: Path, context: TaskContext): Iterator[T] = { + def readFromFile[T]( + path: Path, + broadcastedConf: Broadcast[SerializableWritable[Configuration]], + context: TaskContext + ): Iterator[T] = { val env = SparkEnv.get - val fs = path.getFileSystem(SparkHadoopUtil.get.newConfiguration()) - val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt + val fs = path.getFileSystem(broadcastedConf.value.value) + val bufferSize = env.conf.get("spark.buffer.size", "65536").toInt val fileInputStream = fs.open(path, bufferSize) val serializer = env.serializer.newInstance() val deserializeStream = serializer.deserializeStream(fileInputStream) @@ -144,8 +153,10 @@ private[spark] object CheckpointRDD extends Logging { val sc = new SparkContext(cluster, "CheckpointRDD Test") val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000) val path = new Path(hdfsPath, "temp") - val fs = path.getFileSystem(SparkHadoopUtil.get.newConfiguration()) - sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 1024) _) + val conf = SparkHadoopUtil.get.newConfiguration() + val fs = path.getFileSystem(conf) + val broadcastedConf = sc.broadcast(new SerializableWritable(conf)) + sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, broadcastedConf, 1024) _) val cpRDD = new CheckpointRDD[Int](sc, path.toString) assert(cpRDD.partitions.length == rdd.partitions.length, "Number of partitions is not the same") assert(cpRDD.collect.toList == rdd.collect.toList, "Data of partitions not the same") diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 911a002884..4ba4696fef 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -114,7 +114,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: map.changeValue(k, update) } - val ser = SparkEnv.get.serializerManager.get(serializerClass) + val ser = SparkEnv.get.serializerManager.get(serializerClass, SparkEnv.get.conf) for ((dep, depNum) <- split.deps.zipWithIndex) dep match { case NarrowCoGroupSplitDep(rdd, _, itsSplit) => { // Read them from the parent diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 48168e152e..04a8d05988 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -40,12 +40,15 @@ import org.apache.hadoop.mapreduce.SparkHadoopMapReduceUtil import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob} import org.apache.hadoop.mapreduce.{RecordWriter => NewRecordWriter} +import com.clearspring.analytics.stream.cardinality.HyperLogLog + import org.apache.spark._ import org.apache.spark.SparkContext._ import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.Aggregator import org.apache.spark.Partitioner import org.apache.spark.Partitioner.defaultPartitioner +import org.apache.spark.util.SerializableHyperLogLog /** * Extra functions available on RDDs of (key, value) pairs through an implicit conversion. @@ -208,6 +211,45 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) } /** + * Return approximate number of distinct values for each key in this RDD. + * The accuracy of approximation can be controlled through the relative standard deviation + * (relativeSD) parameter, which also controls the amount of memory used. Lower values result in + * more accurate counts but increase the memory footprint and vise versa. Uses the provided + * Partitioner to partition the output RDD. + */ + def countApproxDistinctByKey(relativeSD: Double, partitioner: Partitioner): RDD[(K, Long)] = { + val createHLL = (v: V) => new SerializableHyperLogLog(new HyperLogLog(relativeSD)).add(v) + val mergeValueHLL = (hll: SerializableHyperLogLog, v: V) => hll.add(v) + val mergeHLL = (h1: SerializableHyperLogLog, h2: SerializableHyperLogLog) => h1.merge(h2) + + combineByKey(createHLL, mergeValueHLL, mergeHLL, partitioner).mapValues(_.value.cardinality()) + } + + /** + * Return approximate number of distinct values for each key in this RDD. + * The accuracy of approximation can be controlled through the relative standard deviation + * (relativeSD) parameter, which also controls the amount of memory used. Lower values result in + * more accurate counts but increase the memory footprint and vise versa. HashPartitions the + * output RDD into numPartitions. + * + */ + def countApproxDistinctByKey(relativeSD: Double, numPartitions: Int): RDD[(K, Long)] = { + countApproxDistinctByKey(relativeSD, new HashPartitioner(numPartitions)) + } + + /** + * Return approximate number of distinct values for each key this RDD. + * The accuracy of approximation can be controlled through the relative standard deviation + * (relativeSD) parameter, which also controls the amount of memory used. Lower values result in + * more accurate counts but increase the memory footprint and vise versa. The default value of + * relativeSD is 0.05. Hash-partitions the output RDD using the existing partitioner/parallelism + * level. + */ + def countApproxDistinctByKey(relativeSD: Double = 0.05): RDD[(K, Long)] = { + countApproxDistinctByKey(relativeSD, defaultPartitioner(self)) + } + + /** * Merge the values for each key using an associative reduce function. This will also perform * the merging locally on each mapper before sending results to a reducer, similarly to a * "combiner" in MapReduce. Output will be hash-partitioned with numPartitions partitions. diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala new file mode 100644 index 0000000000..4c625d062e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.reflect.ClassTag +import java.io.{ObjectOutputStream, IOException} +import org.apache.spark.{TaskContext, OneToOneDependency, SparkContext, Partition} + + +/** + * Class representing partitions of PartitionerAwareUnionRDD, which maintains the list of corresponding partitions + * of parent RDDs. + */ +private[spark] +class PartitionerAwareUnionRDDPartition( + @transient val rdds: Seq[RDD[_]], + val idx: Int + ) extends Partition { + var parents = rdds.map(_.partitions(idx)).toArray + + override val index = idx + override def hashCode(): Int = idx + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + // Update the reference to parent partition at the time of task serialization + parents = rdds.map(_.partitions(index)).toArray + oos.defaultWriteObject() + } +} + +/** + * Class representing an RDD that can take multiple RDDs partitioned by the same partitioner and + * unify them into a single RDD while preserving the partitioner. So m RDDs with p partitions each + * will be unified to a single RDD with p partitions and the same partitioner. The preferred + * location for each partition of the unified RDD will be the most common preferred location + * of the corresponding partitions of the parent RDDs. For example, location of partition 0 + * of the unified RDD will be where most of partition 0 of the parent RDDs are located. + */ +private[spark] +class PartitionerAwareUnionRDD[T: ClassTag]( + sc: SparkContext, + var rdds: Seq[RDD[T]] + ) extends RDD[T](sc, rdds.map(x => new OneToOneDependency(x))) { + require(rdds.length > 0) + require(rdds.flatMap(_.partitioner).toSet.size == 1, + "Parent RDDs have different partitioners: " + rdds.flatMap(_.partitioner)) + + override val partitioner = rdds.head.partitioner + + override def getPartitions: Array[Partition] = { + val numPartitions = partitioner.get.numPartitions + (0 until numPartitions).map(index => { + new PartitionerAwareUnionRDDPartition(rdds, index) + }).toArray + } + + // Get the location where most of the partitions of parent RDDs are located + override def getPreferredLocations(s: Partition): Seq[String] = { + logDebug("Finding preferred location for " + this + ", partition " + s.index) + val parentPartitions = s.asInstanceOf[PartitionerAwareUnionRDDPartition].parents + val locations = rdds.zip(parentPartitions).flatMap { + case (rdd, part) => { + val parentLocations = currPrefLocs(rdd, part) + logDebug("Location of " + rdd + " partition " + part.index + " = " + parentLocations) + parentLocations + } + } + val location = if (locations.isEmpty) { + None + } else { + // Find the location that maximum number of parent partitions prefer + Some(locations.groupBy(x => x).maxBy(_._2.length)._1) + } + logDebug("Selected location for " + this + ", partition " + s.index + " = " + location) + location.toSeq + } + + override def compute(s: Partition, context: TaskContext): Iterator[T] = { + val parentPartitions = s.asInstanceOf[PartitionerAwareUnionRDDPartition].parents + rdds.zip(parentPartitions).iterator.flatMap { + case (rdd, p) => rdd.iterator(p, context) + } + } + + override def clearDependencies() { + super.clearDependencies() + rdds = null + } + + // Get the *current* preferred locations from the DAGScheduler (as opposed to the static ones) + private def currPrefLocs(rdd: RDD[_], part: Partition): Seq[String] = { + rdd.context.getPreferredLocs(rdd, part.index).map(tl => tl.host) + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index ea45566ad1..3f41b66279 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -33,6 +33,7 @@ import org.apache.hadoop.io.Text import org.apache.hadoop.mapred.TextOutputFormat import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} +import com.clearspring.analytics.stream.cardinality.HyperLogLog import org.apache.spark.Partitioner._ import org.apache.spark.api.java.JavaRDD @@ -41,7 +42,7 @@ import org.apache.spark.partial.CountEvaluator import org.apache.spark.partial.GroupedCountEvaluator import org.apache.spark.partial.PartialResult import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{Utils, BoundedPriorityQueue} +import org.apache.spark.util.{Utils, BoundedPriorityQueue, SerializableHyperLogLog} import org.apache.spark.SparkContext._ import org.apache.spark._ @@ -81,6 +82,7 @@ abstract class RDD[T: ClassTag]( def this(@transient oneParent: RDD[_]) = this(oneParent.context , List(new OneToOneDependency(oneParent))) + private[spark] def conf = sc.conf // ======================================================================= // Methods that should be implemented by subclasses of RDD // ======================================================================= @@ -789,6 +791,19 @@ abstract class RDD[T: ClassTag]( } /** + * Return approximate number of distinct elements in the RDD. + * + * The accuracy of approximation can be controlled through the relative standard deviation + * (relativeSD) parameter, which also controls the amount of memory used. Lower values result in + * more accurate counts but increase the memory footprint and vise versa. The default value of + * relativeSD is 0.05. + */ + def countApproxDistinct(relativeSD: Double = 0.05): Long = { + val zeroCounter = new SerializableHyperLogLog(new HyperLogLog(relativeSD)) + aggregate(zeroCounter)(_.add(_), _.merge(_)).value.cardinality() + } + + /** * Take the first num elements of the RDD. It works by first scanning one partition, and use the * results from that partition to estimate the number of additional partitions needed to satisfy * the limit. @@ -938,7 +953,7 @@ abstract class RDD[T: ClassTag]( private var storageLevel: StorageLevel = StorageLevel.NONE /** Record user function generating this RDD. */ - @transient private[spark] val origin = Utils.formatSparkCallSite + @transient private[spark] val origin = sc.getCallSite private[spark] def elementClassTag: ClassTag[T] = classTag[T] diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index 3b56e45aa9..bc688110f4 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -22,7 +22,7 @@ import scala.reflect.ClassTag import org.apache.hadoop.fs.Path import org.apache.hadoop.conf.Configuration -import org.apache.spark.{Partition, SparkException, Logging} +import org.apache.spark.{SerializableWritable, Partition, SparkException, Logging} import org.apache.spark.scheduler.{ResultTask, ShuffleMapTask} /** @@ -40,7 +40,7 @@ private[spark] object CheckpointState extends Enumeration { * manages the post-checkpoint state by providing the updated partitions, iterator and preferred locations * of the checkpointed RDD. */ -private[spark] class RDDCheckpointData[T: ClassTag](rdd: RDD[T]) +private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) extends Logging with Serializable { import CheckpointState._ @@ -85,14 +85,21 @@ private[spark] class RDDCheckpointData[T: ClassTag](rdd: RDD[T]) // Create the output path for the checkpoint val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id) - val fs = path.getFileSystem(new Configuration()) + val fs = path.getFileSystem(rdd.context.hadoopConfiguration) if (!fs.mkdirs(path)) { throw new SparkException("Failed to create checkpoint path " + path) } // Save to file, and reload it as an RDD - rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path.toString) _) + val broadcastedConf = rdd.context.broadcast( + new SerializableWritable(rdd.context.hadoopConfiguration)) + rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path.toString, broadcastedConf) _) val newRDD = new CheckpointRDD[T](rdd.context, path.toString) + if (newRDD.partitions.size != rdd.partitions.size) { + throw new SparkException( + "Checkpoint RDD " + newRDD + "("+ newRDD.partitions.size + ") has different " + + "number of partitions than original RDD " + rdd + "(" + rdd.partitions.size + ")") + } // Change the dependencies and partitions of the RDD RDDCheckpointData.synchronized { @@ -101,8 +108,8 @@ private[spark] class RDDCheckpointData[T: ClassTag](rdd: RDD[T]) rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions cpState = Checkpointed RDDCheckpointData.clearTaskCaches() - logInfo("Done checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id) } + logInfo("Done checkpointing RDD " + rdd.id + " to " + path + ", new parent is RDD " + newRDD.id) } // Get preferred location of a split after checkpointing diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index 3682c84598..0ccb309d0d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -59,7 +59,7 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag]( override def compute(split: Partition, context: TaskContext): Iterator[P] = { val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context, - SparkEnv.get.serializerManager.get(serializerClass)) + SparkEnv.get.serializerManager.get(serializerClass, SparkEnv.get.conf)) } override def clearDependencies() { diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index aab30b1bb4..4f90c7d3d6 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -93,7 +93,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = { val partition = p.asInstanceOf[CoGroupPartition] - val serializer = SparkEnv.get.serializerManager.get(serializerClass) + val serializer = SparkEnv.get.serializerManager.get(serializerClass, SparkEnv.get.conf) val map = new JHashMap[K, ArrayBuffer[V]] def getSeq(k: K): ArrayBuffer[V] = { val seq = map.get(k) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 963d15b76d..043e01dbfb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -152,13 +152,15 @@ class DAGScheduler( val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done val running = new HashSet[Stage] // Stages we are running right now val failed = new HashSet[Stage] // Stages that must be resubmitted due to fetch failures - val pendingTasks = new TimeStampedHashMap[Stage, HashSet[Task[_]]] // Missing tasks from each stage + // Missing tasks from each stage + val pendingTasks = new TimeStampedHashMap[Stage, HashSet[Task[_]]] var lastFetchFailureTime: Long = 0 // Used to wait a bit to avoid repeated resubmits val activeJobs = new HashSet[ActiveJob] val resultStageToJob = new HashMap[Stage, ActiveJob] - val metadataCleaner = new MetadataCleaner(MetadataCleanerType.DAG_SCHEDULER, this.cleanup) + val metadataCleaner = new MetadataCleaner( + MetadataCleanerType.DAG_SCHEDULER, this.cleanup, env.conf) /** * Starts the event processing actor. The actor has two responsibilities: @@ -239,7 +241,8 @@ class DAGScheduler( shuffleToMapStage.get(shuffleDep.shuffleId) match { case Some(stage) => stage case None => - val stage = newOrUsedStage(shuffleDep.rdd, shuffleDep.rdd.partitions.size, shuffleDep, jobId) + val stage = + newOrUsedStage(shuffleDep.rdd, shuffleDep.rdd.partitions.size, shuffleDep, jobId) shuffleToMapStage(shuffleDep.shuffleId) = stage stage } @@ -248,7 +251,8 @@ class DAGScheduler( /** * Create a Stage -- either directly for use as a result stage, or as part of the (re)-creation * of a shuffle map stage in newOrUsedStage. The stage will be associated with the provided - * jobId. Production of shuffle map stages should always use newOrUsedStage, not newStage directly. + * jobId. Production of shuffle map stages should always use newOrUsedStage, not newStage + * directly. */ private def newStage( rdd: RDD[_], @@ -358,7 +362,8 @@ class DAGScheduler( stageIdToJobIds.getOrElseUpdate(s.id, new HashSet[Int]()) += jobId jobIdToStageIds.getOrElseUpdate(jobId, new HashSet[Int]()) += s.id val parents = getParentStages(s.rdd, jobId) - val parentsWithoutThisJobId = parents.filter(p => !stageIdToJobIds.get(p.id).exists(_.contains(jobId))) + val parentsWithoutThisJobId = parents.filter(p => + !stageIdToJobIds.get(p.id).exists(_.contains(jobId))) updateJobIdStageIdMapsList(parentsWithoutThisJobId ++ stages.tail) } } @@ -366,8 +371,9 @@ class DAGScheduler( } /** - * Removes job and any stages that are not needed by any other job. Returns the set of ids for stages that - * were removed. The associated tasks for those stages need to be cancelled if we got here via job cancellation. + * Removes job and any stages that are not needed by any other job. Returns the set of ids for + * stages that were removed. The associated tasks for those stages need to be cancelled if we + * got here via job cancellation. */ private def removeJobAndIndependentStages(jobId: Int): Set[Int] = { val registeredStages = jobIdToStageIds(jobId) @@ -378,7 +384,8 @@ class DAGScheduler( stageIdToJobIds.filterKeys(stageId => registeredStages.contains(stageId)).foreach { case (stageId, jobSet) => if (!jobSet.contains(jobId)) { - logError("Job %d not registered for stage %d even though that stage was registered for the job" + logError( + "Job %d not registered for stage %d even though that stage was registered for the job" .format(jobId, stageId)) } else { def removeStage(stageId: Int) { @@ -389,7 +396,8 @@ class DAGScheduler( running -= s } stageToInfos -= s - shuffleToMapStage.keys.filter(shuffleToMapStage(_) == s).foreach(shuffleToMapStage.remove) + shuffleToMapStage.keys.filter(shuffleToMapStage(_) == s).foreach(shuffleId => + shuffleToMapStage.remove(shuffleId)) if (pendingTasks.contains(s) && !pendingTasks(s).isEmpty) { logDebug("Removing pending status for stage %d".format(stageId)) } @@ -407,7 +415,8 @@ class DAGScheduler( stageIdToStage -= stageId stageIdToJobIds -= stageId - logDebug("After removal of stage %d, remaining stages = %d".format(stageId, stageIdToStage.size)) + logDebug("After removal of stage %d, remaining stages = %d" + .format(stageId, stageIdToStage.size)) } jobSet -= jobId @@ -459,7 +468,8 @@ class DAGScheduler( assert(partitions.size > 0) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler) - eventProcessActor ! JobSubmitted(jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties) + eventProcessActor ! JobSubmitted( + jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties) waiter } @@ -494,7 +504,8 @@ class DAGScheduler( val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val partitions = (0 until rdd.partitions.size).toArray val jobId = nextJobId.getAndIncrement() - eventProcessActor ! JobSubmitted(jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, properties) + eventProcessActor ! JobSubmitted( + jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, properties) listener.awaitResult() // Will throw an exception if the job fails } @@ -529,8 +540,8 @@ class DAGScheduler( case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) => var finalStage: Stage = null try { - // New stage creation at times and if its not protected, the scheduler thread is killed. - // e.g. it can fail when jobs are run on HadoopRDD whose underlying hdfs files have been deleted + // New stage creation may throw an exception if, for example, jobs are run on a HadoopRDD + // whose underlying HDFS files have been deleted. finalStage = newStage(rdd, partitions.size, None, jobId, Some(callSite)) } catch { case e: Exception => @@ -563,7 +574,8 @@ class DAGScheduler( case JobGroupCancelled(groupId) => // Cancel all jobs belonging to this job group. // First finds all active jobs with this group id, and then kill stages for them. - val activeInGroup = activeJobs.filter(groupId == _.properties.get(SparkContext.SPARK_JOB_GROUP_ID)) + val activeInGroup = activeJobs.filter(activeJob => + groupId == activeJob.properties.get(SparkContext.SPARK_JOB_GROUP_ID)) val jobIds = activeInGroup.map(_.jobId) jobIds.foreach { handleJobCancellation } @@ -585,7 +597,8 @@ class DAGScheduler( stage <- stageIdToStage.get(task.stageId); stageInfo <- stageToInfos.get(stage) ) { - if (taskInfo.serializedSize > TASK_SIZE_TO_WARN * 1024 && !stageInfo.emittedTaskSizeWarning) { + if (taskInfo.serializedSize > TASK_SIZE_TO_WARN * 1024 && + !stageInfo.emittedTaskSizeWarning) { stageInfo.emittedTaskSizeWarning = true logWarning(("Stage %d (%s) contains a task of very large " + "size (%d KB). The maximum recommended task size is %d KB.").format( @@ -815,7 +828,7 @@ class DAGScheduler( } logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime)) stageToInfos(stage).completionTime = Some(System.currentTimeMillis()) - listenerBus.post(StageCompleted(stageToInfos(stage))) + listenerBus.post(SparkListenerStageCompleted(stageToInfos(stage))) running -= stage } event.reason match { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorLossReason.scala b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala index 5077b2b48b..2bc43a9186 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorLossReason.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.scheduler.cluster +package org.apache.spark.scheduler import org.apache.spark.executor.ExecutorExitCode diff --git a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala index 1791ee660d..90eb8a747f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala @@ -32,7 +32,7 @@ import scala.collection.JavaConversions._ /** * Parses and holds information about inputFormat (and files) specified as a parameter. */ -class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Class[_], +class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Class[_], val path: String) extends Logging { var mapreduceInputFormat: Boolean = false @@ -40,7 +40,7 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl validate() - override def toString(): String = { + override def toString: String = { "InputFormatInfo " + super.toString + " .. inputFormatClazz " + inputFormatClazz + ", path : " + path } @@ -125,7 +125,7 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl } private def findPreferredLocations(): Set[SplitInfo] = { - logDebug("mapreduceInputFormat : " + mapreduceInputFormat + ", mapredInputFormat : " + mapredInputFormat + + logDebug("mapreduceInputFormat : " + mapreduceInputFormat + ", mapredInputFormat : " + mapredInputFormat + ", inputFormatClazz : " + inputFormatClazz) if (mapreduceInputFormat) { return prefLocsFromMapreduceInputFormat() @@ -143,14 +143,14 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl object InputFormatInfo { /** Computes the preferred locations based on input(s) and returned a location to block map. - Typical use of this method for allocation would follow some algo like this - (which is what we currently do in YARN branch) : + Typical use of this method for allocation would follow some algo like this: + a) For each host, count number of splits hosted on that host. b) Decrement the currently allocated containers on that host. c) Compute rack info for each host and update rack -> count map based on (b). d) Allocate nodes based on (c) - e) On the allocation result, ensure that we dont allocate "too many" jobs on a single node - (even if data locality on that is very high) : this is to prevent fragility of job if a single + e) On the allocation result, ensure that we dont allocate "too many" jobs on a single node + (even if data locality on that is very high) : this is to prevent fragility of job if a single (or small set of) hosts go down. go to (a) until required nodes are allocated. diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala index 60927831a1..f8fa5a9f7a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala @@ -297,7 +297,7 @@ class JobLogger(val user: String, val logDirName: String) * When stage is completed, record stage completion status * @param stageCompleted Stage completed event */ - override def onStageCompleted(stageCompleted: StageCompleted) { + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) { stageLogInfo(stageCompleted.stage.stageId, "STAGE_ID=%d STATUS=COMPLETED".format( stageCompleted.stage.stageId)) } @@ -328,10 +328,6 @@ class JobLogger(val user: String, val logDirName: String) task.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" + mapId + " REDUCE_ID=" + reduceId stageLogInfo(task.stageId, taskStatus) - case OtherFailure(message) => - taskStatus += " STATUS=FAILURE TID=" + taskInfo.taskId + - " STAGE_ID=" + task.stageId + " INFO=" + message - stageLogInfo(task.stageId, taskStatus) case _ => } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala index 596f9adde9..1791242215 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala @@ -117,8 +117,4 @@ private[spark] class Pool( parent.decreaseRunningTasks(taskNum) } } - - override def hasPendingTasks(): Boolean = { - schedulableQueue.exists(_.hasPendingTasks()) - } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 310ec62ca8..28f3ba53b8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -32,7 +32,9 @@ private[spark] object ResultTask { // expensive on the master node if it needs to launch thousands of tasks. val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]] - val metadataCleaner = new MetadataCleaner(MetadataCleanerType.RESULT_TASK, serializedInfoCache.clearOldValues) + // TODO: This object shouldn't have global variables + val metadataCleaner = new MetadataCleaner( + MetadataCleanerType.RESULT_TASK, serializedInfoCache.clearOldValues, new SparkConf) def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] = { synchronized { diff --git a/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala b/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala index 1c7ea2dccc..d573e125a3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala @@ -42,5 +42,4 @@ private[spark] trait Schedulable { def executorLost(executorId: String, host: String): Unit def checkSpeculatableTasks(): Boolean def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] - def hasPendingTasks(): Boolean } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala index 356fe56bf3..3cf995ea74 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala @@ -20,7 +20,7 @@ package org.apache.spark.scheduler import java.io.{FileInputStream, InputStream} import java.util.{NoSuchElementException, Properties} -import org.apache.spark.Logging +import org.apache.spark.{SparkConf, Logging} import scala.xml.XML @@ -49,10 +49,10 @@ private[spark] class FIFOSchedulableBuilder(val rootPool: Pool) } } -private[spark] class FairSchedulableBuilder(val rootPool: Pool) +private[spark] class FairSchedulableBuilder(val rootPool: Pool, conf: SparkConf) extends SchedulableBuilder with Logging { - val schedulerAllocFile = Option(System.getProperty("spark.scheduler.allocation.file")) + val schedulerAllocFile = conf.getOption("spark.scheduler.allocation.file") val DEFAULT_SCHEDULER_FILE = "fairscheduler.xml" val FAIR_SCHEDULER_PROPERTIES = "spark.scheduler.pool" val DEFAULT_POOL_NAME = "default" diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala index 5367218faa..02bdbba825 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala @@ -15,12 +15,12 @@ * limitations under the License. */ -package org.apache.spark.scheduler.cluster +package org.apache.spark.scheduler import org.apache.spark.SparkContext /** - * A backend interface for cluster scheduling systems that allows plugging in different ones under + * A backend interface for scheduling systems that allows plugging in different ones under * ClusterScheduler. We assume a Mesos-like model where the application gets resource offers as * machines become available and can launch tasks on them. */ @@ -31,7 +31,4 @@ private[spark] trait SchedulerBackend { def defaultParallelism(): Int def killTask(taskId: Long, executorId: String): Unit = throw new UnsupportedOperationException - - // Memory used by each executor (in megabytes) - protected val executorMemory: Int = SparkContext.executorMemoryRequested } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 0f2deb4bcb..a37ead5632 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -37,7 +37,9 @@ private[spark] object ShuffleMapTask { // expensive on the master node if it needs to launch thousands of tasks. val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]] - val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SHUFFLE_MAP_TASK, serializedInfoCache.clearOldValues) + // TODO: This object shouldn't have global variables + val metadataCleaner = new MetadataCleaner( + MetadataCleanerType.SHUFFLE_MAP_TASK, serializedInfoCache.clearOldValues, new SparkConf) def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = { synchronized { @@ -152,7 +154,7 @@ private[spark] class ShuffleMapTask( try { // Obtain all the block writers for shuffle blocks. - val ser = SparkEnv.get.serializerManager.get(dep.serializerClass) + val ser = SparkEnv.get.serializerManager.get(dep.serializerClass, SparkEnv.get.conf) shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, partitionId, numOutputSplits, ser) // Write the map output to its associated buckets. diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 3841b5616d..627995c826 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -27,7 +27,7 @@ sealed trait SparkListenerEvents case class SparkListenerStageSubmitted(stage: StageInfo, properties: Properties) extends SparkListenerEvents -case class StageCompleted(val stage: StageInfo) extends SparkListenerEvents +case class SparkListenerStageCompleted(val stage: StageInfo) extends SparkListenerEvents case class SparkListenerTaskStart(task: Task[_], taskInfo: TaskInfo) extends SparkListenerEvents @@ -47,7 +47,7 @@ trait SparkListener { /** * Called when a stage is completed, with information on the completed stage */ - def onStageCompleted(stageCompleted: StageCompleted) { } + def onStageCompleted(stageCompleted: SparkListenerStageCompleted) { } /** * Called when a stage is submitted @@ -63,7 +63,7 @@ trait SparkListener { * Called when a task begins remotely fetching its result (will not be called for tasks that do * not need to fetch the result remotely). */ - def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { } + def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { } /** * Called when a task ends @@ -86,7 +86,7 @@ trait SparkListener { * Simple SparkListener that logs a few summary statistics when each stage completes */ class StatsReportListener extends SparkListener with Logging { - override def onStageCompleted(stageCompleted: StageCompleted) { + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) { import org.apache.spark.scheduler.StatsReportListener._ implicit val sc = stageCompleted this.logInfo("Finished stage: " + stageCompleted.stage) @@ -119,20 +119,24 @@ object StatsReportListener extends Logging { val probabilities = percentiles.map{_ / 100.0} val percentilesHeader = "\t" + percentiles.mkString("%\t") + "%" - def extractDoubleDistribution(stage:StageCompleted, getMetric: (TaskInfo,TaskMetrics) => Option[Double]): Option[Distribution] = { + def extractDoubleDistribution(stage: SparkListenerStageCompleted, + getMetric: (TaskInfo,TaskMetrics) => Option[Double]) + : Option[Distribution] = { Distribution(stage.stage.taskInfos.flatMap { case ((info,metric)) => getMetric(info, metric)}) } //is there some way to setup the types that I can get rid of this completely? - def extractLongDistribution(stage:StageCompleted, getMetric: (TaskInfo,TaskMetrics) => Option[Long]): Option[Distribution] = { + def extractLongDistribution(stage: SparkListenerStageCompleted, + getMetric: (TaskInfo,TaskMetrics) => Option[Long]) + : Option[Distribution] = { extractDoubleDistribution(stage, (info, metric) => getMetric(info,metric).map{_.toDouble}) } def showDistribution(heading: String, d: Distribution, formatNumber: Double => String) { val stats = d.statCounter - logInfo(heading + stats) val quantiles = d.getQuantiles(probabilities).map{formatNumber} + logInfo(heading + stats) logInfo(percentilesHeader) logInfo("\t" + quantiles.mkString("\t")) } @@ -147,12 +151,12 @@ object StatsReportListener extends Logging { } def showDistribution(heading:String, format: String, getMetric: (TaskInfo,TaskMetrics) => Option[Double]) - (implicit stage: StageCompleted) { + (implicit stage: SparkListenerStageCompleted) { showDistribution(heading, extractDoubleDistribution(stage, getMetric), format) } def showBytesDistribution(heading:String, getMetric: (TaskInfo,TaskMetrics) => Option[Long]) - (implicit stage: StageCompleted) { + (implicit stage: SparkListenerStageCompleted) { showBytesDistribution(heading, extractLongDistribution(stage, getMetric)) } @@ -169,12 +173,10 @@ object StatsReportListener extends Logging { } def showMillisDistribution(heading: String, getMetric: (TaskInfo, TaskMetrics) => Option[Long]) - (implicit stage: StageCompleted) { + (implicit stage: SparkListenerStageCompleted) { showMillisDistribution(heading, extractLongDistribution(stage, getMetric)) } - - val seconds = 1000L val minutes = seconds * 60 val hours = minutes * 60 @@ -198,7 +200,6 @@ object StatsReportListener extends Logging { } - case class RuntimePercentage(executorPct: Double, fetchPct: Option[Double], other: Double) object RuntimePercentage { def apply(totalTime: Long, metrics: TaskMetrics): RuntimePercentage = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index d5824e7954..e7defd768b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -41,7 +41,7 @@ private[spark] class SparkListenerBus() extends Logging { event match {
case stageSubmitted: SparkListenerStageSubmitted =>
sparkListeners.foreach(_.onStageSubmitted(stageSubmitted))
- case stageCompleted: StageCompleted =>
+ case stageCompleted: SparkListenerStageCompleted =>
sparkListeners.foreach(_.onStageCompleted(stageCompleted))
case jobStart: SparkListenerJobStart =>
sparkListeners.foreach(_.onJobStart(jobStart))
@@ -91,4 +91,3 @@ private[spark] class SparkListenerBus() extends Logging { return true
}
}
-
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index e68c527713..e22b1e53e8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -15,23 +15,23 @@ * limitations under the License. */ -package org.apache.spark.scheduler.cluster +package org.apache.spark.scheduler import java.nio.ByteBuffer import java.util.concurrent.{LinkedBlockingDeque, ThreadFactory, ThreadPoolExecutor, TimeUnit} import org.apache.spark._ import org.apache.spark.TaskState.TaskState -import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, TaskResult} import org.apache.spark.serializer.SerializerInstance import org.apache.spark.util.Utils /** * Runs a thread pool that deserializes and remotely fetches (if necessary) task results. */ -private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterScheduler) +private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedulerImpl) extends Logging { - private val THREADS = System.getProperty("spark.resultGetter.threads", "4").toInt + + private val THREADS = sparkEnv.conf.get("spark.resultGetter.threads", "4").toInt private val getTaskResultExecutor = Utils.newDaemonFixedThreadPool( THREADS, "Result resolver thread") @@ -42,7 +42,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterSche } def enqueueSuccessfulTask( - taskSetManager: ClusterTaskSetManager, tid: Long, serializedData: ByteBuffer) { + taskSetManager: TaskSetManager, tid: Long, serializedData: ByteBuffer) { getTaskResultExecutor.execute(new Runnable { override def run() { try { @@ -78,7 +78,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterSche }) } - def enqueueFailedTask(taskSetManager: ClusterTaskSetManager, tid: Long, taskState: TaskState, + def enqueueFailedTask(taskSetManager: TaskSetManager, tid: Long, taskState: TaskState, serializedData: ByteBuffer) { var reason: Option[TaskEndReason] = None getTaskResultExecutor.execute(new Runnable { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index 10e0478108..17b6d97e90 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -20,11 +20,12 @@ package org.apache.spark.scheduler import org.apache.spark.scheduler.SchedulingMode.SchedulingMode /** - * Low-level task scheduler interface, implemented by both ClusterScheduler and LocalScheduler. - * Each TaskScheduler schedulers task for a single SparkContext. - * These schedulers get sets of tasks submitted to them from the DAGScheduler for each stage, - * and are responsible for sending the tasks to the cluster, running them, retrying if there - * are failures, and mitigating stragglers. They return events to the DAGScheduler. + * Low-level task scheduler interface, currently implemented exclusively by the ClusterScheduler. + * This interface allows plugging in different task schedulers. Each TaskScheduler schedulers tasks + * for a single SparkContext. These schedulers get sets of tasks submitted to them from the + * DAGScheduler for each stage, and are responsible for sending the tasks to the cluster, running + * them, retrying if there are failures, and mitigating stragglers. They return events to the + * DAGScheduler. */ private[spark] trait TaskScheduler { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 66ab8ea4cd..0c8ed62759 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.scheduler.cluster +package org.apache.spark.scheduler import java.nio.ByteBuffer import java.util.concurrent.atomic.AtomicLong @@ -28,36 +28,42 @@ import scala.concurrent.duration._ import org.apache.spark._ import org.apache.spark.TaskState.TaskState -import org.apache.spark.scheduler._ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode /** - * The main TaskScheduler implementation, for running tasks on a cluster. Clients should first call - * initialize() and start(), then submit task sets through the runTasks method. - * - * This class can work with multiple types of clusters by acting through a SchedulerBackend. + * Schedules tasks for multiple types of clusters by acting through a SchedulerBackend. + * It can also work with a local setup by using a LocalBackend and setting isLocal to true. * It handles common logic, like determining a scheduling order across jobs, waking up to launch * speculative tasks, etc. * + * Clients should first call initialize() and start(), then submit task sets through the + * runTasks method. + * * THREADING: SchedulerBackends and task-submitting clients can call this class from multiple * threads, so it needs locks in public API methods to maintain its state. In addition, some * SchedulerBackends sycnchronize on themselves when they want to send events here, and then * acquire a lock on us, so we need to make sure that we don't try to lock the backend while * we are holding a lock on ourselves. */ -private[spark] class ClusterScheduler(val sc: SparkContext) - extends TaskScheduler - with Logging +private[spark] class TaskSchedulerImpl( + val sc: SparkContext, + val maxTaskFailures: Int, + isLocal: Boolean = false) + extends TaskScheduler with Logging { + def this(sc: SparkContext) = this(sc, sc.conf.get("spark.task.maxFailures", "4").toInt) + + val conf = sc.conf + // How often to check for speculative tasks - val SPECULATION_INTERVAL = System.getProperty("spark.speculation.interval", "100").toLong + val SPECULATION_INTERVAL = conf.get("spark.speculation.interval", "100").toLong // Threshold above which we warn user initial TaskSet may be starved - val STARVATION_TIMEOUT = System.getProperty("spark.starvation.timeout", "15000").toLong + val STARVATION_TIMEOUT = conf.get("spark.starvation.timeout", "15000").toLong - // ClusterTaskSetManagers are not thread safe, so any access to one should be synchronized + // TaskSetManagers are not thread safe, so any access to one should be synchronized // on this class. - val activeTaskSets = new HashMap[String, ClusterTaskSetManager] + val activeTaskSets = new HashMap[String, TaskSetManager] val taskIdToTaskSetId = new HashMap[Long, String] val taskIdToExecutorId = new HashMap[Long, String] @@ -90,7 +96,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) var rootPool: Pool = null // default scheduler is FIFO val schedulingMode: SchedulingMode = SchedulingMode.withName( - System.getProperty("spark.scheduler.mode", "FIFO")) + conf.get("spark.scheduler.mode", "FIFO")) // This is a var so that we can reset it for testing purposes. private[spark] var taskResultGetter = new TaskResultGetter(sc.env, this) @@ -108,7 +114,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) case SchedulingMode.FIFO => new FIFOSchedulableBuilder(rootPool) case SchedulingMode.FAIR => - new FairSchedulableBuilder(rootPool) + new FairSchedulableBuilder(rootPool, conf) } } schedulableBuilder.buildPools() @@ -119,7 +125,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) override def start() { backend.start() - if (System.getProperty("spark.speculation", "false").toBoolean) { + if (!isLocal && conf.get("spark.speculation", "false").toBoolean) { logInfo("Starting speculative execution thread") import sc.env.actorSystem.dispatcher sc.env.actorSystem.scheduler.schedule(SPECULATION_INTERVAL milliseconds, @@ -133,12 +139,12 @@ private[spark] class ClusterScheduler(val sc: SparkContext) val tasks = taskSet.tasks logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks") this.synchronized { - val manager = new ClusterTaskSetManager(this, taskSet) + val manager = new TaskSetManager(this, taskSet, maxTaskFailures) activeTaskSets(taskSet.id) = manager schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) taskSetTaskIds(taskSet.id) = new HashSet[Long]() - if (!hasReceivedTask) { + if (!isLocal && !hasReceivedTask) { starvationTimer.scheduleAtFixedRate(new TimerTask() { override def run() { if (!hasLaunchedTask) { @@ -279,7 +285,8 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } } case None => - logInfo("Ignoring update from TID " + tid + " because its task set is gone") + logInfo("Ignoring update with state %s from TID %s because its task set is gone" + .format(state, tid)) } } catch { case e: Exception => logError("Exception in statusUpdate", e) @@ -292,19 +299,19 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } } - def handleTaskGettingResult(taskSetManager: ClusterTaskSetManager, tid: Long) { + def handleTaskGettingResult(taskSetManager: TaskSetManager, tid: Long) { taskSetManager.handleTaskGettingResult(tid) } def handleSuccessfulTask( - taskSetManager: ClusterTaskSetManager, + taskSetManager: TaskSetManager, tid: Long, taskResult: DirectTaskResult[_]) = synchronized { taskSetManager.handleSuccessfulTask(tid, taskResult) } def handleFailedTask( - taskSetManager: ClusterTaskSetManager, + taskSetManager: TaskSetManager, tid: Long, taskState: TaskState, reason: Option[TaskEndReason]) = synchronized { @@ -322,7 +329,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) // Have each task set throw a SparkException with the error for ((taskSetId, manager) <- activeTaskSets) { try { - manager.error(message) + manager.abort(message) } catch { case e: Exception => logError("Exception in error callback", e) } @@ -352,7 +359,6 @@ private[spark] class ClusterScheduler(val sc: SparkContext) override def defaultParallelism() = backend.defaultParallelism() - // Check for speculatable tasks in all our active jobs. def checkSpeculatableTasks() { var shouldRevive = false @@ -364,13 +370,6 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } } - // Check for pending tasks in all our active jobs. - def hasPendingTasks: Boolean = { - synchronized { - rootPool.hasPendingTasks() - } - } - def executorLost(executorId: String, reason: ExecutorLossReason) { var failedExecutor: Option[String] = None @@ -429,7 +428,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } -object ClusterScheduler { +private[spark] object TaskSchedulerImpl { /** * Used to balance containers across hosts. * diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 90f6bcefac..6dd1469d8f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -17,32 +17,693 @@ package org.apache.spark.scheduler -import java.nio.ByteBuffer +import java.io.NotSerializableException +import java.util.Arrays +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet +import scala.math.max +import scala.math.min + +import org.apache.spark.{ExceptionFailure, FetchFailed, Logging, Resubmitted, SparkEnv, + Success, TaskEndReason, TaskKilled, TaskResultLost, TaskState} import org.apache.spark.TaskState.TaskState +import org.apache.spark.util.{Clock, SystemClock} + /** - * Tracks and schedules the tasks within a single TaskSet. This class keeps track of the status of - * each task and is responsible for retries on failure and locality. The main interfaces to it - * are resourceOffer, which asks the TaskSet whether it wants to run a task on one node, and - * statusUpdate, which tells it that one of its tasks changed state (e.g. finished). + * Schedules the tasks within a single TaskSet in the ClusterScheduler. This class keeps track of + * each task, retries tasks if they fail (up to a limited number of times), and + * handles locality-aware scheduling for this TaskSet via delay scheduling. The main interfaces + * to it are resourceOffer, which asks the TaskSet whether it wants to run a task on one node, + * and statusUpdate, which tells it that one of its tasks changed state (e.g. finished). + * + * THREADING: This class is designed to only be called from code with a lock on the + * TaskScheduler (e.g. its event handlers). It should not be called from other threads. * - * THREADING: This class is designed to only be called from code with a lock on the TaskScheduler - * (e.g. its event handlers). It should not be called from other threads. + * @param sched the ClusterScheduler associated with the TaskSetManager + * @param taskSet the TaskSet to manage scheduling for + * @param maxTaskFailures if any particular task fails more than this number of times, the entire + * task set will be aborted */ -private[spark] trait TaskSetManager extends Schedulable { - def schedulableQueue = null - - def schedulingMode = SchedulingMode.NONE - - def taskSet: TaskSet +private[spark] class TaskSetManager( + sched: TaskSchedulerImpl, + val taskSet: TaskSet, + val maxTaskFailures: Int, + clock: Clock = SystemClock) + extends Schedulable with Logging +{ + val conf = sched.sc.conf + + // CPUs to request per task + val CPUS_PER_TASK = conf.get("spark.task.cpus", "1").toInt + + // Quantile of tasks at which to start speculation + val SPECULATION_QUANTILE = conf.get("spark.speculation.quantile", "0.75").toDouble + val SPECULATION_MULTIPLIER = conf.get("spark.speculation.multiplier", "1.5").toDouble + + // Serializer for closures and tasks. + val env = SparkEnv.get + val ser = env.closureSerializer.newInstance() + + val tasks = taskSet.tasks + val numTasks = tasks.length + val copiesRunning = new Array[Int](numTasks) + val successful = new Array[Boolean](numTasks) + val numFailures = new Array[Int](numTasks) + val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil) + var tasksSuccessful = 0 + + var weight = 1 + var minShare = 0 + var priority = taskSet.priority + var stageId = taskSet.stageId + var name = "TaskSet_"+taskSet.stageId.toString + var parent: Pool = null + + var runningTasks = 0 + private val runningTasksSet = new HashSet[Long] + + // Set of pending tasks for each executor. These collections are actually + // treated as stacks, in which new tasks are added to the end of the + // ArrayBuffer and removed from the end. This makes it faster to detect + // tasks that repeatedly fail because whenever a task failed, it is put + // back at the head of the stack. They are also only cleaned up lazily; + // when a task is launched, it remains in all the pending lists except + // the one that it was launched from, but gets removed from them later. + private val pendingTasksForExecutor = new HashMap[String, ArrayBuffer[Int]] + + // Set of pending tasks for each host. Similar to pendingTasksForExecutor, + // but at host level. + private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]] + + // Set of pending tasks for each rack -- similar to the above. + private val pendingTasksForRack = new HashMap[String, ArrayBuffer[Int]] + + // Set containing pending tasks with no locality preferences. + val pendingTasksWithNoPrefs = new ArrayBuffer[Int] + + // Set containing all pending tasks (also used as a stack, as above). + val allPendingTasks = new ArrayBuffer[Int] + + // Tasks that can be speculated. Since these will be a small fraction of total + // tasks, we'll just hold them in a HashSet. + val speculatableTasks = new HashSet[Int] + + // Task index, start and finish time for each task attempt (indexed by task ID) + val taskInfos = new HashMap[Long, TaskInfo] + + // How frequently to reprint duplicate exceptions in full, in milliseconds + val EXCEPTION_PRINT_INTERVAL = + conf.get("spark.logging.exceptionPrintInterval", "10000").toLong + + // Map of recent exceptions (identified by string representation and top stack frame) to + // duplicate count (how many times the same exception has appeared) and time the full exception + // was printed. This should ideally be an LRU map that can drop old exceptions automatically. + val recentExceptions = HashMap[String, (Int, Long)]() + + // Figure out the current map output tracker epoch and set it on all tasks + val epoch = sched.mapOutputTracker.getEpoch + logDebug("Epoch for " + taskSet + ": " + epoch) + for (t <- tasks) { + t.epoch = epoch + } + + // Add all our tasks to the pending lists. We do this in reverse order + // of task index so that tasks with low indices get launched first. + for (i <- (0 until numTasks).reverse) { + addPendingTask(i) + } + + // Figure out which locality levels we have in our TaskSet, so we can do delay scheduling + val myLocalityLevels = computeValidLocalityLevels() + val localityWaits = myLocalityLevels.map(getLocalityWait) // Time to wait at each level + + // Delay scheduling variables: we keep track of our current locality level and the time we + // last launched a task at that level, and move up a level when localityWaits[curLevel] expires. + // We then move down if we manage to launch a "more local" task. + var currentLocalityIndex = 0 // Index of our current locality level in validLocalityLevels + var lastLaunchTime = clock.getTime() // Time we last launched a task at this level + + override def schedulableQueue = null + + override def schedulingMode = SchedulingMode.NONE + + /** + * Add a task to all the pending-task lists that it should be on. If readding is set, we are + * re-adding the task so only include it in each list if it's not already there. + */ + private def addPendingTask(index: Int, readding: Boolean = false) { + // Utility method that adds `index` to a list only if readding=false or it's not already there + def addTo(list: ArrayBuffer[Int]) { + if (!readding || !list.contains(index)) { + list += index + } + } + + var hadAliveLocations = false + for (loc <- tasks(index).preferredLocations) { + for (execId <- loc.executorId) { + if (sched.isExecutorAlive(execId)) { + addTo(pendingTasksForExecutor.getOrElseUpdate(execId, new ArrayBuffer)) + hadAliveLocations = true + } + } + if (sched.hasExecutorsAliveOnHost(loc.host)) { + addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer)) + for (rack <- sched.getRackForHost(loc.host)) { + addTo(pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer)) + } + hadAliveLocations = true + } + } + + if (!hadAliveLocations) { + // Even though the task might've had preferred locations, all of those hosts or executors + // are dead; put it in the no-prefs list so we can schedule it elsewhere right away. + addTo(pendingTasksWithNoPrefs) + } + + if (!readding) { + allPendingTasks += index // No point scanning this whole list to find the old task there + } + } + + /** + * Return the pending tasks list for a given executor ID, or an empty list if + * there is no map entry for that host + */ + private def getPendingTasksForExecutor(executorId: String): ArrayBuffer[Int] = { + pendingTasksForExecutor.getOrElse(executorId, ArrayBuffer()) + } + + /** + * Return the pending tasks list for a given host, or an empty list if + * there is no map entry for that host + */ + private def getPendingTasksForHost(host: String): ArrayBuffer[Int] = { + pendingTasksForHost.getOrElse(host, ArrayBuffer()) + } + + /** + * Return the pending rack-local task list for a given rack, or an empty list if + * there is no map entry for that rack + */ + private def getPendingTasksForRack(rack: String): ArrayBuffer[Int] = { + pendingTasksForRack.getOrElse(rack, ArrayBuffer()) + } + + /** + * Dequeue a pending task from the given list and return its index. + * Return None if the list is empty. + * This method also cleans up any tasks in the list that have already + * been launched, since we want that to happen lazily. + */ + private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = { + while (!list.isEmpty) { + val index = list.last + list.trimEnd(1) + if (copiesRunning(index) == 0 && !successful(index)) { + return Some(index) + } + } + return None + } + + /** Check whether a task is currently running an attempt on a given host */ + private def hasAttemptOnHost(taskIndex: Int, host: String): Boolean = { + !taskAttempts(taskIndex).exists(_.host == host) + } + + /** + * Return a speculative task for a given executor if any are available. The task should not have + * an attempt running on this host, in case the host is slow. In addition, the task should meet + * the given locality constraint. + */ + private def findSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value) + : Option[(Int, TaskLocality.Value)] = + { + speculatableTasks.retain(index => !successful(index)) // Remove finished tasks from set + + if (!speculatableTasks.isEmpty) { + // Check for process-local or preference-less tasks; note that tasks can be process-local + // on multiple nodes when we replicate cached blocks, as in Spark Streaming + for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) { + val prefs = tasks(index).preferredLocations + val executors = prefs.flatMap(_.executorId) + if (prefs.size == 0 || executors.contains(execId)) { + speculatableTasks -= index + return Some((index, TaskLocality.PROCESS_LOCAL)) + } + } + + // Check for node-local tasks + if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) { + for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) { + val locations = tasks(index).preferredLocations.map(_.host) + if (locations.contains(host)) { + speculatableTasks -= index + return Some((index, TaskLocality.NODE_LOCAL)) + } + } + } + // Check for rack-local tasks + if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { + for (rack <- sched.getRackForHost(host)) { + for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) { + val racks = tasks(index).preferredLocations.map(_.host).map(sched.getRackForHost) + if (racks.contains(rack)) { + speculatableTasks -= index + return Some((index, TaskLocality.RACK_LOCAL)) + } + } + } + } + + // Check for non-local tasks + if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) { + for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) { + speculatableTasks -= index + return Some((index, TaskLocality.ANY)) + } + } + } + + return None + } + + /** + * Dequeue a pending task for a given node and return its index and locality level. + * Only search for tasks matching the given locality constraint. + */ + private def findTask(execId: String, host: String, locality: TaskLocality.Value) + : Option[(Int, TaskLocality.Value)] = + { + for (index <- findTaskFromList(getPendingTasksForExecutor(execId))) { + return Some((index, TaskLocality.PROCESS_LOCAL)) + } + + if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) { + for (index <- findTaskFromList(getPendingTasksForHost(host))) { + return Some((index, TaskLocality.NODE_LOCAL)) + } + } + + if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { + for { + rack <- sched.getRackForHost(host) + index <- findTaskFromList(getPendingTasksForRack(rack)) + } { + return Some((index, TaskLocality.RACK_LOCAL)) + } + } + + // Look for no-pref tasks after rack-local tasks since they can run anywhere. + for (index <- findTaskFromList(pendingTasksWithNoPrefs)) { + return Some((index, TaskLocality.PROCESS_LOCAL)) + } + + if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) { + for (index <- findTaskFromList(allPendingTasks)) { + return Some((index, TaskLocality.ANY)) + } + } + + // Finally, if all else has failed, find a speculative task + return findSpeculativeTask(execId, host, locality) + } + + /** + * Respond to an offer of a single executor from the scheduler by finding a task + */ def resourceOffer( execId: String, host: String, availableCpus: Int, maxLocality: TaskLocality.TaskLocality) - : Option[TaskDescription] + : Option[TaskDescription] = + { + if (tasksSuccessful < numTasks && availableCpus >= CPUS_PER_TASK) { + val curTime = clock.getTime() + + var allowedLocality = getAllowedLocalityLevel(curTime) + if (allowedLocality > maxLocality) { + allowedLocality = maxLocality // We're not allowed to search for farther-away tasks + } + + findTask(execId, host, allowedLocality) match { + case Some((index, taskLocality)) => { + // Found a task; do some bookkeeping and return a task description + val task = tasks(index) + val taskId = sched.newTaskId() + // Figure out whether this should count as a preferred launch + logInfo("Starting task %s:%d as TID %s on executor %s: %s (%s)".format( + taskSet.id, index, taskId, execId, host, taskLocality)) + // Do various bookkeeping + copiesRunning(index) += 1 + val info = new TaskInfo(taskId, index, curTime, execId, host, taskLocality) + taskInfos(taskId) = info + taskAttempts(index) = info :: taskAttempts(index) + // Update our locality level for delay scheduling + currentLocalityIndex = getLocalityIndex(taskLocality) + lastLaunchTime = curTime + // Serialize and return the task + val startTime = clock.getTime() + // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here + // we assume the task can be serialized without exceptions. + val serializedTask = Task.serializeWithDependencies( + task, sched.sc.addedFiles, sched.sc.addedJars, ser) + val timeTaken = clock.getTime() - startTime + addRunningTask(taskId) + logInfo("Serialized task %s:%d as %d bytes in %d ms".format( + taskSet.id, index, serializedTask.limit, timeTaken)) + val taskName = "task %s:%d".format(taskSet.id, index) + if (taskAttempts(index).size == 1) + taskStarted(task,info) + return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask)) + } + case _ => + } + } + return None + } + + /** + * Get the level we can launch tasks according to delay scheduling, based on current wait time. + */ + private def getAllowedLocalityLevel(curTime: Long): TaskLocality.TaskLocality = { + while (curTime - lastLaunchTime >= localityWaits(currentLocalityIndex) && + currentLocalityIndex < myLocalityLevels.length - 1) + { + // Jump to the next locality level, and remove our waiting time for the current one since + // we don't want to count it again on the next one + lastLaunchTime += localityWaits(currentLocalityIndex) + currentLocalityIndex += 1 + } + myLocalityLevels(currentLocalityIndex) + } + + /** + * Find the index in myLocalityLevels for a given locality. This is also designed to work with + * localities that are not in myLocalityLevels (in case we somehow get those) by returning the + * next-biggest level we have. Uses the fact that the last value in myLocalityLevels is ANY. + */ + def getLocalityIndex(locality: TaskLocality.TaskLocality): Int = { + var index = 0 + while (locality > myLocalityLevels(index)) { + index += 1 + } + index + } + + private def taskStarted(task: Task[_], info: TaskInfo) { + sched.dagScheduler.taskStarted(task, info) + } + + def handleTaskGettingResult(tid: Long) = { + val info = taskInfos(tid) + info.markGettingResult() + sched.dagScheduler.taskGettingResult(tasks(info.index), info) + } + + /** + * Marks the task as successful and notifies the DAGScheduler that a task has ended. + */ + def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]) = { + val info = taskInfos(tid) + val index = info.index + info.markSuccessful() + removeRunningTask(tid) + if (!successful(index)) { + logInfo("Finished TID %s in %d ms on %s (progress: %d/%d)".format( + tid, info.duration, info.host, tasksSuccessful, numTasks)) + sched.dagScheduler.taskEnded( + tasks(index), Success, result.value, result.accumUpdates, info, result.metrics) + + // Mark successful and stop if all the tasks have succeeded. + tasksSuccessful += 1 + successful(index) = true + if (tasksSuccessful == numTasks) { + sched.taskSetFinished(this) + } + } else { + logInfo("Ignorning task-finished event for TID " + tid + " because task " + + index + " has already completed successfully") + } + } + + /** + * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the + * DAG Scheduler. + */ + def handleFailedTask(tid: Long, state: TaskState, reason: Option[TaskEndReason]) { + val info = taskInfos(tid) + if (info.failed) { + return + } + removeRunningTask(tid) + val index = info.index + info.markFailed() + var failureReason = "unknown" + if (!successful(index)) { + logWarning("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index)) + copiesRunning(index) -= 1 + // Check if the problem is a map output fetch failure. In that case, this + // task will never succeed on any node, so tell the scheduler about it. + reason.foreach { + case fetchFailed: FetchFailed => + logWarning("Loss was due to fetch failure from " + fetchFailed.bmAddress) + sched.dagScheduler.taskEnded(tasks(index), fetchFailed, null, null, info, null) + successful(index) = true + tasksSuccessful += 1 + sched.taskSetFinished(this) + removeAllRunningTasks() + return + + case TaskKilled => + logWarning("Task %d was killed.".format(tid)) + sched.dagScheduler.taskEnded(tasks(index), reason.get, null, null, info, null) + return + + case ef: ExceptionFailure => + sched.dagScheduler.taskEnded( + tasks(index), ef, null, null, info, ef.metrics.getOrElse(null)) + if (ef.className == classOf[NotSerializableException].getName()) { + // If the task result wasn't rerializable, there's no point in trying to re-execute it. + logError("Task %s:%s had a not serializable result: %s; not retrying".format( + taskSet.id, index, ef.description)) + abort("Task %s:%s had a not serializable result: %s".format( + taskSet.id, index, ef.description)) + return + } + val key = ef.description + failureReason = "Exception failure: %s".format(ef.description) + val now = clock.getTime() + val (printFull, dupCount) = { + if (recentExceptions.contains(key)) { + val (dupCount, printTime) = recentExceptions(key) + if (now - printTime > EXCEPTION_PRINT_INTERVAL) { + recentExceptions(key) = (0, now) + (true, 0) + } else { + recentExceptions(key) = (dupCount + 1, printTime) + (false, dupCount + 1) + } + } else { + recentExceptions(key) = (0, now) + (true, 0) + } + } + if (printFull) { + val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString)) + logWarning("Loss was due to %s\n%s\n%s".format( + ef.className, ef.description, locs.mkString("\n"))) + } else { + logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount)) + } + + case TaskResultLost => + failureReason = "Lost result for TID %s on host %s".format(tid, info.host) + logWarning(failureReason) + sched.dagScheduler.taskEnded(tasks(index), TaskResultLost, null, null, info, null) + + case _ => {} + } + // On non-fetch failures, re-enqueue the task as pending for a max number of retries + addPendingTask(index) + if (state != TaskState.KILLED) { + numFailures(index) += 1 + if (numFailures(index) >= maxTaskFailures) { + logError("Task %s:%d failed %d times; aborting job".format( + taskSet.id, index, maxTaskFailures)) + abort("Task %s:%d failed %d times (most recent failure: %s)".format( + taskSet.id, index, maxTaskFailures, failureReason)) + } + } + } else { + logInfo("Ignoring task-lost event for TID " + tid + + " because task " + index + " is already finished") + } + } + + def abort(message: String) { + // TODO: Kill running tasks if we were not terminated due to a Mesos error + sched.dagScheduler.taskSetFailed(taskSet, message) + removeAllRunningTasks() + sched.taskSetFinished(this) + } + + /** If the given task ID is not in the set of running tasks, adds it. + * + * Used to keep track of the number of running tasks, for enforcing scheduling policies. + */ + def addRunningTask(tid: Long) { + if (runningTasksSet.add(tid) && parent != null) { + parent.increaseRunningTasks(1) + } + runningTasks = runningTasksSet.size + } + + /** If the given task ID is in the set of running tasks, removes it. */ + def removeRunningTask(tid: Long) { + if (runningTasksSet.remove(tid) && parent != null) { + parent.decreaseRunningTasks(1) + } + runningTasks = runningTasksSet.size + } + + private[scheduler] def removeAllRunningTasks() { + val numRunningTasks = runningTasksSet.size + runningTasksSet.clear() + if (parent != null) { + parent.decreaseRunningTasks(numRunningTasks) + } + runningTasks = 0 + } + + override def getSchedulableByName(name: String): Schedulable = { + return null + } + + override def addSchedulable(schedulable: Schedulable) {} + + override def removeSchedulable(schedulable: Schedulable) {} + + override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { + var sortedTaskSetQueue = ArrayBuffer[TaskSetManager](this) + sortedTaskSetQueue += this + return sortedTaskSetQueue + } + + /** Called by TaskScheduler when an executor is lost so we can re-enqueue our tasks */ + override def executorLost(execId: String, host: String) { + logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id) + + // Re-enqueue pending tasks for this host based on the status of the cluster -- for example, a + // task that used to have locations on only this host might now go to the no-prefs list. Note + // that it's okay if we add a task to the same queue twice (if it had multiple preferred + // locations), because findTaskFromList will skip already-running tasks. + for (index <- getPendingTasksForExecutor(execId)) { + addPendingTask(index, readding=true) + } + for (index <- getPendingTasksForHost(host)) { + addPendingTask(index, readding=true) + } + + // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage + if (tasks(0).isInstanceOf[ShuffleMapTask]) { + for ((tid, info) <- taskInfos if info.executorId == execId) { + val index = taskInfos(tid).index + if (successful(index)) { + successful(index) = false + copiesRunning(index) -= 1 + tasksSuccessful -= 1 + addPendingTask(index) + // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our + // stage finishes when a total of tasks.size tasks finish. + sched.dagScheduler.taskEnded(tasks(index), Resubmitted, null, null, info, null) + } + } + } + // Also re-enqueue any tasks that were running on the node + for ((tid, info) <- taskInfos if info.running && info.executorId == execId) { + handleFailedTask(tid, TaskState.KILLED, None) + } + } + + /** + * Check for tasks to be speculated and return true if there are any. This is called periodically + * by the TaskScheduler. + * + * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that + * we don't scan the whole task set. It might also help to make this sorted by launch time. + */ + override def checkSpeculatableTasks(): Boolean = { + // Can't speculate if we only have one task, or if all tasks have finished. + if (numTasks == 1 || tasksSuccessful == numTasks) { + return false + } + var foundTasks = false + val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt + logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation) + if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) { + val time = clock.getTime() + val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray + Arrays.sort(durations) + val medianDuration = durations(min((0.5 * tasksSuccessful).round.toInt, durations.size - 1)) + val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100) + // TODO: Threshold should also look at standard deviation of task durations and have a lower + // bound based on that. + logDebug("Task length threshold for speculation: " + threshold) + for ((tid, info) <- taskInfos) { + val index = info.index + if (!successful(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold && + !speculatableTasks.contains(index)) { + logInfo( + "Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format( + taskSet.id, index, info.host, threshold)) + speculatableTasks += index + foundTasks = true + } + } + } + return foundTasks + } + + private def getLocalityWait(level: TaskLocality.TaskLocality): Long = { + val defaultWait = conf.get("spark.locality.wait", "3000") + level match { + case TaskLocality.PROCESS_LOCAL => + conf.get("spark.locality.wait.process", defaultWait).toLong + case TaskLocality.NODE_LOCAL => + conf.get("spark.locality.wait.node", defaultWait).toLong + case TaskLocality.RACK_LOCAL => + conf.get("spark.locality.wait.rack", defaultWait).toLong + case TaskLocality.ANY => + 0L + } + } - def error(message: String) + /** + * Compute the locality levels used in this TaskSet. Assumes that all tasks have already been + * added to queues using addPendingTask. + */ + private def computeValidLocalityLevels(): Array[TaskLocality.TaskLocality] = { + import TaskLocality.{PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY} + val levels = new ArrayBuffer[TaskLocality.TaskLocality] + if (!pendingTasksForExecutor.isEmpty && getLocalityWait(PROCESS_LOCAL) != 0) { + levels += PROCESS_LOCAL + } + if (!pendingTasksForHost.isEmpty && getLocalityWait(NODE_LOCAL) != 0) { + levels += NODE_LOCAL + } + if (!pendingTasksForRack.isEmpty && getLocalityWait(RACK_LOCAL) != 0) { + levels += RACK_LOCAL + } + levels += ANY + logDebug("Valid locality levels for " + taskSet + ": " + levels.mkString(", ")) + levels.toArray + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/WorkerOffer.scala b/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala index 938f62883a..ba6bab3f91 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/WorkerOffer.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.scheduler.cluster +package org.apache.spark.scheduler /** * Represents free resources available on an executor. diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala deleted file mode 100644 index bf494aa64d..0000000000 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala +++ /dev/null @@ -1,713 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler.cluster - -import java.io.NotSerializableException -import java.util.Arrays - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet -import scala.math.max -import scala.math.min - -import org.apache.spark.{ExceptionFailure, FetchFailed, Logging, Resubmitted, SparkEnv, - Success, TaskEndReason, TaskKilled, TaskResultLost, TaskState} -import org.apache.spark.TaskState.TaskState -import org.apache.spark.scheduler._ -import org.apache.spark.util.{SystemClock, Clock} - - -/** - * Schedules the tasks within a single TaskSet in the ClusterScheduler. This class keeps track of - * the status of each task, retries tasks if they fail (up to a limited number of times), and - * handles locality-aware scheduling for this TaskSet via delay scheduling. The main interfaces - * to it are resourceOffer, which asks the TaskSet whether it wants to run a task on one node, - * and statusUpdate, which tells it that one of its tasks changed state (e.g. finished). - * - * THREADING: This class is designed to only be called from code with a lock on the - * ClusterScheduler (e.g. its event handlers). It should not be called from other threads. - */ -private[spark] class ClusterTaskSetManager( - sched: ClusterScheduler, - val taskSet: TaskSet, - clock: Clock = SystemClock) - extends TaskSetManager - with Logging -{ - // CPUs to request per task - val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toInt - - // Maximum times a task is allowed to fail before failing the job - val MAX_TASK_FAILURES = System.getProperty("spark.task.maxFailures", "4").toInt - - // Quantile of tasks at which to start speculation - val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble - val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble - - // Serializer for closures and tasks. - val env = SparkEnv.get - val ser = env.closureSerializer.newInstance() - - val tasks = taskSet.tasks - val numTasks = tasks.length - val copiesRunning = new Array[Int](numTasks) - val successful = new Array[Boolean](numTasks) - val numFailures = new Array[Int](numTasks) - val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil) - var tasksSuccessful = 0 - - var weight = 1 - var minShare = 0 - var priority = taskSet.priority - var stageId = taskSet.stageId - var name = "TaskSet_"+taskSet.stageId.toString - var parent: Pool = null - - var runningTasks = 0 - private val runningTasksSet = new HashSet[Long] - - // Set of pending tasks for each executor. These collections are actually - // treated as stacks, in which new tasks are added to the end of the - // ArrayBuffer and removed from the end. This makes it faster to detect - // tasks that repeatedly fail because whenever a task failed, it is put - // back at the head of the stack. They are also only cleaned up lazily; - // when a task is launched, it remains in all the pending lists except - // the one that it was launched from, but gets removed from them later. - private val pendingTasksForExecutor = new HashMap[String, ArrayBuffer[Int]] - - // Set of pending tasks for each host. Similar to pendingTasksForExecutor, - // but at host level. - private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]] - - // Set of pending tasks for each rack -- similar to the above. - private val pendingTasksForRack = new HashMap[String, ArrayBuffer[Int]] - - // Set containing pending tasks with no locality preferences. - val pendingTasksWithNoPrefs = new ArrayBuffer[Int] - - // Set containing all pending tasks (also used as a stack, as above). - val allPendingTasks = new ArrayBuffer[Int] - - // Tasks that can be speculated. Since these will be a small fraction of total - // tasks, we'll just hold them in a HashSet. - val speculatableTasks = new HashSet[Int] - - // Task index, start and finish time for each task attempt (indexed by task ID) - val taskInfos = new HashMap[Long, TaskInfo] - - // Did the TaskSet fail? - var failed = false - var causeOfFailure = "" - - // How frequently to reprint duplicate exceptions in full, in milliseconds - val EXCEPTION_PRINT_INTERVAL = - System.getProperty("spark.logging.exceptionPrintInterval", "10000").toLong - - // Map of recent exceptions (identified by string representation and top stack frame) to - // duplicate count (how many times the same exception has appeared) and time the full exception - // was printed. This should ideally be an LRU map that can drop old exceptions automatically. - val recentExceptions = HashMap[String, (Int, Long)]() - - // Figure out the current map output tracker epoch and set it on all tasks - val epoch = sched.mapOutputTracker.getEpoch - logDebug("Epoch for " + taskSet + ": " + epoch) - for (t <- tasks) { - t.epoch = epoch - } - - // Add all our tasks to the pending lists. We do this in reverse order - // of task index so that tasks with low indices get launched first. - for (i <- (0 until numTasks).reverse) { - addPendingTask(i) - } - - // Figure out which locality levels we have in our TaskSet, so we can do delay scheduling - val myLocalityLevels = computeValidLocalityLevels() - val localityWaits = myLocalityLevels.map(getLocalityWait) // Time to wait at each level - - // Delay scheduling variables: we keep track of our current locality level and the time we - // last launched a task at that level, and move up a level when localityWaits[curLevel] expires. - // We then move down if we manage to launch a "more local" task. - var currentLocalityIndex = 0 // Index of our current locality level in validLocalityLevels - var lastLaunchTime = clock.getTime() // Time we last launched a task at this level - - /** - * Add a task to all the pending-task lists that it should be on. If readding is set, we are - * re-adding the task so only include it in each list if it's not already there. - */ - private def addPendingTask(index: Int, readding: Boolean = false) { - // Utility method that adds `index` to a list only if readding=false or it's not already there - def addTo(list: ArrayBuffer[Int]) { - if (!readding || !list.contains(index)) { - list += index - } - } - - var hadAliveLocations = false - for (loc <- tasks(index).preferredLocations) { - for (execId <- loc.executorId) { - if (sched.isExecutorAlive(execId)) { - addTo(pendingTasksForExecutor.getOrElseUpdate(execId, new ArrayBuffer)) - hadAliveLocations = true - } - } - if (sched.hasExecutorsAliveOnHost(loc.host)) { - addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer)) - for (rack <- sched.getRackForHost(loc.host)) { - addTo(pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer)) - } - hadAliveLocations = true - } - } - - if (!hadAliveLocations) { - // Even though the task might've had preferred locations, all of those hosts or executors - // are dead; put it in the no-prefs list so we can schedule it elsewhere right away. - addTo(pendingTasksWithNoPrefs) - } - - if (!readding) { - allPendingTasks += index // No point scanning this whole list to find the old task there - } - } - - /** - * Return the pending tasks list for a given executor ID, or an empty list if - * there is no map entry for that host - */ - private def getPendingTasksForExecutor(executorId: String): ArrayBuffer[Int] = { - pendingTasksForExecutor.getOrElse(executorId, ArrayBuffer()) - } - - /** - * Return the pending tasks list for a given host, or an empty list if - * there is no map entry for that host - */ - private def getPendingTasksForHost(host: String): ArrayBuffer[Int] = { - pendingTasksForHost.getOrElse(host, ArrayBuffer()) - } - - /** - * Return the pending rack-local task list for a given rack, or an empty list if - * there is no map entry for that rack - */ - private def getPendingTasksForRack(rack: String): ArrayBuffer[Int] = { - pendingTasksForRack.getOrElse(rack, ArrayBuffer()) - } - - /** - * Dequeue a pending task from the given list and return its index. - * Return None if the list is empty. - * This method also cleans up any tasks in the list that have already - * been launched, since we want that to happen lazily. - */ - private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = { - while (!list.isEmpty) { - val index = list.last - list.trimEnd(1) - if (copiesRunning(index) == 0 && !successful(index)) { - return Some(index) - } - } - return None - } - - /** Check whether a task is currently running an attempt on a given host */ - private def hasAttemptOnHost(taskIndex: Int, host: String): Boolean = { - !taskAttempts(taskIndex).exists(_.host == host) - } - - /** - * Return a speculative task for a given executor if any are available. The task should not have - * an attempt running on this host, in case the host is slow. In addition, the task should meet - * the given locality constraint. - */ - private def findSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value) - : Option[(Int, TaskLocality.Value)] = - { - speculatableTasks.retain(index => !successful(index)) // Remove finished tasks from set - - if (!speculatableTasks.isEmpty) { - // Check for process-local or preference-less tasks; note that tasks can be process-local - // on multiple nodes when we replicate cached blocks, as in Spark Streaming - for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) { - val prefs = tasks(index).preferredLocations - val executors = prefs.flatMap(_.executorId) - if (prefs.size == 0 || executors.contains(execId)) { - speculatableTasks -= index - return Some((index, TaskLocality.PROCESS_LOCAL)) - } - } - - // Check for node-local tasks - if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) { - for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) { - val locations = tasks(index).preferredLocations.map(_.host) - if (locations.contains(host)) { - speculatableTasks -= index - return Some((index, TaskLocality.NODE_LOCAL)) - } - } - } - - // Check for rack-local tasks - if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { - for (rack <- sched.getRackForHost(host)) { - for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) { - val racks = tasks(index).preferredLocations.map(_.host).map(sched.getRackForHost) - if (racks.contains(rack)) { - speculatableTasks -= index - return Some((index, TaskLocality.RACK_LOCAL)) - } - } - } - } - - // Check for non-local tasks - if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) { - for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) { - speculatableTasks -= index - return Some((index, TaskLocality.ANY)) - } - } - } - - return None - } - - /** - * Dequeue a pending task for a given node and return its index and locality level. - * Only search for tasks matching the given locality constraint. - */ - private def findTask(execId: String, host: String, locality: TaskLocality.Value) - : Option[(Int, TaskLocality.Value)] = - { - for (index <- findTaskFromList(getPendingTasksForExecutor(execId))) { - return Some((index, TaskLocality.PROCESS_LOCAL)) - } - - if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) { - for (index <- findTaskFromList(getPendingTasksForHost(host))) { - return Some((index, TaskLocality.NODE_LOCAL)) - } - } - - if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { - for { - rack <- sched.getRackForHost(host) - index <- findTaskFromList(getPendingTasksForRack(rack)) - } { - return Some((index, TaskLocality.RACK_LOCAL)) - } - } - - // Look for no-pref tasks after rack-local tasks since they can run anywhere. - for (index <- findTaskFromList(pendingTasksWithNoPrefs)) { - return Some((index, TaskLocality.PROCESS_LOCAL)) - } - - if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) { - for (index <- findTaskFromList(allPendingTasks)) { - return Some((index, TaskLocality.ANY)) - } - } - - // Finally, if all else has failed, find a speculative task - return findSpeculativeTask(execId, host, locality) - } - - /** - * Respond to an offer of a single executor from the scheduler by finding a task - */ - override def resourceOffer( - execId: String, - host: String, - availableCpus: Int, - maxLocality: TaskLocality.TaskLocality) - : Option[TaskDescription] = - { - if (tasksSuccessful < numTasks && availableCpus >= CPUS_PER_TASK) { - val curTime = clock.getTime() - - var allowedLocality = getAllowedLocalityLevel(curTime) - if (allowedLocality > maxLocality) { - allowedLocality = maxLocality // We're not allowed to search for farther-away tasks - } - - findTask(execId, host, allowedLocality) match { - case Some((index, taskLocality)) => { - // Found a task; do some bookkeeping and return a task description - val task = tasks(index) - val taskId = sched.newTaskId() - // Figure out whether this should count as a preferred launch - logInfo("Starting task %s:%d as TID %s on executor %s: %s (%s)".format( - taskSet.id, index, taskId, execId, host, taskLocality)) - // Do various bookkeeping - copiesRunning(index) += 1 - val info = new TaskInfo(taskId, index, curTime, execId, host, taskLocality) - taskInfos(taskId) = info - taskAttempts(index) = info :: taskAttempts(index) - // Update our locality level for delay scheduling - currentLocalityIndex = getLocalityIndex(taskLocality) - lastLaunchTime = curTime - // Serialize and return the task - val startTime = clock.getTime() - // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here - // we assume the task can be serialized without exceptions. - val serializedTask = Task.serializeWithDependencies( - task, sched.sc.addedFiles, sched.sc.addedJars, ser) - val timeTaken = clock.getTime() - startTime - addRunningTask(taskId) - logInfo("Serialized task %s:%d as %d bytes in %d ms".format( - taskSet.id, index, serializedTask.limit, timeTaken)) - val taskName = "task %s:%d".format(taskSet.id, index) - info.serializedSize = serializedTask.limit - if (taskAttempts(index).size == 1) - taskStarted(task,info) - return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask)) - } - case _ => - } - } - return None - } - - /** - * Get the level we can launch tasks according to delay scheduling, based on current wait time. - */ - private def getAllowedLocalityLevel(curTime: Long): TaskLocality.TaskLocality = { - while (curTime - lastLaunchTime >= localityWaits(currentLocalityIndex) && - currentLocalityIndex < myLocalityLevels.length - 1) - { - // Jump to the next locality level, and remove our waiting time for the current one since - // we don't want to count it again on the next one - lastLaunchTime += localityWaits(currentLocalityIndex) - currentLocalityIndex += 1 - } - myLocalityLevels(currentLocalityIndex) - } - - /** - * Find the index in myLocalityLevels for a given locality. This is also designed to work with - * localities that are not in myLocalityLevels (in case we somehow get those) by returning the - * next-biggest level we have. Uses the fact that the last value in myLocalityLevels is ANY. - */ - def getLocalityIndex(locality: TaskLocality.TaskLocality): Int = { - var index = 0 - while (locality > myLocalityLevels(index)) { - index += 1 - } - index - } - - private def taskStarted(task: Task[_], info: TaskInfo) { - sched.dagScheduler.taskStarted(task, info) - } - - def handleTaskGettingResult(tid: Long) = { - val info = taskInfos(tid) - info.markGettingResult() - sched.dagScheduler.taskGettingResult(tasks(info.index), info) - } - - /** - * Marks the task as successful and notifies the DAGScheduler that a task has ended. - */ - def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]) = { - val info = taskInfos(tid) - val index = info.index - info.markSuccessful() - removeRunningTask(tid) - if (!successful(index)) { - logInfo("Finished TID %s in %d ms on %s (progress: %d/%d)".format( - tid, info.duration, info.host, tasksSuccessful, numTasks)) - sched.dagScheduler.taskEnded( - tasks(index), Success, result.value, result.accumUpdates, info, result.metrics) - - // Mark successful and stop if all the tasks have succeeded. - tasksSuccessful += 1 - successful(index) = true - if (tasksSuccessful == numTasks) { - sched.taskSetFinished(this) - } - } else { - logInfo("Ignorning task-finished event for TID " + tid + " because task " + - index + " has already completed successfully") - } - } - - /** - * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the - * DAG Scheduler. - */ - def handleFailedTask(tid: Long, state: TaskState, reason: Option[TaskEndReason]) { - val info = taskInfos(tid) - if (info.failed) { - return - } - removeRunningTask(tid) - val index = info.index - info.markFailed() - if (!successful(index)) { - logWarning("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index)) - copiesRunning(index) -= 1 - // Check if the problem is a map output fetch failure. In that case, this - // task will never succeed on any node, so tell the scheduler about it. - reason.foreach { - case fetchFailed: FetchFailed => - logWarning("Loss was due to fetch failure from " + fetchFailed.bmAddress) - sched.dagScheduler.taskEnded(tasks(index), fetchFailed, null, null, info, null) - successful(index) = true - tasksSuccessful += 1 - sched.taskSetFinished(this) - removeAllRunningTasks() - return - - case TaskKilled => - logWarning("Task %d was killed.".format(tid)) - sched.dagScheduler.taskEnded(tasks(index), reason.get, null, null, info, null) - return - - case ef: ExceptionFailure => - sched.dagScheduler.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null)) - if (ef.className == classOf[NotSerializableException].getName()) { - // If the task result wasn't serializable, there's no point in trying to re-execute it. - logError("Task %s:%s had a not serializable result: %s; not retrying".format( - taskSet.id, index, ef.description)) - abort("Task %s:%s had a not serializable result: %s".format( - taskSet.id, index, ef.description)) - return - } - val key = ef.description - val now = clock.getTime() - val (printFull, dupCount) = { - if (recentExceptions.contains(key)) { - val (dupCount, printTime) = recentExceptions(key) - if (now - printTime > EXCEPTION_PRINT_INTERVAL) { - recentExceptions(key) = (0, now) - (true, 0) - } else { - recentExceptions(key) = (dupCount + 1, printTime) - (false, dupCount + 1) - } - } else { - recentExceptions(key) = (0, now) - (true, 0) - } - } - if (printFull) { - val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString)) - logWarning("Loss was due to %s\n%s\n%s".format( - ef.className, ef.description, locs.mkString("\n"))) - } else { - logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount)) - } - - case TaskResultLost => - logWarning("Lost result for TID %s on host %s".format(tid, info.host)) - sched.dagScheduler.taskEnded(tasks(index), TaskResultLost, null, null, info, null) - - case _ => {} - } - // On non-fetch failures, re-enqueue the task as pending for a max number of retries - addPendingTask(index) - if (state != TaskState.KILLED) { - numFailures(index) += 1 - if (numFailures(index) >= MAX_TASK_FAILURES) { - logError("Task %s:%d failed %d times; aborting job".format( - taskSet.id, index, MAX_TASK_FAILURES)) - abort("Task %s:%d failed %d times".format(taskSet.id, index, MAX_TASK_FAILURES)) - } - } - } else { - logInfo("Ignoring task-lost event for TID " + tid + - " because task " + index + " is already finished") - } - } - - override def error(message: String) { - // Save the error message - abort("Error: " + message) - } - - def abort(message: String) { - failed = true - causeOfFailure = message - // TODO: Kill running tasks if we were not terminated due to a Mesos error - sched.dagScheduler.taskSetFailed(taskSet, message) - removeAllRunningTasks() - sched.taskSetFinished(this) - } - - /** If the given task ID is not in the set of running tasks, adds it. - * - * Used to keep track of the number of running tasks, for enforcing scheduling policies. - */ - def addRunningTask(tid: Long) { - if (runningTasksSet.add(tid) && parent != null) { - parent.increaseRunningTasks(1) - } - runningTasks = runningTasksSet.size - } - - /** If the given task ID is in the set of running tasks, removes it. */ - def removeRunningTask(tid: Long) { - if (runningTasksSet.remove(tid) && parent != null) { - parent.decreaseRunningTasks(1) - } - runningTasks = runningTasksSet.size - } - - private[cluster] def removeAllRunningTasks() { - val numRunningTasks = runningTasksSet.size - runningTasksSet.clear() - if (parent != null) { - parent.decreaseRunningTasks(numRunningTasks) - } - runningTasks = 0 - } - - override def getSchedulableByName(name: String): Schedulable = { - return null - } - - override def addSchedulable(schedulable: Schedulable) {} - - override def removeSchedulable(schedulable: Schedulable) {} - - override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { - var sortedTaskSetQueue = ArrayBuffer[TaskSetManager](this) - sortedTaskSetQueue += this - return sortedTaskSetQueue - } - - /** Called by cluster scheduler when an executor is lost so we can re-enqueue our tasks */ - override def executorLost(execId: String, host: String) { - logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id) - - // Re-enqueue pending tasks for this host based on the status of the cluster -- for example, a - // task that used to have locations on only this host might now go to the no-prefs list. Note - // that it's okay if we add a task to the same queue twice (if it had multiple preferred - // locations), because findTaskFromList will skip already-running tasks. - for (index <- getPendingTasksForExecutor(execId)) { - addPendingTask(index, readding=true) - } - for (index <- getPendingTasksForHost(host)) { - addPendingTask(index, readding=true) - } - - // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage - if (tasks(0).isInstanceOf[ShuffleMapTask]) { - for ((tid, info) <- taskInfos if info.executorId == execId) { - val index = taskInfos(tid).index - if (successful(index)) { - successful(index) = false - copiesRunning(index) -= 1 - tasksSuccessful -= 1 - addPendingTask(index) - // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our - // stage finishes when a total of tasks.size tasks finish. - sched.dagScheduler.taskEnded(tasks(index), Resubmitted, null, null, info, null) - } - } - } - // Also re-enqueue any tasks that were running on the node - for ((tid, info) <- taskInfos if info.running && info.executorId == execId) { - handleFailedTask(tid, TaskState.KILLED, None) - } - } - - /** - * Check for tasks to be speculated and return true if there are any. This is called periodically - * by the ClusterScheduler. - * - * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that - * we don't scan the whole task set. It might also help to make this sorted by launch time. - */ - override def checkSpeculatableTasks(): Boolean = { - // Can't speculate if we only have one task, or if all tasks have finished. - if (numTasks == 1 || tasksSuccessful == numTasks) { - return false - } - var foundTasks = false - val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt - logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation) - if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) { - val time = clock.getTime() - val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray - Arrays.sort(durations) - val medianDuration = durations(min((0.5 * tasksSuccessful).round.toInt, durations.size - 1)) - val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100) - // TODO: Threshold should also look at standard deviation of task durations and have a lower - // bound based on that. - logDebug("Task length threshold for speculation: " + threshold) - for ((tid, info) <- taskInfos) { - val index = info.index - if (!successful(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold && - !speculatableTasks.contains(index)) { - logInfo( - "Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format( - taskSet.id, index, info.host, threshold)) - speculatableTasks += index - foundTasks = true - } - } - } - return foundTasks - } - - override def hasPendingTasks(): Boolean = { - numTasks > 0 && tasksSuccessful < numTasks - } - - private def getLocalityWait(level: TaskLocality.TaskLocality): Long = { - val defaultWait = System.getProperty("spark.locality.wait", "3000") - level match { - case TaskLocality.PROCESS_LOCAL => - System.getProperty("spark.locality.wait.process", defaultWait).toLong - case TaskLocality.NODE_LOCAL => - System.getProperty("spark.locality.wait.node", defaultWait).toLong - case TaskLocality.RACK_LOCAL => - System.getProperty("spark.locality.wait.rack", defaultWait).toLong - case TaskLocality.ANY => - 0L - } - } - - /** - * Compute the locality levels used in this TaskSet. Assumes that all tasks have already been - * added to queues using addPendingTask. - */ - private def computeValidLocalityLevels(): Array[TaskLocality.TaskLocality] = { - import TaskLocality.{PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY} - val levels = new ArrayBuffer[TaskLocality.TaskLocality] - if (!pendingTasksForExecutor.isEmpty && getLocalityWait(PROCESS_LOCAL) != 0) { - levels += PROCESS_LOCAL - } - if (!pendingTasksForHost.isEmpty && getLocalityWait(NODE_LOCAL) != 0) { - levels += NODE_LOCAL - } - if (!pendingTasksForRack.isEmpty && getLocalityWait(RACK_LOCAL) != 0) { - levels += RACK_LOCAL - } - levels += ANY - logDebug("Valid locality levels for " + taskSet + ": " + levels.mkString(", ")) - levels.toArray - } -} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 7e22c843bf..2f5bcafe40 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -27,8 +27,10 @@ import akka.actor._ import akka.pattern.ask import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} +import org.apache.spark.{SparkException, Logging, TaskState} import org.apache.spark.{Logging, SparkException, TaskState} -import org.apache.spark.scheduler.TaskDescription +import org.apache.spark.scheduler.{TaskSchedulerImpl, SchedulerBackend, SlaveLost, TaskDescription, + WorkerOffer} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util.{AkkaUtils, Utils} @@ -41,13 +43,13 @@ import org.apache.spark.util.{AkkaUtils, Utils} * (spark.deploy.*). */ private[spark] -class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: ActorSystem) +class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: ActorSystem) extends SchedulerBackend with Logging { // Use an atomic variable to track total number of cores in the cluster for simplicity and speed var totalCoreCount = new AtomicInteger(0) - - private val timeout = AkkaUtils.askTimeout + val conf = scheduler.sc.conf + private val timeout = AkkaUtils.askTimeout(conf) class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor { private val executorActor = new HashMap[String, ActorRef] @@ -61,7 +63,7 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) // Periodically revive offers to allow delay scheduling to work - val reviveInterval = System.getProperty("spark.scheduler.revive.interval", "1000").toLong + val reviveInterval = conf.get("spark.scheduler.revive.interval", "1000").toLong import context.dispatcher context.system.scheduler.schedule(0.millis, reviveInterval.millis, self, ReviveOffers) } @@ -117,7 +119,7 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac removeExecutor(executorId, reason) sender ! true - case DisassociatedEvent(_, address, _) => + case DisassociatedEvent(_, address, _) => addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client disassociated")) } @@ -162,14 +164,12 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac override def start() { val properties = new ArrayBuffer[(String, String)] - val iterator = System.getProperties.entrySet.iterator - while (iterator.hasNext) { - val entry = iterator.next - val (key, value) = (entry.getKey.toString, entry.getValue.toString) + for ((key, value) <- scheduler.sc.conf.getAll) { if (key.startsWith("spark.") && !key.equals("spark.hostPort")) { properties += ((key, value)) } } + //TODO (prashant) send conf instead of properties driverActor = actorSystem.actorOf( Props(new DriverActor(properties)), name = CoarseGrainedSchedulerBackend.ACTOR_NAME) } @@ -208,8 +208,10 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac driverActor ! KillTask(taskId, executorId) } - override def defaultParallelism() = Option(System.getProperty("spark.default.parallelism")) - .map(_.toInt).getOrElse(math.max(totalCoreCount.get(), 2)) + override def defaultParallelism(): Int = { + conf.getOption("spark.default.parallelism").map(_.toInt).getOrElse( + math.max(totalCoreCount.get(), 2)) + } // Called by subclasses when notified of a lost worker def removeExecutor(executorId: String, reason: String) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala index e8fecec4a6..b44d1e43c8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala @@ -19,10 +19,12 @@ package org.apache.spark.scheduler.cluster import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{Path, FileSystem} + import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.scheduler.TaskSchedulerImpl private[spark] class SimrSchedulerBackend( - scheduler: ClusterScheduler, + scheduler: TaskSchedulerImpl, sc: SparkContext, driverFilePath: String) extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) @@ -31,13 +33,13 @@ private[spark] class SimrSchedulerBackend( val tmpPath = new Path(driverFilePath + "_tmp") val filePath = new Path(driverFilePath) - val maxCores = System.getProperty("spark.simr.executor.cores", "1").toInt + val maxCores = conf.get("spark.simr.executor.cores", "1").toInt override def start() { super.start() val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format( - System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"), + sc.conf.get("spark.driver.host"), sc.conf.get("spark.driver.port"), CoarseGrainedSchedulerBackend.ACTOR_NAME) val conf = new Configuration() diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 7127a72d6d..9858717d13 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -17,14 +17,16 @@ package org.apache.spark.scheduler.cluster +import scala.collection.mutable.HashMap + import org.apache.spark.{Logging, SparkContext} import org.apache.spark.deploy.client.{Client, ClientListener} import org.apache.spark.deploy.{Command, ApplicationDescription} -import scala.collection.mutable.HashMap +import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SlaveLost, TaskSchedulerImpl} import org.apache.spark.util.Utils private[spark] class SparkDeploySchedulerBackend( - scheduler: ClusterScheduler, + scheduler: TaskSchedulerImpl, sc: SparkContext, masters: Array[String], appName: String) @@ -36,23 +38,23 @@ private[spark] class SparkDeploySchedulerBackend( var stopping = false var shutdownCallback : (SparkDeploySchedulerBackend) => Unit = _ - val maxCores = System.getProperty("spark.cores.max", Int.MaxValue.toString).toInt + val maxCores = conf.get("spark.cores.max", Int.MaxValue.toString).toInt override def start() { super.start() // The endpoint for executors to talk to us val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format( - System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"), + conf.get("spark.driver.host"), conf.get("spark.driver.port"), CoarseGrainedSchedulerBackend.ACTOR_NAME) val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}") val command = Command( "org.apache.spark.executor.CoarseGrainedExecutorBackend", args, sc.executorEnvs) val sparkHome = sc.getSparkHome().getOrElse(null) - val appDesc = new ApplicationDescription(appName, maxCores, executorMemory, command, sparkHome, + val appDesc = new ApplicationDescription(appName, maxCores, sc.executorMemory, command, sparkHome, "http://" + sc.ui.appUIAddress) - client = new Client(sc.env.actorSystem, masters, appDesc, this) + client = new Client(sc.env.actorSystem, masters, appDesc, this, conf) client.start() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 84fe3094cc..d46fceba89 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -30,7 +30,8 @@ import org.apache.mesos._ import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _} import org.apache.spark.{SparkException, Logging, SparkContext, TaskState} -import org.apache.spark.scheduler.cluster.{ClusterScheduler, CoarseGrainedSchedulerBackend} +import org.apache.spark.scheduler.TaskSchedulerImpl +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend /** * A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds @@ -43,7 +44,7 @@ import org.apache.spark.scheduler.cluster.{ClusterScheduler, CoarseGrainedSchedu * remove this. */ private[spark] class CoarseMesosSchedulerBackend( - scheduler: ClusterScheduler, + scheduler: TaskSchedulerImpl, sc: SparkContext, master: String, appName: String) @@ -61,7 +62,7 @@ private[spark] class CoarseMesosSchedulerBackend( var driver: SchedulerDriver = null // Maximum number of cores to acquire (TODO: we'll need more flexible controls here) - val maxCores = System.getProperty("spark.cores.max", Int.MaxValue.toString).toInt + val maxCores = conf.get("spark.cores.max", Int.MaxValue.toString).toInt // Cores we have acquired with each Mesos task ID val coresByTaskId = new HashMap[Int, Int] @@ -76,7 +77,7 @@ private[spark] class CoarseMesosSchedulerBackend( "Spark home is not set; set it through the spark.home system " + "property, the SPARK_HOME environment variable or the SparkContext constructor")) - val extraCoresPerSlave = System.getProperty("spark.mesos.extra.cores", "0").toInt + val extraCoresPerSlave = conf.get("spark.mesos.extra.cores", "0").toInt var nextMesosTaskId = 0 @@ -121,12 +122,12 @@ private[spark] class CoarseMesosSchedulerBackend( val command = CommandInfo.newBuilder() .setEnvironment(environment) val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format( - System.getProperty("spark.driver.host"), - System.getProperty("spark.driver.port"), + conf.get("spark.driver.host"), + conf.get("spark.driver.port"), CoarseGrainedSchedulerBackend.ACTOR_NAME) - val uri = System.getProperty("spark.executor.uri") + val uri = conf.get("spark.executor.uri", null) if (uri == null) { - val runScript = new File(sparkHome, "spark-class").getCanonicalPath + val runScript = new File(sparkHome, "./bin/spark-class").getCanonicalPath command.setValue( "\"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d".format( runScript, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores)) @@ -135,7 +136,7 @@ private[spark] class CoarseMesosSchedulerBackend( // glob the directory "correctly". val basename = uri.split('/').last.split('.').head command.setValue( - "cd %s*; ./spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d" + "cd %s*; ./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d" .format(basename, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores)) command.addUris(CommandInfo.URI.newBuilder().setValue(uri)) } @@ -176,7 +177,7 @@ private[spark] class CoarseMesosSchedulerBackend( val slaveId = offer.getSlaveId.toString val mem = getResource(offer.getResourcesList, "mem") val cpus = getResource(offer.getResourcesList, "cpus").toInt - if (totalCoresAcquired < maxCores && mem >= executorMemory && cpus >= 1 && + if (totalCoresAcquired < maxCores && mem >= sc.executorMemory && cpus >= 1 && failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES && !slaveIdsWithExecutors.contains(slaveId)) { // Launch an executor on the slave @@ -192,7 +193,7 @@ private[spark] class CoarseMesosSchedulerBackend( .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave)) .setName("Task " + taskId) .addResources(createResource("cpus", cpusToUse)) - .addResources(createResource("mem", executorMemory)) + .addResources(createResource("mem", sc.executorMemory)) .build() d.launchTasks(offer.getId, Collections.singletonList(task), filters) } else { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 50cbc2ca92..ae8d527352 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -30,9 +30,8 @@ import org.apache.mesos._ import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _} import org.apache.spark.{Logging, SparkException, SparkContext, TaskState} -import org.apache.spark.scheduler.TaskDescription -import org.apache.spark.scheduler.cluster.{ClusterScheduler, ExecutorExited, ExecutorLossReason} -import org.apache.spark.scheduler.cluster.{SchedulerBackend, SlaveLost, WorkerOffer} +import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SchedulerBackend, SlaveLost, + TaskDescription, TaskSchedulerImpl, WorkerOffer} import org.apache.spark.util.Utils /** @@ -41,7 +40,7 @@ import org.apache.spark.util.Utils * from multiple apps can run on different cores) and in time (a core can switch ownership). */ private[spark] class MesosSchedulerBackend( - scheduler: ClusterScheduler, + scheduler: TaskSchedulerImpl, sc: SparkContext, master: String, appName: String) @@ -101,20 +100,20 @@ private[spark] class MesosSchedulerBackend( } val command = CommandInfo.newBuilder() .setEnvironment(environment) - val uri = System.getProperty("spark.executor.uri") + val uri = sc.conf.get("spark.executor.uri", null) if (uri == null) { - command.setValue(new File(sparkHome, "spark-executor").getCanonicalPath) + command.setValue(new File(sparkHome, "/sbin/spark-executor").getCanonicalPath) } else { // Grab everything to the first '.'. We'll use that and '*' to // glob the directory "correctly". val basename = uri.split('/').last.split('.').head - command.setValue("cd %s*; ./spark-executor".format(basename)) + command.setValue("cd %s*; ./sbin/spark-executor".format(basename)) command.addUris(CommandInfo.URI.newBuilder().setValue(uri)) } val memory = Resource.newBuilder() .setName("mem") .setType(Value.Type.SCALAR) - .setScalar(Value.Scalar.newBuilder().setValue(executorMemory).build()) + .setScalar(Value.Scalar.newBuilder().setValue(sc.executorMemory).build()) .build() ExecutorInfo.newBuilder() .setExecutorId(ExecutorID.newBuilder().setValue(execId).build()) @@ -199,7 +198,7 @@ private[spark] class MesosSchedulerBackend( def enoughMemory(o: Offer) = { val mem = getResource(o.getResourcesList, "mem") val slaveId = o.getSlaveId.getValue - mem >= executorMemory || slaveIdsWithExecutors.contains(slaveId) + mem >= sc.executorMemory || slaveIdsWithExecutors.contains(slaveId) } for ((offer, index) <- offers.zipWithIndex if enoughMemory(offer)) { @@ -341,5 +340,5 @@ private[spark] class MesosSchedulerBackend( } // TODO: query Mesos for number of cores - override def defaultParallelism() = System.getProperty("spark.default.parallelism", "8").toInt + override def defaultParallelism() = sc.conf.get("spark.default.parallelism", "8").toInt } diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala new file mode 100644 index 0000000000..897d47a9ad --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.local + +import java.nio.ByteBuffer + +import akka.actor.{Actor, ActorRef, Props} + +import org.apache.spark.{Logging, SparkContext, SparkEnv, TaskState} +import org.apache.spark.TaskState.TaskState +import org.apache.spark.executor.{Executor, ExecutorBackend} +import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer} + +private case class ReviveOffers() + +private case class StatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) + +private case class KillTask(taskId: Long) + +/** + * Calls to LocalBackend are all serialized through LocalActor. Using an actor makes the calls on + * LocalBackend asynchronous, which is necessary to prevent deadlock between LocalBackend + * and the ClusterScheduler. + */ +private[spark] class LocalActor( + scheduler: TaskSchedulerImpl, + executorBackend: LocalBackend, + private val totalCores: Int) extends Actor with Logging { + + private var freeCores = totalCores + + private val localExecutorId = "localhost" + private val localExecutorHostname = "localhost" + + val executor = new Executor( + localExecutorId, localExecutorHostname, scheduler.conf.getAll, isLocal = true) + + def receive = { + case ReviveOffers => + reviveOffers() + + case StatusUpdate(taskId, state, serializedData) => + scheduler.statusUpdate(taskId, state, serializedData) + if (TaskState.isFinished(state)) { + freeCores += 1 + reviveOffers() + } + + case KillTask(taskId) => + executor.killTask(taskId) + } + + def reviveOffers() { + val offers = Seq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores)) + for (task <- scheduler.resourceOffers(offers).flatten) { + freeCores -= 1 + executor.launchTask(executorBackend, task.taskId, task.serializedTask) + } + } +} + +/** + * LocalBackend is used when running a local version of Spark where the executor, backend, and + * master all run in the same JVM. It sits behind a ClusterScheduler and handles launching tasks + * on a single Executor (created by the LocalBackend) running locally. + */ +private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: Int) + extends SchedulerBackend with ExecutorBackend { + + var localActor: ActorRef = null + + override def start() { + localActor = SparkEnv.get.actorSystem.actorOf( + Props(new LocalActor(scheduler, this, totalCores)), + "LocalBackendActor") + } + + override def stop() { + } + + override def reviveOffers() { + localActor ! ReviveOffers + } + + override def defaultParallelism() = totalCores + + override def killTask(taskId: Long, executorId: String) { + localActor ! KillTask(taskId) + } + + override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) { + localActor ! StatusUpdate(taskId, state, serializedData) + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala deleted file mode 100644 index 01e95162c0..0000000000 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala +++ /dev/null @@ -1,222 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler.local - -import java.nio.ByteBuffer -import java.util.concurrent.atomic.AtomicInteger - -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} - -import akka.actor._ - -import org.apache.spark._ -import org.apache.spark.TaskState.TaskState -import org.apache.spark.executor.{Executor, ExecutorBackend} -import org.apache.spark.scheduler._ -import org.apache.spark.scheduler.SchedulingMode.SchedulingMode - - -/** - * A FIFO or Fair TaskScheduler implementation that runs tasks locally in a thread pool. Optionally - * the scheduler also allows each task to fail up to maxFailures times, which is useful for - * testing fault recovery. - */ - -private[local] -case class LocalReviveOffers() - -private[local] -case class LocalStatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) - -private[local] -case class KillTask(taskId: Long) - -private[spark] -class LocalActor(localScheduler: LocalScheduler, private var freeCores: Int) - extends Actor with Logging { - - val executor = new Executor("localhost", "localhost", Seq.empty, isLocal = true) - - def receive = { - case LocalReviveOffers => - launchTask(localScheduler.resourceOffer(freeCores)) - - case LocalStatusUpdate(taskId, state, serializeData) => - if (TaskState.isFinished(state)) { - freeCores += 1 - launchTask(localScheduler.resourceOffer(freeCores)) - } - - case KillTask(taskId) => - executor.killTask(taskId) - } - - private def launchTask(tasks: Seq[TaskDescription]) { - for (task <- tasks) { - freeCores -= 1 - executor.launchTask(localScheduler, task.taskId, task.serializedTask) - } - } -} - -private[spark] class LocalScheduler(val threads: Int, val maxFailures: Int, val sc: SparkContext) - extends TaskScheduler - with ExecutorBackend - with Logging { - - val env = SparkEnv.get - val attemptId = new AtomicInteger - var dagScheduler: DAGScheduler = null - - // Application dependencies (added through SparkContext) that we've fetched so far on this node. - // Each map holds the master's timestamp for the version of that file or JAR we got. - val currentFiles: HashMap[String, Long] = new HashMap[String, Long]() - val currentJars: HashMap[String, Long] = new HashMap[String, Long]() - - var schedulableBuilder: SchedulableBuilder = null - var rootPool: Pool = null - val schedulingMode: SchedulingMode = SchedulingMode.withName( - System.getProperty("spark.scheduler.mode", "FIFO")) - val activeTaskSets = new HashMap[String, LocalTaskSetManager] - val taskIdToTaskSetId = new HashMap[Long, String] - val taskSetTaskIds = new HashMap[String, HashSet[Long]] - - var localActor: ActorRef = null - - override def start() { - // temporarily set rootPool name to empty - rootPool = new Pool("", schedulingMode, 0, 0) - schedulableBuilder = { - schedulingMode match { - case SchedulingMode.FIFO => - new FIFOSchedulableBuilder(rootPool) - case SchedulingMode.FAIR => - new FairSchedulableBuilder(rootPool) - } - } - schedulableBuilder.buildPools() - - localActor = env.actorSystem.actorOf(Props(new LocalActor(this, threads)), "Test") - } - - override def setDAGScheduler(dagScheduler: DAGScheduler) { - this.dagScheduler = dagScheduler - } - - override def submitTasks(taskSet: TaskSet) { - synchronized { - val manager = new LocalTaskSetManager(this, taskSet) - schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) - activeTaskSets(taskSet.id) = manager - taskSetTaskIds(taskSet.id) = new HashSet[Long]() - localActor ! LocalReviveOffers - } - } - - override def cancelTasks(stageId: Int): Unit = synchronized { - logInfo("Cancelling stage " + stageId) - logInfo("Cancelling stage " + activeTaskSets.map(_._2.stageId)) - activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) => - // There are two possible cases here: - // 1. The task set manager has been created and some tasks have been scheduled. - // In this case, send a kill signal to the executors to kill the task and then abort - // the stage. - // 2. The task set manager has been created but no tasks has been scheduled. In this case, - // simply abort the stage. - val taskIds = taskSetTaskIds(tsm.taskSet.id) - if (taskIds.size > 0) { - taskIds.foreach { tid => - localActor ! KillTask(tid) - } - } - logInfo("Stage %d was cancelled".format(stageId)) - taskSetFinished(tsm) - } - } - - def resourceOffer(freeCores: Int): Seq[TaskDescription] = { - synchronized { - var freeCpuCores = freeCores - val tasks = new ArrayBuffer[TaskDescription](freeCores) - val sortedTaskSetQueue = rootPool.getSortedTaskSetQueue() - for (manager <- sortedTaskSetQueue) { - logDebug("parentName:%s,name:%s,runningTasks:%s".format( - manager.parent.name, manager.name, manager.runningTasks)) - } - - var launchTask = false - for (manager <- sortedTaskSetQueue) { - do { - launchTask = false - manager.resourceOffer(null, null, freeCpuCores, null) match { - case Some(task) => - tasks += task - taskIdToTaskSetId(task.taskId) = manager.taskSet.id - taskSetTaskIds(manager.taskSet.id) += task.taskId - freeCpuCores -= 1 - launchTask = true - case None => {} - } - } while(launchTask) - } - return tasks - } - } - - def taskSetFinished(manager: TaskSetManager) { - synchronized { - activeTaskSets -= manager.taskSet.id - manager.parent.removeSchedulable(manager) - logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name)) - taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id) - taskSetTaskIds -= manager.taskSet.id - } - } - - override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) { - if (TaskState.isFinished(state)) { - synchronized { - taskIdToTaskSetId.get(taskId) match { - case Some(taskSetId) => - val taskSetManager = activeTaskSets.get(taskSetId) - taskSetManager.foreach { tsm => - taskSetTaskIds(taskSetId) -= taskId - - state match { - case TaskState.FINISHED => - tsm.taskEnded(taskId, state, serializedData) - case TaskState.FAILED => - tsm.taskFailed(taskId, state, serializedData) - case TaskState.KILLED => - tsm.error("Task %d was killed".format(taskId)) - case _ => {} - } - } - case None => - logInfo("Ignoring update from TID " + taskId + " because its task set is gone") - } - } - localActor ! LocalStatusUpdate(taskId, state, serializedData) - } - } - - override def stop() { - } - - override def defaultParallelism() = threads -} diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala deleted file mode 100644 index 53bf78267e..0000000000 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala +++ /dev/null @@ -1,191 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler.local - -import java.nio.ByteBuffer -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap - -import org.apache.spark.{ExceptionFailure, Logging, SparkEnv, SparkException, Success, TaskState} -import org.apache.spark.TaskState.TaskState -import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Pool, Schedulable, Task, - TaskDescription, TaskInfo, TaskLocality, TaskResult, TaskSet, TaskSetManager} - - -private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: TaskSet) - extends TaskSetManager with Logging { - - var parent: Pool = null - var weight: Int = 1 - var minShare: Int = 0 - var runningTasks: Int = 0 - var priority: Int = taskSet.priority - var stageId: Int = taskSet.stageId - var name: String = "TaskSet_" + taskSet.stageId.toString - - var failCount = new Array[Int](taskSet.tasks.size) - val taskInfos = new HashMap[Long, TaskInfo] - val numTasks = taskSet.tasks.size - var numFinished = 0 - val env = SparkEnv.get - val ser = env.closureSerializer.newInstance() - val copiesRunning = new Array[Int](numTasks) - val finished = new Array[Boolean](numTasks) - val numFailures = new Array[Int](numTasks) - val MAX_TASK_FAILURES = sched.maxFailures - - def increaseRunningTasks(taskNum: Int): Unit = { - runningTasks += taskNum - if (parent != null) { - parent.increaseRunningTasks(taskNum) - } - } - - def decreaseRunningTasks(taskNum: Int): Unit = { - runningTasks -= taskNum - if (parent != null) { - parent.decreaseRunningTasks(taskNum) - } - } - - override def addSchedulable(schedulable: Schedulable): Unit = { - // nothing - } - - override def removeSchedulable(schedulable: Schedulable): Unit = { - // nothing - } - - override def getSchedulableByName(name: String): Schedulable = { - return null - } - - override def executorLost(executorId: String, host: String): Unit = { - // nothing - } - - override def checkSpeculatableTasks() = true - - override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { - var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager] - sortedTaskSetQueue += this - return sortedTaskSetQueue - } - - override def hasPendingTasks() = true - - def findTask(): Option[Int] = { - for (i <- 0 to numTasks-1) { - if (copiesRunning(i) == 0 && !finished(i)) { - return Some(i) - } - } - return None - } - - override def resourceOffer( - execId: String, - host: String, - availableCpus: Int, - maxLocality: TaskLocality.TaskLocality) - : Option[TaskDescription] = - { - SparkEnv.set(sched.env) - logDebug("availableCpus:%d, numFinished:%d, numTasks:%d".format( - availableCpus.toInt, numFinished, numTasks)) - if (availableCpus > 0 && numFinished < numTasks) { - findTask() match { - case Some(index) => - val taskId = sched.attemptId.getAndIncrement() - val task = taskSet.tasks(index) - val info = new TaskInfo(taskId, index, System.currentTimeMillis(), "local", "local:1", - TaskLocality.NODE_LOCAL) - taskInfos(taskId) = info - // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here - // we assume the task can be serialized without exceptions. - val bytes = Task.serializeWithDependencies( - task, sched.sc.addedFiles, sched.sc.addedJars, ser) - logInfo("Size of task " + taskId + " is " + bytes.limit + " bytes") - val taskName = "task %s:%d".format(taskSet.id, index) - copiesRunning(index) += 1 - increaseRunningTasks(1) - taskStarted(task, info) - return Some(new TaskDescription(taskId, null, taskName, index, bytes)) - case None => {} - } - } - return None - } - - def taskStarted(task: Task[_], info: TaskInfo) { - sched.dagScheduler.taskStarted(task, info) - } - - def taskEnded(tid: Long, state: TaskState, serializedData: ByteBuffer) { - val info = taskInfos(tid) - val index = info.index - val task = taskSet.tasks(index) - info.markSuccessful() - val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader) match { - case directResult: DirectTaskResult[_] => directResult - case IndirectTaskResult(blockId) => { - throw new SparkException("Expect only DirectTaskResults when using LocalScheduler") - } - } - result.metrics.resultSize = serializedData.limit() - sched.dagScheduler.taskEnded(task, Success, result.value, result.accumUpdates, info, - result.metrics) - numFinished += 1 - decreaseRunningTasks(1) - finished(index) = true - if (numFinished == numTasks) { - sched.taskSetFinished(this) - } - } - - def taskFailed(tid: Long, state: TaskState, serializedData: ByteBuffer) { - val info = taskInfos(tid) - val index = info.index - val task = taskSet.tasks(index) - info.markFailed() - decreaseRunningTasks(1) - val reason: ExceptionFailure = ser.deserialize[ExceptionFailure]( - serializedData, getClass.getClassLoader) - sched.dagScheduler.taskEnded(task, reason, null, null, info, reason.metrics.getOrElse(null)) - if (!finished(index)) { - copiesRunning(index) -= 1 - numFailures(index) += 1 - val locs = reason.stackTrace.map(loc => "\tat %s".format(loc.toString)) - logInfo("Loss was due to %s\n%s\n%s".format( - reason.className, reason.description, locs.mkString("\n"))) - if (numFailures(index) > MAX_TASK_FAILURES) { - val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format( - taskSet.id, index, MAX_TASK_FAILURES, reason.description) - decreaseRunningTasks(runningTasks) - sched.dagScheduler.taskSetFailed(taskSet, errorMessage) - // need to delete failed Taskset from schedule queue - sched.taskSetFinished(this) - } - } - } - - override def error(message: String) { - sched.dagScheduler.taskSetFailed(taskSet, message) - sched.taskSetFinished(this) - } -} diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index 4de81617b1..5d3d43623d 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -21,6 +21,7 @@ import java.io._ import java.nio.ByteBuffer import org.apache.spark.util.ByteBufferInputStream +import org.apache.spark.SparkConf private[spark] class JavaSerializationStream(out: OutputStream) extends SerializationStream { val objOut = new ObjectOutputStream(out) @@ -77,6 +78,6 @@ private[spark] class JavaSerializerInstance extends SerializerInstance { /** * A Spark serializer that uses Java's built-in serialization. */ -class JavaSerializer extends Serializer { +class JavaSerializer(conf: SparkConf) extends Serializer { def newInstance(): SerializerInstance = new JavaSerializerInstance } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index e748c2275d..a24a3b04b8 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -25,18 +25,18 @@ import com.esotericsoftware.kryo.{KryoException, Kryo} import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} import com.twitter.chill.{EmptyScalaKryoInstantiator, AllScalaRegistrar} -import org.apache.spark.{SerializableWritable, Logging} +import org.apache.spark._ import org.apache.spark.broadcast.HttpBroadcast import org.apache.spark.scheduler.MapStatus import org.apache.spark.storage._ +import org.apache.spark.storage.{GetBlock, GotBlock, PutBlock} /** * A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]]. */ -class KryoSerializer extends org.apache.spark.serializer.Serializer with Logging { - +class KryoSerializer(conf: SparkConf) extends org.apache.spark.serializer.Serializer with Logging { private val bufferSize = { - System.getProperty("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024 + conf.get("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024 } def newKryoOutput() = new KryoOutput(bufferSize) @@ -48,7 +48,7 @@ class KryoSerializer extends org.apache.spark.serializer.Serializer with Logging // Allow disabling Kryo reference tracking if user knows their object graphs don't have loops. // Do this before we invoke the user registrator so the user registrator can override this. - kryo.setReferences(System.getProperty("spark.kryo.referenceTracking", "true").toBoolean) + kryo.setReferences(conf.get("spark.kryo.referenceTracking", "true").toBoolean) for (cls <- KryoSerializer.toRegister) kryo.register(cls) @@ -58,13 +58,13 @@ class KryoSerializer extends org.apache.spark.serializer.Serializer with Logging // Allow the user to register their own classes by setting spark.kryo.registrator try { - Option(System.getProperty("spark.kryo.registrator")).foreach { regCls => + for (regCls <- conf.getOption("spark.kryo.registrator")) { logDebug("Running user registrator: " + regCls) val reg = Class.forName(regCls, true, classLoader).newInstance().asInstanceOf[KryoRegistrator] reg.registerClasses(kryo) } } catch { - case _: Exception => println("Failed to register spark.kryo.registrator") + case e: Exception => logError("Failed to run spark.kryo.registrator", e) } // Register Chill's classes; we do this after our ranges and the user's own classes to let diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index 160cca4d6c..9a5e3cb77e 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -29,6 +29,9 @@ import org.apache.spark.util.{NextIterator, ByteBufferInputStream} * A serializer. Because some serialization libraries are not thread safe, this class is used to * create [[org.apache.spark.serializer.SerializerInstance]] objects that do the actual serialization and are * guaranteed to only be called from one thread at a time. + * + * Implementations of this trait should have a zero-arg constructor or a constructor that accepts a + * [[org.apache.spark.SparkConf]] as parameter. If both constructors are defined, the latter takes precedence. */ trait Serializer { def newInstance(): SerializerInstance diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala index 2955986fec..36a37af4f8 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala @@ -18,6 +18,7 @@ package org.apache.spark.serializer import java.util.concurrent.ConcurrentHashMap +import org.apache.spark.SparkConf /** @@ -26,18 +27,19 @@ import java.util.concurrent.ConcurrentHashMap * creating a new one. */ private[spark] class SerializerManager { + // TODO: Consider moving this into SparkConf itself to remove the global singleton. private val serializers = new ConcurrentHashMap[String, Serializer] private var _default: Serializer = _ def default = _default - def setDefault(clsName: String): Serializer = { - _default = get(clsName) + def setDefault(clsName: String, conf: SparkConf): Serializer = { + _default = get(clsName, conf) _default } - def get(clsName: String): Serializer = { + def get(clsName: String, conf: SparkConf): Serializer = { if (clsName == null) { default } else { @@ -51,8 +53,19 @@ private[spark] class SerializerManager { serializer = serializers.get(clsName) if (serializer == null) { val clsLoader = Thread.currentThread.getContextClassLoader - serializer = - Class.forName(clsName, true, clsLoader).newInstance().asInstanceOf[Serializer] + val cls = Class.forName(clsName, true, clsLoader) + + // First try with the constructor that takes SparkConf. If we can't find one, + // use a no-arg constructor instead. + try { + val constructor = cls.getConstructor(classOf[SparkConf]) + serializer = constructor.newInstance(conf).asInstanceOf[Serializer] + } catch { + case _: NoSuchMethodException => + val constructor = cls.getConstructor() + serializer = constructor.newInstance().asInstanceOf[Serializer] + } + serializers.put(clsName, serializer) } serializer diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala index e51c5b30a3..47478631a1 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala @@ -312,7 +312,7 @@ object BlockFetcherIterator { logDebug("Sending request for %d blocks (%s) from %s".format( req.blocks.size, Utils.bytesToString(req.size), req.address.host)) val cmId = new ConnectionManagerId(req.address.host, req.address.nettyPort) - val cpier = new ShuffleCopier + val cpier = new ShuffleCopier(blockManager.conf) cpier.getBlocks(cmId, req.blocks, putResult) logDebug("Sent request for remote blocks " + req.blocks + " from " + req.address.host ) } @@ -327,7 +327,7 @@ object BlockFetcherIterator { fetchRequestsSync.put(request) } - copiers = startCopiers(System.getProperty("spark.shuffle.copier.threads", "6").toInt) + copiers = startCopiers(conf.get("spark.shuffle.copier.threads", "6").toInt) logInfo("Started " + fetchRequestsSync.size + " remote gets in " + Utils.getUsedTimeMs(startTime)) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 19a025a329..6d2cda97b0 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -30,7 +30,7 @@ import scala.concurrent.duration._ import it.unimi.dsi.fastutil.io.{FastBufferedOutputStream, FastByteArrayOutputStream} -import org.apache.spark.{Logging, SparkEnv, SparkException} +import org.apache.spark.{SparkConf, Logging, SparkEnv, SparkException} import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ import org.apache.spark.serializer.Serializer @@ -43,12 +43,13 @@ private[spark] class BlockManager( actorSystem: ActorSystem, val master: BlockManagerMaster, val defaultSerializer: Serializer, - maxMemory: Long) + maxMemory: Long, + val conf: SparkConf) extends Logging { val shuffleBlockManager = new ShuffleBlockManager(this) val diskBlockManager = new DiskBlockManager(shuffleBlockManager, - System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir"))) + conf.get("spark.local.dir", System.getProperty("java.io.tmpdir"))) private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo] @@ -57,12 +58,12 @@ private[spark] class BlockManager( // If we use Netty for shuffle, start a new Netty-based shuffle sender service. private val nettyPort: Int = { - val useNetty = System.getProperty("spark.shuffle.use.netty", "false").toBoolean - val nettyPortConfig = System.getProperty("spark.shuffle.sender.port", "0").toInt + val useNetty = conf.get("spark.shuffle.use.netty", "false").toBoolean + val nettyPortConfig = conf.get("spark.shuffle.sender.port", "0").toInt if (useNetty) diskBlockManager.startShuffleBlockSender(nettyPortConfig) else 0 } - val connectionManager = new ConnectionManager(0) + val connectionManager = new ConnectionManager(0, conf) implicit val futureExecContext = connectionManager.futureExecContext val blockManagerId = BlockManagerId( @@ -71,18 +72,18 @@ private[spark] class BlockManager( // Max megabytes of data to keep in flight per reducer (to avoid over-allocating memory // for receiving shuffle outputs) val maxBytesInFlight = - System.getProperty("spark.reducer.maxMbInFlight", "48").toLong * 1024 * 1024 + conf.get("spark.reducer.maxMbInFlight", "48").toLong * 1024 * 1024 // Whether to compress broadcast variables that are stored - val compressBroadcast = System.getProperty("spark.broadcast.compress", "true").toBoolean + val compressBroadcast = conf.get("spark.broadcast.compress", "true").toBoolean // Whether to compress shuffle output that are stored - val compressShuffle = System.getProperty("spark.shuffle.compress", "true").toBoolean + val compressShuffle = conf.get("spark.shuffle.compress", "true").toBoolean // Whether to compress RDD partitions that are stored serialized - val compressRdds = System.getProperty("spark.rdd.compress", "false").toBoolean + val compressRdds = conf.get("spark.rdd.compress", "false").toBoolean - val heartBeatFrequency = BlockManager.getHeartBeatFrequencyFromSystemProperties + val heartBeatFrequency = BlockManager.getHeartBeatFrequency(conf) - val hostPort = Utils.localHostPort() + val hostPort = Utils.localHostPort(conf) val slaveActor = actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)), name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next) @@ -100,8 +101,11 @@ private[spark] class BlockManager( var heartBeatTask: Cancellable = null - private val metadataCleaner = new MetadataCleaner(MetadataCleanerType.BLOCK_MANAGER, this.dropOldNonBroadcastBlocks) - private val broadcastCleaner = new MetadataCleaner(MetadataCleanerType.BROADCAST_VARS, this.dropOldBroadcastBlocks) + private val metadataCleaner = new MetadataCleaner( + MetadataCleanerType.BLOCK_MANAGER, this.dropOldNonBroadcastBlocks, conf) + private val broadcastCleaner = new MetadataCleaner( + MetadataCleanerType.BROADCAST_VARS, this.dropOldBroadcastBlocks, conf) + initialize() // The compression codec to use. Note that the "lazy" val is necessary because we want to delay @@ -109,14 +113,14 @@ private[spark] class BlockManager( // program could be using a user-defined codec in a third party jar, which is loaded in // Executor.updateDependencies. When the BlockManager is initialized, user level jars hasn't been // loaded yet. - private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec() + private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec(conf) /** * Construct a BlockManager with a memory limit set based on system properties. */ def this(execId: String, actorSystem: ActorSystem, master: BlockManagerMaster, - serializer: Serializer) = { - this(execId, actorSystem, master, serializer, BlockManager.getMaxMemoryFromSystemProperties) + serializer: Serializer, conf: SparkConf) = { + this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), conf) } /** @@ -126,7 +130,7 @@ private[spark] class BlockManager( private def initialize() { master.registerBlockManager(blockManagerId, maxMemory, slaveActor) BlockManagerWorker.startBlockManagerWorker(this) - if (!BlockManager.getDisableHeartBeatsForTesting) { + if (!BlockManager.getDisableHeartBeatsForTesting(conf)) { heartBeatTask = actorSystem.scheduler.schedule(0.seconds, heartBeatFrequency.milliseconds) { heartBeat() } @@ -439,7 +443,7 @@ private[spark] class BlockManager( : BlockFetcherIterator = { val iter = - if (System.getProperty("spark.shuffle.use.netty", "false").toBoolean) { + if (conf.get("spark.shuffle.use.netty", "false").toBoolean) { new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer) } else { new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer) @@ -465,7 +469,8 @@ private[spark] class BlockManager( def getDiskWriter(blockId: BlockId, file: File, serializer: Serializer, bufferSize: Int) : BlockObjectWriter = { val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _) - new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream) + val syncWrites = conf.get("spark.shuffle.sync", "false").toBoolean + new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream, syncWrites) } /** @@ -856,19 +861,18 @@ private[spark] class BlockManager( private[spark] object BlockManager extends Logging { - val ID_GENERATOR = new IdGenerator - def getMaxMemoryFromSystemProperties: Long = { - val memoryFraction = System.getProperty("spark.storage.memoryFraction", "0.66").toDouble + def getMaxMemory(conf: SparkConf): Long = { + val memoryFraction = conf.get("spark.storage.memoryFraction", "0.66").toDouble (Runtime.getRuntime.maxMemory * memoryFraction).toLong } - def getHeartBeatFrequencyFromSystemProperties: Long = - System.getProperty("spark.storage.blockManagerTimeoutIntervalMs", "60000").toLong / 4 + def getHeartBeatFrequency(conf: SparkConf): Long = + conf.get("spark.storage.blockManagerTimeoutIntervalMs", "60000").toLong / 4 - def getDisableHeartBeatsForTesting: Boolean = - System.getProperty("spark.test.disableBlockManagerHeartBeat", "false").toBoolean + def getDisableHeartBeatsForTesting(conf: SparkConf): Boolean = + conf.get("spark.test.disableBlockManagerHeartBeat", "false").toBoolean /** * Attempt to clean up a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index e1d68ef592..b5afe8cd23 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -23,19 +23,20 @@ import scala.concurrent.ExecutionContext.Implicits.global import akka.actor._ import akka.pattern.ask -import org.apache.spark.{Logging, SparkException} +import org.apache.spark.{SparkConf, Logging, SparkException} import org.apache.spark.storage.BlockManagerMessages._ import org.apache.spark.util.AkkaUtils private[spark] -class BlockManagerMaster(var driverActor : Either[ActorRef, ActorSelection]) extends Logging { +class BlockManagerMaster(var driverActor : Either[ActorRef, ActorSelection], + conf: SparkConf) extends Logging { - val AKKA_RETRY_ATTEMPTS: Int = System.getProperty("spark.akka.num.retries", "3").toInt - val AKKA_RETRY_INTERVAL_MS: Int = System.getProperty("spark.akka.retry.wait", "3000").toInt + val AKKA_RETRY_ATTEMPTS: Int = conf.get("spark.akka.num.retries", "3").toInt + val AKKA_RETRY_INTERVAL_MS: Int = conf.get("spark.akka.retry.wait", "3000").toInt val DRIVER_AKKA_ACTOR_NAME = "BlockManagerMaster" - val timeout = AkkaUtils.askTimeout + val timeout = AkkaUtils.askTimeout(conf) /** Remove a dead executor from the driver actor. This is only called on the driver side. */ def removeExecutor(execId: String) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 1e44b1ae15..52a424db22 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -27,7 +27,7 @@ import scala.concurrent.duration._ import akka.actor.{Actor, ActorRef, Cancellable} import akka.pattern.ask -import org.apache.spark.{Logging, SparkException} +import org.apache.spark.{SparkConf, Logging, SparkException} import org.apache.spark.storage.BlockManagerMessages._ import org.apache.spark.util.{AkkaUtils, Utils} @@ -36,7 +36,7 @@ import org.apache.spark.util.{AkkaUtils, Utils} * all slaves' block managers. */ private[spark] -class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { +class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf) extends Actor with Logging { // Mapping from block manager id to the block manager's information. private val blockManagerInfo = @@ -48,20 +48,18 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { // Mapping from block id to the set of block managers that have the block. private val blockLocations = new JHashMap[BlockId, mutable.HashSet[BlockManagerId]] - private val akkaTimeout = AkkaUtils.askTimeout + private val akkaTimeout = AkkaUtils.askTimeout(conf) - initLogging() + val slaveTimeout = conf.get("spark.storage.blockManagerSlaveTimeoutMs", + "" + (BlockManager.getHeartBeatFrequency(conf) * 3)).toLong - val slaveTimeout = System.getProperty("spark.storage.blockManagerSlaveTimeoutMs", - "" + (BlockManager.getHeartBeatFrequencyFromSystemProperties * 3)).toLong - - val checkTimeoutInterval = System.getProperty("spark.storage.blockManagerTimeoutIntervalMs", + val checkTimeoutInterval = conf.get("spark.storage.blockManagerTimeoutIntervalMs", "60000").toLong var timeoutCheckingTask: Cancellable = null override def preStart() { - if (!BlockManager.getDisableHeartBeatsForTesting) { + if (!BlockManager.getDisableHeartBeatsForTesting(conf)) { import context.dispatcher timeoutCheckingTask = context.system.scheduler.schedule( 0.seconds, checkTimeoutInterval.milliseconds, self, ExpireDeadHosts) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala index 0c66addf9d..21f003609b 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala @@ -30,7 +30,6 @@ import org.apache.spark.util.Utils * TODO: Use event model. */ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends Logging { - initLogging() blockManager.connectionManager.onReceiveMessage(onBlockMessageReceive) @@ -101,8 +100,6 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends private[spark] object BlockManagerWorker extends Logging { private var blockManagerWorker: BlockManagerWorker = null - initLogging() - def startBlockManagerWorker(manager: BlockManager) { blockManagerWorker = new BlockManagerWorker(manager) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala index 6ce9127c74..a06f50a0ac 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala @@ -37,8 +37,6 @@ class BlockMessageArray(var blockMessages: Seq[BlockMessage]) extends Seq[BlockM def length = blockMessages.length - initLogging() - def set(bufferMessage: BufferMessage) { val startTime = System.currentTimeMillis val newBlockMessages = new ArrayBuffer[BlockMessage]() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index b4451fc7b8..61e63c60d5 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -74,7 +74,8 @@ class DiskBlockObjectWriter( file: File, serializer: Serializer, bufferSize: Int, - compressStream: OutputStream => OutputStream) + compressStream: OutputStream => OutputStream, + syncWrites: Boolean) extends BlockObjectWriter(blockId) with Logging { @@ -97,8 +98,6 @@ class DiskBlockObjectWriter( override def flush() = out.flush() } - private val syncWrites = System.getProperty("spark.shuffle.sync", "false").toBoolean - /** The file channel, used for repositioning / truncating the file. */ private var channel: FileChannel = null private var bs: OutputStream = null diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index fcd2e97982..55dcb3742c 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -38,7 +38,7 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD extends PathResolver with Logging { private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 - private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt + private val subDirsPerLocalDir = shuffleManager.conf.get("spark.diskStore.subDirectories", "64").toInt // Create one local directory for each path mentioned in spark.local.dir; then, inside this // directory, create multiple subdirectories that we will hash files into, in order to avoid diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index e828e1d1c5..39dc7bb19a 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -59,32 +59,40 @@ private[spark] trait ShuffleWriterGroup { */ private[spark] class ShuffleBlockManager(blockManager: BlockManager) { + def conf = blockManager.conf + // Turning off shuffle file consolidation causes all shuffle Blocks to get their own file. // TODO: Remove this once the shuffle file consolidation feature is stable. val consolidateShuffleFiles = - System.getProperty("spark.shuffle.consolidateFiles", "false").toBoolean + conf.get("spark.shuffle.consolidateFiles", "false").toBoolean - private val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024 + private val bufferSize = conf.get("spark.shuffle.file.buffer.kb", "100").toInt * 1024 /** * Contains all the state related to a particular shuffle. This includes a pool of unused * ShuffleFileGroups, as well as all ShuffleFileGroups that have been created for the shuffle. */ - private class ShuffleState() { + private class ShuffleState(val numBuckets: Int) { val nextFileId = new AtomicInteger(0) val unusedFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]() val allFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]() + + /** + * The mapIds of all map tasks completed on this Executor for this shuffle. + * NB: This is only populated if consolidateShuffleFiles is FALSE. We don't need it otherwise. + */ + val completedMapTasks = new ConcurrentLinkedQueue[Int]() } type ShuffleId = Int private val shuffleStates = new TimeStampedHashMap[ShuffleId, ShuffleState] - private - val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SHUFFLE_BLOCK_MANAGER, this.cleanup) + private val metadataCleaner = + new MetadataCleaner(MetadataCleanerType.SHUFFLE_BLOCK_MANAGER, this.cleanup, conf) def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer) = { new ShuffleWriterGroup { - shuffleStates.putIfAbsent(shuffleId, new ShuffleState()) + shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets)) private val shuffleState = shuffleStates(shuffleId) private var fileGroup: ShuffleFileGroup = null @@ -109,6 +117,8 @@ class ShuffleBlockManager(blockManager: BlockManager) { fileGroup.recordMapOutput(mapId, offsets) } recycleFileGroup(fileGroup) + } else { + shuffleState.completedMapTasks.add(mapId) } } @@ -154,7 +164,18 @@ class ShuffleBlockManager(blockManager: BlockManager) { } private def cleanup(cleanupTime: Long) { - shuffleStates.clearOldValues(cleanupTime) + shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => { + if (consolidateShuffleFiles) { + for (fileGroup <- state.allFileGroups; file <- fileGroup.files) { + file.delete() + } + } else { + for (mapId <- state.completedMapTasks; reduceId <- 0 until state.numBuckets) { + val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId) + blockManager.diskBlockManager.getFile(blockId).delete() + } + } + }) } } diff --git a/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala b/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala index d52b3d8284..40734aab49 100644 --- a/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala +++ b/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala @@ -56,7 +56,7 @@ object StoragePerfTester { def writeOutputBytes(mapId: Int, total: AtomicLong) = { val shuffle = blockManager.shuffleBlockManager.forMapTask(1, mapId, numOutputSplits, - new KryoSerializer()) + new KryoSerializer(sc.conf)) val writers = shuffle.writers for (i <- 1 to recordsPerMap) { writers(i % numOutputSplits).write(writeData) diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala index a8db37ded1..dca98c6c05 100644 --- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala +++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala @@ -22,6 +22,7 @@ import akka.actor._ import java.util.concurrent.ArrayBlockingQueue import util.Random import org.apache.spark.serializer.KryoSerializer +import org.apache.spark.{SparkConf, SparkContext} /** * This class tests the BlockManager and MemoryStore for thread safety and @@ -91,11 +92,12 @@ private[spark] object ThreadingTest { def main(args: Array[String]) { System.setProperty("spark.kryoserializer.buffer.mb", "1") val actorSystem = ActorSystem("test") - val serializer = new KryoSerializer + val conf = new SparkConf() + val serializer = new KryoSerializer(conf) val blockManagerMaster = new BlockManagerMaster( - Left(actorSystem.actorOf(Props(new BlockManagerMasterActor(true))))) + Left(actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf)))), conf) val blockManager = new BlockManager( - "<driver>", actorSystem, blockManagerMaster, serializer, 1024 * 1024) + "<driver>", actorSystem, blockManagerMaster, serializer, 1024 * 1024, conf) val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i)) val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue)) producers.foreach(_.start) diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index f1d86c0221..50dfdbdf5a 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -1,4 +1,4 @@ -/* +/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. @@ -32,7 +32,7 @@ import org.apache.spark.util.Utils /** Top level user interface for Spark */ private[spark] class SparkUI(sc: SparkContext) extends Logging { val host = Option(System.getenv("SPARK_PUBLIC_DNS")).getOrElse(Utils.localHostName()) - val port = Option(System.getProperty("spark.ui.port")).getOrElse(SparkUI.DEFAULT_PORT).toInt + val port = sc.conf.get("spark.ui.port", SparkUI.DEFAULT_PORT).toInt var boundPort: Option[Int] = None var server: Option[Server] = None diff --git a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala index fcd1b518d0..6ba15187d9 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala @@ -19,7 +19,7 @@ package org.apache.spark.ui import scala.util.Random -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ import org.apache.spark.scheduler.SchedulingMode @@ -27,25 +27,26 @@ import org.apache.spark.scheduler.SchedulingMode /** * Continuously generates jobs that expose various features of the WebUI (internal testing tool). * - * Usage: ./run spark.ui.UIWorkloadGenerator [master] + * Usage: ./bin/spark-class org.apache.spark.ui.UIWorkloadGenerator [master] [FIFO|FAIR] */ private[spark] object UIWorkloadGenerator { + val NUM_PARTITIONS = 100 val INTER_JOB_WAIT_MS = 5000 def main(args: Array[String]) { if (args.length < 2) { - println("usage: ./spark-class org.apache.spark.ui.UIWorkloadGenerator [master] [FIFO|FAIR]") + println("usage: ./bin/spark-class org.apache.spark.ui.UIWorkloadGenerator [master] [FIFO|FAIR]") System.exit(1) } - val master = args(0) - val schedulingMode = SchedulingMode.withName(args(1)) - val appName = "Spark UI Tester" + val conf = new SparkConf().setMaster(args(0)).setAppName("Spark UI tester") + + val schedulingMode = SchedulingMode.withName(args(1)) if (schedulingMode == SchedulingMode.FAIR) { - System.setProperty("spark.scheduler.mode", "FAIR") + conf.set("spark.scheduler.mode", "FAIR") } - val sc = new SparkContext(master, appName) + val sc = new SparkContext(conf) def setProperties(s: String) = { if(schedulingMode == SchedulingMode.FAIR) { @@ -55,11 +56,11 @@ private[spark] object UIWorkloadGenerator { } val baseData = sc.makeRDD(1 to NUM_PARTITIONS * 10, NUM_PARTITIONS) - def nextFloat() = (new Random()).nextFloat() + def nextFloat() = new Random().nextFloat() val jobs = Seq[(String, () => Long)]( ("Count", baseData.count), - ("Cache and Count", baseData.map(x => x).cache.count), + ("Cache and Count", baseData.map(x => x).cache().count), ("Single Shuffle", baseData.map(x => (x % 10, x)).reduceByKey(_ + _).count), ("Entirely failed phase", baseData.map(x => throw new Exception).count), ("Partially failed phase", { diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala index c5bf2acc9e..88f41be8d3 100644 --- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala @@ -48,12 +48,15 @@ private[spark] class EnvironmentUI(sc: SparkContext) { def jvmTable = UIUtils.listingTable(Seq("Name", "Value"), jvmRow, jvmInformation, fixedWidth = true) - val properties = System.getProperties.iterator.toSeq - val classPathProperty = properties.find { case (k, v) => - k.contains("java.class.path") + val sparkProperties = sc.conf.getAll.sorted + + val systemProperties = System.getProperties.iterator.toSeq + val classPathProperty = systemProperties.find { case (k, v) => + k == "java.class.path" }.getOrElse(("", "")) - val sparkProperties = properties.filter(_._1.startsWith("spark")).sorted - val otherProperties = properties.diff(sparkProperties :+ classPathProperty).sorted + val otherProperties = systemProperties.filter { case (k, v) => + k != "java.class.path" && !k.startsWith("spark.") + }.sorted val propertyHeaders = Seq("Name", "Value") def propertyRow(kv: (String, String)) = <tr><td>{kv._1}</td><td>{kv._2}</td></tr> @@ -63,7 +66,7 @@ private[spark] class EnvironmentUI(sc: SparkContext) { UIUtils.listingTable(propertyHeaders, propertyRow, otherProperties, fixedWidth = true) val classPathEntries = classPathProperty._2 - .split(System.getProperty("path.separator", ":")) + .split(sc.conf.get("path.separator", ":")) .filterNot(e => e.isEmpty) .map(e => (e, "System Classpath")) val addedJars = sc.addedJars.iterator.toSeq.map{case (path, time) => (path, "Added By User")} diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala index e596690bc3..a31a7e1d58 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala @@ -56,7 +56,8 @@ private[spark] class ExecutorsUI(val sc: SparkContext) { val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize)).fold(0L)(_+_) val execHead = Seq("Executor ID", "Address", "RDD blocks", "Memory used", "Disk used", - "Active tasks", "Failed tasks", "Complete tasks", "Total tasks") + "Active tasks", "Failed tasks", "Complete tasks", "Total tasks", "Task Time", "Shuffle Read", + "Shuffle Write") def execRow(kv: Seq[String]) = { <tr> @@ -73,6 +74,9 @@ private[spark] class ExecutorsUI(val sc: SparkContext) { <td>{kv(7)}</td> <td>{kv(8)}</td> <td>{kv(9)}</td> + <td>{Utils.msDurationToString(kv(10).toLong)}</td> + <td>{Utils.bytesToString(kv(11).toLong)}</td> + <td>{Utils.bytesToString(kv(12).toLong)}</td> </tr> } @@ -111,6 +115,9 @@ private[spark] class ExecutorsUI(val sc: SparkContext) { val failedTasks = listener.executorToTasksFailed.getOrElse(execId, 0) val completedTasks = listener.executorToTasksComplete.getOrElse(execId, 0) val totalTasks = activeTasks + failedTasks + completedTasks + val totalDuration = listener.executorToDuration.getOrElse(execId, 0) + val totalShuffleRead = listener.executorToShuffleRead.getOrElse(execId, 0) + val totalShuffleWrite = listener.executorToShuffleWrite.getOrElse(execId, 0) Seq( execId, @@ -122,7 +129,10 @@ private[spark] class ExecutorsUI(val sc: SparkContext) { activeTasks.toString, failedTasks.toString, completedTasks.toString, - totalTasks.toString + totalTasks.toString, + totalDuration.toString, + totalShuffleRead.toString, + totalShuffleWrite.toString ) } @@ -130,6 +140,9 @@ private[spark] class ExecutorsUI(val sc: SparkContext) { val executorToTasksActive = HashMap[String, HashSet[TaskInfo]]() val executorToTasksComplete = HashMap[String, Int]() val executorToTasksFailed = HashMap[String, Int]() + val executorToDuration = HashMap[String, Long]() + val executorToShuffleRead = HashMap[String, Long]() + val executorToShuffleWrite = HashMap[String, Long]() override def onTaskStart(taskStart: SparkListenerTaskStart) { val eid = taskStart.taskInfo.executorId @@ -140,6 +153,9 @@ private[spark] class ExecutorsUI(val sc: SparkContext) { override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { val eid = taskEnd.taskInfo.executorId val activeTasks = executorToTasksActive.getOrElseUpdate(eid, new HashSet[TaskInfo]()) + val newDuration = executorToDuration.getOrElse(eid, 0L) + taskEnd.taskInfo.duration + executorToDuration.put(eid, newDuration) + activeTasks -= taskEnd.taskInfo val (failureInfo, metrics): (Option[ExceptionFailure], Option[TaskMetrics]) = taskEnd.reason match { @@ -150,6 +166,17 @@ private[spark] class ExecutorsUI(val sc: SparkContext) { executorToTasksComplete(eid) = executorToTasksComplete.getOrElse(eid, 0) + 1 (None, Option(taskEnd.taskMetrics)) } + + // update shuffle read/write + if (null != taskEnd.taskMetrics) { + taskEnd.taskMetrics.shuffleReadMetrics.foreach(shuffleRead => + executorToShuffleRead.put(eid, executorToShuffleRead.getOrElse(eid, 0L) + + shuffleRead.remoteBytesRead)) + + taskEnd.taskMetrics.shuffleWriteMetrics.foreach(shuffleWrite => + executorToShuffleWrite.put(eid, executorToShuffleWrite.getOrElse(eid, 0L) + + shuffleWrite.shuffleBytesWritten)) + } } } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorSummary.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorSummary.scala new file mode 100644 index 0000000000..3c53e88380 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorSummary.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.jobs + +/** class for reporting aggregated metrics for each executors in stageUI */ +private[spark] class ExecutorSummary { + var taskTime : Long = 0 + var failedTasks : Int = 0 + var succeededTasks : Int = 0 + var shuffleRead : Long = 0 + var shuffleWrite : Long = 0 +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala new file mode 100644 index 0000000000..0dd876480a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.jobs + +import scala.xml.Node + +import org.apache.spark.scheduler.SchedulingMode +import org.apache.spark.util.Utils +import scala.collection.mutable + +/** Page showing executor summary */ +private[spark] class ExecutorTable(val parent: JobProgressUI, val stageId: Int) { + + val listener = parent.listener + val dateFmt = parent.dateFmt + val isFairScheduler = listener.sc.getSchedulingMode == SchedulingMode.FAIR + + def toNodeSeq(): Seq[Node] = { + listener.synchronized { + executorTable() + } + } + + /** Special table which merges two header cells. */ + private def executorTable[T](): Seq[Node] = { + <table class="table table-bordered table-striped table-condensed sortable"> + <thead> + <th>Executor ID</th> + <th>Address</th> + <th>Task Time</th> + <th>Total Tasks</th> + <th>Failed Tasks</th> + <th>Succeeded Tasks</th> + <th>Shuffle Read</th> + <th>Shuffle Write</th> + </thead> + <tbody> + {createExecutorTable()} + </tbody> + </table> + } + + private def createExecutorTable() : Seq[Node] = { + // make a executor-id -> address map + val executorIdToAddress = mutable.HashMap[String, String]() + val storageStatusList = parent.sc.getExecutorStorageStatus + for (statusId <- 0 until storageStatusList.size) { + val blockManagerId = parent.sc.getExecutorStorageStatus(statusId).blockManagerId + val address = blockManagerId.hostPort + val executorId = blockManagerId.executorId + executorIdToAddress.put(executorId, address) + } + + val executorIdToSummary = listener.stageIdToExecutorSummaries.get(stageId) + executorIdToSummary match { + case Some(x) => { + x.toSeq.sortBy(_._1).map{ + case (k,v) => { + <tr> + <td>{k}</td> + <td>{executorIdToAddress.getOrElse(k, "CANNOT FIND ADDRESS")}</td> + <td>{parent.formatDuration(v.taskTime)}</td> + <td>{v.failedTasks + v.succeededTasks}</td> + <td>{v.failedTasks}</td> + <td>{v.succeededTasks}</td> + <td>{Utils.bytesToString(v.shuffleRead)}</td> + <td>{Utils.bytesToString(v.shuffleWrite)}</td> + </tr> + } + } + } + case _ => { Seq[Node]() } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 6b854740d6..b7b87250b9 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -33,7 +33,7 @@ import org.apache.spark.scheduler._ */ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkListener { // How many stages to remember - val RETAINED_STAGES = System.getProperty("spark.ui.retained_stages", "1000").toInt + val RETAINED_STAGES = sc.conf.get("spark.ui.retained_stages", "1000").toInt val DEFAULT_POOL_NAME = "default" val stageIdToPool = new HashMap[Int, String]() @@ -57,10 +57,11 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList val stageIdToTasksFailed = HashMap[Int, Int]() val stageIdToTaskInfos = HashMap[Int, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]]() + val stageIdToExecutorSummaries = HashMap[Int, HashMap[String, ExecutorSummary]]() override def onJobStart(jobStart: SparkListenerJobStart) {} - override def onStageCompleted(stageCompleted: StageCompleted) = synchronized { + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) = synchronized { val stage = stageCompleted.stage poolToActiveStages(stageIdToPool(stage.stageId)) -= stage activeStages -= stage @@ -105,7 +106,7 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashSet[StageInfo]()) stages += stage } - + override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized { val sid = taskStart.task.stageId val tasksActive = stageIdToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]()) @@ -124,8 +125,38 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { val sid = taskEnd.task.stageId + + // create executor summary map if necessary + val executorSummaryMap = stageIdToExecutorSummaries.getOrElseUpdate(key = sid, + op = new HashMap[String, ExecutorSummary]()) + executorSummaryMap.getOrElseUpdate(key = taskEnd.taskInfo.executorId, + op = new ExecutorSummary()) + + val executorSummary = executorSummaryMap.get(taskEnd.taskInfo.executorId) + executorSummary match { + case Some(y) => { + // first update failed-task, succeed-task + taskEnd.reason match { + case Success => + y.succeededTasks += 1 + case _ => + y.failedTasks += 1 + } + + // update duration + y.taskTime += taskEnd.taskInfo.duration + + Option(taskEnd.taskMetrics).foreach { taskMetrics => + taskMetrics.shuffleReadMetrics.foreach { y.shuffleRead += _.remoteBytesRead } + taskMetrics.shuffleWriteMetrics.foreach { y.shuffleWrite += _.shuffleBytesWritten } + } + } + case _ => {} + } + val tasksActive = stageIdToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]()) tasksActive -= taskEnd.taskInfo + val (failureInfo, metrics): (Option[ExceptionFailure], Option[TaskMetrics]) = taskEnd.reason match { case e: ExceptionFailure => diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 996e1b4d1a..8dcfeacb60 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -66,7 +66,7 @@ private[spark] class StagePage(parent: JobProgressUI) { <div> <ul class="unstyled"> <li> - <strong>Total duration across all tasks: </strong> + <strong>Total task time across all tasks: </strong> {parent.formatDuration(listener.stageIdToTime.getOrElse(stageId, 0L) + activeTime)} </li> {if (hasShuffleRead) @@ -166,11 +166,12 @@ private[spark] class StagePage(parent: JobProgressUI) { def quantileRow(data: Seq[String]): Seq[Node] = <tr> {data.map(d => <td>{d}</td>)} </tr> Some(listingTable(quantileHeaders, quantileRow, listings, fixedWidth = true)) } - + val executorTable = new ExecutorTable(parent, stageId) val content = summary ++ <h4>Summary Metrics for {numCompleted} Completed Tasks</h4> ++ <div>{summaryTable.getOrElse("No tasks have reported metrics yet.")}</div> ++ + <h4>Aggregated Metrics by Executors</h4> ++ executorTable.toNodeSeq() ++ <h4>Tasks</h4> ++ taskTable headerSparkPage(content, parent.sc, "Details for Stage %d".format(stageId), Stages) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 9ad6de3c6d..463d85dfd5 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -48,7 +48,7 @@ private[spark] class StageTable(val stages: Seq[StageInfo], val parent: JobProgr {if (isFairScheduler) {<th>Pool Name</th>} else {}} <th>Description</th> <th>Submitted</th> - <th>Duration</th> + <th>Task Time</th> <th>Tasks: Succeeded/Total</th> <th>Shuffle Read</th> <th>Shuffle Write</th> diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index 1c8b51b8bc..7df7e3d8e5 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -21,6 +21,9 @@ import scala.concurrent.duration.{Duration, FiniteDuration} import akka.actor.{ActorSystem, ExtendedActorSystem, IndestructibleActorSystem} import com.typesafe.config.ConfigFactory +import org.apache.log4j.{Level, Logger} + +import org.apache.spark.SparkConf /** * Various utility classes for working with Akka. @@ -37,22 +40,29 @@ private[spark] object AkkaUtils { * If indestructible is set to true, the Actor System will continue running in the event * of a fatal exception. This is used by [[org.apache.spark.executor.Executor]]. */ - def createActorSystem(name: String, host: String, port: Int, indestructible: Boolean = false) - : (ActorSystem, Int) = { + def createActorSystem(name: String, host: String, port: Int, indestructible: Boolean = false, + conf: SparkConf): (ActorSystem, Int) = { + + val akkaThreads = conf.get("spark.akka.threads", "4").toInt + val akkaBatchSize = conf.get("spark.akka.batchSize", "15").toInt - val akkaThreads = System.getProperty("spark.akka.threads", "4").toInt - val akkaBatchSize = System.getProperty("spark.akka.batchSize", "15").toInt + val akkaTimeout = conf.get("spark.akka.timeout", "100").toInt - val akkaTimeout = System.getProperty("spark.akka.timeout", "100").toInt + val akkaFrameSize = conf.get("spark.akka.frameSize", "10").toInt + val akkaLogLifecycleEvents = conf.get("spark.akka.logLifecycleEvents", "false").toBoolean + val lifecycleEvents = if (akkaLogLifecycleEvents) "on" else "off" + if (!akkaLogLifecycleEvents) { + // As a workaround for Akka issue #3787, we coerce the "EndpointWriter" log to be silent. + // See: https://www.assembla.com/spaces/akka/tickets/3787#/ + Option(Logger.getLogger("akka.remote.EndpointWriter")).map(l => l.setLevel(Level.FATAL)) + } - val akkaFrameSize = System.getProperty("spark.akka.frameSize", "10").toInt - val lifecycleEvents = - if (System.getProperty("spark.akka.logLifecycleEvents", "false").toBoolean) "on" else "off" + val logAkkaConfig = if (conf.get("spark.akka.logAkkaConfig", "false").toBoolean) "on" else "off" - val akkaHeartBeatPauses = System.getProperty("spark.akka.heartbeat.pauses", "600").toInt + val akkaHeartBeatPauses = conf.get("spark.akka.heartbeat.pauses", "600").toInt val akkaFailureDetector = - System.getProperty("spark.akka.failure-detector.threshold", "300.0").toDouble - val akkaHeartBeatInterval = System.getProperty("spark.akka.heartbeat.interval", "1000").toInt + conf.get("spark.akka.failure-detector.threshold", "300.0").toDouble + val akkaHeartBeatInterval = conf.get("spark.akka.heartbeat.interval", "1000").toInt val akkaConf = ConfigFactory.parseString( s""" @@ -72,7 +82,10 @@ private[spark] object AkkaUtils { |akka.remote.netty.tcp.maximum-frame-size = ${akkaFrameSize}MiB |akka.remote.netty.tcp.execution-pool-size = $akkaThreads |akka.actor.default-dispatcher.throughput = $akkaBatchSize + |akka.log-config-on-start = $logAkkaConfig |akka.remote.log-remote-lifecycle-events = $lifecycleEvents + |akka.log-dead-letters = $lifecycleEvents + |akka.log-dead-letters-during-shutdown = $lifecycleEvents """.stripMargin) val actorSystem = if (indestructible) { @@ -87,7 +100,7 @@ private[spark] object AkkaUtils { } /** Returns the default Spark timeout to use for Akka ask operations. */ - def askTimeout: FiniteDuration = { - Duration.create(System.getProperty("spark.akka.askTimeout", "30").toLong, "seconds") + def askTimeout(conf: SparkConf): FiniteDuration = { + Duration.create(conf.get("spark.akka.askTimeout", "30").toLong, "seconds") } } diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala index 7b41ef89f1..aa7f52cafb 100644 --- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala @@ -18,16 +18,21 @@ package org.apache.spark.util import java.util.{TimerTask, Timer} -import org.apache.spark.Logging +import org.apache.spark.{SparkConf, SparkContext, Logging} /** * Runs a timer task to periodically clean up metadata (e.g. old files or hashtable entries) */ -class MetadataCleaner(cleanerType: MetadataCleanerType.MetadataCleanerType, cleanupFunc: (Long) => Unit) extends Logging { +class MetadataCleaner( + cleanerType: MetadataCleanerType.MetadataCleanerType, + cleanupFunc: (Long) => Unit, + conf: SparkConf) + extends Logging +{ val name = cleanerType.toString - private val delaySeconds = MetadataCleaner.getDelaySeconds + private val delaySeconds = MetadataCleaner.getDelaySeconds(conf, cleanerType) private val periodSeconds = math.max(10, delaySeconds / 10) private val timer = new Timer(name + " cleanup timer", true) @@ -65,22 +70,28 @@ object MetadataCleanerType extends Enumeration { def systemProperty(which: MetadataCleanerType.MetadataCleanerType) = "spark.cleaner.ttl." + which.toString } +// TODO: This mutates a Conf to set properties right now, which is kind of ugly when used in the +// initialization of StreamingContext. It's okay for users trying to configure stuff themselves. object MetadataCleaner { + def getDelaySeconds(conf: SparkConf) = { + conf.get("spark.cleaner.ttl", "3500").toInt + } - // using only sys props for now : so that workers can also get to it while preserving earlier behavior. - def getDelaySeconds = System.getProperty("spark.cleaner.ttl", "-1").toInt - - def getDelaySeconds(cleanerType: MetadataCleanerType.MetadataCleanerType): Int = { - System.getProperty(MetadataCleanerType.systemProperty(cleanerType), getDelaySeconds.toString).toInt + def getDelaySeconds(conf: SparkConf, cleanerType: MetadataCleanerType.MetadataCleanerType): Int = + { + conf.get(MetadataCleanerType.systemProperty(cleanerType), getDelaySeconds(conf).toString) + .toInt } - def setDelaySeconds(cleanerType: MetadataCleanerType.MetadataCleanerType, delay: Int) { - System.setProperty(MetadataCleanerType.systemProperty(cleanerType), delay.toString) + def setDelaySeconds(conf: SparkConf, cleanerType: MetadataCleanerType.MetadataCleanerType, + delay: Int) + { + conf.set(MetadataCleanerType.systemProperty(cleanerType), delay.toString) } - def setDelaySeconds(delay: Int, resetAll: Boolean = true) { + def setDelaySeconds(conf: SparkConf, delay: Int, resetAll: Boolean = true) { // override for all ? - System.setProperty("spark.cleaner.ttl", delay.toString) + conf.set("spark.cleaner.ttl", delay.toString) if (resetAll) { for (cleanerType <- MetadataCleanerType.values) { System.clearProperty(MetadataCleanerType.systemProperty(cleanerType)) diff --git a/core/src/main/scala/org/apache/spark/util/SerializableHyperLogLog.scala b/core/src/main/scala/org/apache/spark/util/SerializableHyperLogLog.scala new file mode 100644 index 0000000000..8b4e7c104c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/SerializableHyperLogLog.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io.{Externalizable, ObjectOutput, ObjectInput} +import com.clearspring.analytics.stream.cardinality.{ICardinality, HyperLogLog} + +/** + * A wrapper around [[com.clearspring.analytics.stream.cardinality.HyperLogLog]] that is serializable. + */ +private[spark] +class SerializableHyperLogLog(var value: ICardinality) extends Externalizable { + + def this() = this(null) // For deserialization + + def merge(other: SerializableHyperLogLog) = new SerializableHyperLogLog(value.merge(other.value)) + + def add[T](elem: T) = { + this.value.offer(elem) + this + } + + def readExternal(in: ObjectInput) { + val byteLength = in.readInt() + val bytes = new Array[Byte](byteLength) + in.readFully(bytes) + value = HyperLogLog.Builder.build(bytes) + } + + def writeExternal(out: ObjectOutput) { + val bytes = value.getBytes() + out.writeInt(bytes.length) + out.write(bytes) + } +} diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index a25b37a2a9..bddb3bb735 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -30,10 +30,10 @@ import java.lang.management.ManagementFactory import scala.collection.mutable.ArrayBuffer import it.unimi.dsi.fastutil.ints.IntOpenHashSet -import org.apache.spark.Logging +import org.apache.spark.{SparkEnv, SparkConf, SparkContext, Logging} /** - * Estimates the sizes of Java objects (number of bytes of memory they occupy), for use in + * Estimates the sizes of Java objects (number of bytes of memory they occupy), for use in * memory-aware caches. * * Based on the following JavaWorld article: @@ -89,9 +89,11 @@ private[spark] object SizeEstimator extends Logging { classInfos.put(classOf[Object], new ClassInfo(objectSize, Nil)) } - private def getIsCompressedOops : Boolean = { + private def getIsCompressedOops: Boolean = { + // This is only used by tests to override the detection of compressed oops. The test + // actually uses a system property instead of a SparkConf, so we'll stick with that. if (System.getProperty("spark.test.useCompressedOops") != null) { - return System.getProperty("spark.test.useCompressedOops").toBoolean + return System.getProperty("spark.test.useCompressedOops").toBoolean } try { @@ -103,7 +105,7 @@ private[spark] object SizeEstimator extends Logging { val getVMMethod = hotSpotMBeanClass.getDeclaredMethod("getVMOption", Class.forName("java.lang.String")) - val bean = ManagementFactory.newPlatformMXBeanProxy(server, + val bean = ManagementFactory.newPlatformMXBeanProxy(server, hotSpotMBeanName, hotSpotMBeanClass) // TODO: We could use reflection on the VMOption returned ? return getVMMethod.invoke(bean, "UseCompressedOops").toString.contains("true") @@ -251,7 +253,7 @@ private[spark] object SizeEstimator extends Logging { if (info != null) { return info } - + val parent = getClassInfo(cls.getSuperclass) var shellSize = parent.shellSize var pointerFields = parent.pointerFields diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala index dbff571de9..181ae2fd45 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala @@ -104,19 +104,28 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() with Logging { def toMap: immutable.Map[A, B] = iterator.toMap /** - * Removes old key-value pairs that have timestamp earlier than `threshTime` + * Removes old key-value pairs that have timestamp earlier than `threshTime`, + * calling the supplied function on each such entry before removing. */ - def clearOldValues(threshTime: Long) { + def clearOldValues(threshTime: Long, f: (A, B) => Unit) { val iterator = internalMap.entrySet().iterator() - while(iterator.hasNext) { + while (iterator.hasNext) { val entry = iterator.next() if (entry.getValue._2 < threshTime) { + f(entry.getKey, entry.getValue._1) logDebug("Removing key " + entry.getKey) iterator.remove() } } } + /** + * Removes old key-value pairs that have timestamp earlier than `threshTime` + */ + def clearOldValues(threshTime: Long) { + clearOldValues(threshTime, (_, _) => ()) + } + private def currentTime: Long = System.currentTimeMillis() } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 3f7858d2de..5f1253100b 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -36,14 +36,13 @@ import org.apache.hadoop.fs.{Path, FileSystem, FileUtil} import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} import org.apache.spark.deploy.SparkHadoopUtil import java.nio.ByteBuffer -import org.apache.spark.{SparkException, Logging} +import org.apache.spark.{SparkConf, SparkContext, SparkException, Logging} /** * Various utility methods used by Spark. */ private[spark] object Utils extends Logging { - /** Serialize an object using Java serialization */ def serialize[T](o: T): Array[Byte] = { val bos = new ByteArrayOutputStream() @@ -239,9 +238,9 @@ private[spark] object Utils extends Logging { * Throws SparkException if the target file already exists and has different contents than * the requested file. */ - def fetchFile(url: String, targetDir: File) { + def fetchFile(url: String, targetDir: File, conf: SparkConf) { val filename = url.split("/").last - val tempDir = getLocalDir + val tempDir = getLocalDir(conf) val tempFile = File.createTempFile("fetchFileTemp", null, new File(tempDir)) val targetFile = new File(targetDir, filename) val uri = new URI(url) @@ -311,8 +310,8 @@ private[spark] object Utils extends Logging { * return a single directory, even though the spark.local.dir property might be a list of * multiple paths. */ - def getLocalDir: String = { - System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")).split(',')(0) + def getLocalDir(conf: SparkConf): String = { + conf.get("spark.local.dir", System.getProperty("java.io.tmpdir")).split(',')(0) } /** @@ -397,13 +396,12 @@ private[spark] object Utils extends Logging { InetAddress.getByName(address).getHostName } - def localHostPort(): String = { - val retval = System.getProperty("spark.hostPort", null) + def localHostPort(conf: SparkConf): String = { + val retval = conf.get("spark.hostPort", null) if (retval == null) { logErrorWithStack("spark.hostPort not set but invoking localHostPort") return localHostName() } - retval } @@ -415,9 +413,12 @@ private[spark] object Utils extends Logging { assert(hostPort.indexOf(':') != -1, message) } - // Used by DEBUG code : remove when all testing done def logErrorWithStack(msg: String) { - try { throw new Exception } catch { case ex: Exception => { logError(msg, ex) } } + try { + throw new Exception + } catch { + case ex: Exception => logError(msg, ex) + } } // Typically, this will be of order of number of nodes in cluster @@ -837,7 +838,7 @@ private[spark] object Utils extends Logging { } } - /** + /** * Timing method based on iterations that permit JVM JIT optimization. * @param numIters number of iterations * @param f function to be executed diff --git a/core/src/test/resources/spark.conf b/core/src/test/resources/spark.conf new file mode 100644 index 0000000000..aa4e751235 --- /dev/null +++ b/core/src/test/resources/spark.conf @@ -0,0 +1,8 @@ +# A simple spark.conf file used only in our unit tests + +spark.test.intTestProperty = 1 + +spark.test { + stringTestProperty = "hi" + listTestProperty = ["a", "b"] +} diff --git a/core/src/test/resources/uncommons-maths-1.2.2.jar b/core/src/test/resources/uncommons-maths-1.2.2.jar Binary files differdeleted file mode 100644 index e126001c1c..0000000000 --- a/core/src/test/resources/uncommons-maths-1.2.2.jar +++ /dev/null diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index f25d921d3f..ec13b329b2 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -26,8 +26,6 @@ import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} import org.apache.spark.util.Utils class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { - initLogging() - var checkpointDir: File = _ val partitioner = new HashPartitioner(2) @@ -57,15 +55,15 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { } test("RDDs with one-to-one dependencies") { - testCheckpointing(_.map(x => x.toString)) - testCheckpointing(_.flatMap(x => 1 to x)) - testCheckpointing(_.filter(_ % 2 == 0)) - testCheckpointing(_.sample(false, 0.5, 0)) - testCheckpointing(_.glom()) - testCheckpointing(_.mapPartitions(_.map(_.toString))) - testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString)) - testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x)) - testCheckpointing(_.pipe(Seq("cat"))) + testRDD(_.map(x => x.toString)) + testRDD(_.flatMap(x => 1 to x)) + testRDD(_.filter(_ % 2 == 0)) + testRDD(_.sample(false, 0.5, 0)) + testRDD(_.glom()) + testRDD(_.mapPartitions(_.map(_.toString))) + testRDD(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString)) + testRDD(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x)) + testRDD(_.pipe(Seq("cat"))) } test("ParallelCollection") { @@ -97,7 +95,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { } test("ShuffledRDD") { - testCheckpointing(rdd => { + testRDD(rdd => { // Creating ShuffledRDD directly as PairRDDFunctions.combineByKey produces a MapPartitionedRDD new ShuffledRDD[Int, Int, (Int, Int)](rdd.map(x => (x % 2, 1)), partitioner) }) @@ -105,25 +103,17 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { test("UnionRDD") { def otherRDD = sc.makeRDD(1 to 10, 1) - - // Test whether the size of UnionRDDPartitions reduce in size after parent RDD is checkpointed. - // Current implementation of UnionRDD has transient reference to parent RDDs, - // so only the partitions will reduce in serialized size, not the RDD. - testCheckpointing(_.union(otherRDD), false, true) - testParentCheckpointing(_.union(otherRDD), false, true) + testRDD(_.union(otherRDD)) + testRDDPartitions(_.union(otherRDD)) } test("CartesianRDD") { def otherRDD = sc.makeRDD(1 to 10, 1) - testCheckpointing(new CartesianRDD(sc, _, otherRDD)) - - // Test whether size of CoalescedRDD reduce in size after parent RDD is checkpointed - // Current implementation of CoalescedRDDPartition has transient reference to parent RDD, - // so only the RDD will reduce in serialized size, not the partitions. - testParentCheckpointing(new CartesianRDD(sc, _, otherRDD), true, false) + testRDD(new CartesianRDD(sc, _, otherRDD)) + testRDDPartitions(new CartesianRDD(sc, _, otherRDD)) // Test that the CartesianRDD updates parent partitions (CartesianRDD.s1/s2) after - // the parent RDD has been checkpointed and parent partitions have been changed to HadoopPartitions. + // the parent RDD has been checkpointed and parent partitions have been changed. // Note that this test is very specific to the current implementation of CartesianRDD. val ones = sc.makeRDD(1 to 100, 10).map(x => x) ones.checkpoint() // checkpoint that MappedRDD @@ -134,23 +124,20 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { val splitAfterCheckpoint = serializeDeserialize(cartesian.partitions.head.asInstanceOf[CartesianPartition]) assert( - (splitAfterCheckpoint.s1 != splitBeforeCheckpoint.s1) && - (splitAfterCheckpoint.s2 != splitBeforeCheckpoint.s2), - "CartesianRDD.parents not updated after parent RDD checkpointed" + (splitAfterCheckpoint.s1.getClass != splitBeforeCheckpoint.s1.getClass) && + (splitAfterCheckpoint.s2.getClass != splitBeforeCheckpoint.s2.getClass), + "CartesianRDD.s1 and CartesianRDD.s2 not updated after parent RDD is checkpointed" ) } test("CoalescedRDD") { - testCheckpointing(_.coalesce(2)) - - // Test whether size of CoalescedRDD reduce in size after parent RDD is checkpointed - // Current implementation of CoalescedRDDPartition has transient reference to parent RDD, - // so only the RDD will reduce in serialized size, not the partitions. - testParentCheckpointing(_.coalesce(2), true, false) + testRDD(_.coalesce(2)) + testRDDPartitions(_.coalesce(2)) - // Test that the CoalescedRDDPartition updates parent partitions (CoalescedRDDPartition.parents) after - // the parent RDD has been checkpointed and parent partitions have been changed to HadoopPartitions. - // Note that this test is very specific to the current implementation of CoalescedRDDPartitions + // Test that the CoalescedRDDPartition updates parent partitions (CoalescedRDDPartition.parents) + // after the parent RDD has been checkpointed and parent partitions have been changed. + // Note that this test is very specific to the current implementation of + // CoalescedRDDPartitions. val ones = sc.makeRDD(1 to 100, 10).map(x => x) ones.checkpoint() // checkpoint that MappedRDD val coalesced = new CoalescedRDD(ones, 2) @@ -160,33 +147,78 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { val splitAfterCheckpoint = serializeDeserialize(coalesced.partitions.head.asInstanceOf[CoalescedRDDPartition]) assert( - splitAfterCheckpoint.parents.head != splitBeforeCheckpoint.parents.head, - "CoalescedRDDPartition.parents not updated after parent RDD checkpointed" + splitAfterCheckpoint.parents.head.getClass != splitBeforeCheckpoint.parents.head.getClass, + "CoalescedRDDPartition.parents not updated after parent RDD is checkpointed" ) } test("CoGroupedRDD") { - val longLineageRDD1 = generateLongLineageRDDForCoGroupedRDD() - testCheckpointing(rdd => { + val longLineageRDD1 = generateFatPairRDD() + testRDD(rdd => { CheckpointSuite.cogroup(longLineageRDD1, rdd.map(x => (x % 2, 1)), partitioner) - }, false, true) + }) - val longLineageRDD2 = generateLongLineageRDDForCoGroupedRDD() - testParentCheckpointing(rdd => { + val longLineageRDD2 = generateFatPairRDD() + testRDDPartitions(rdd => { CheckpointSuite.cogroup( longLineageRDD2, sc.makeRDD(1 to 2, 2).map(x => (x % 2, 1)), partitioner) - }, false, true) + }) } test("ZippedRDD") { - testCheckpointing( - rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)), true, false) - - // Test whether size of ZippedRDD reduce in size after parent RDD is checkpointed - // Current implementation of ZippedRDDPartitions has transient references to parent RDDs, - // so only the RDD will reduce in serialized size, not the partitions. - testParentCheckpointing( - rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)), true, false) + testRDD(rdd => new ZippedRDD(sc, rdd, rdd.map(x => x))) + testRDDPartitions(rdd => new ZippedRDD(sc, rdd, rdd.map(x => x))) + + // Test that the ZippedPartition updates parent partitions + // after the parent RDD has been checkpointed and parent partitions have been changed. + // Note that this test is very specific to the current implementation of ZippedRDD. + val rdd = generateFatRDD() + val zippedRDD = new ZippedRDD(sc, rdd, rdd.map(x => x)) + zippedRDD.rdd1.checkpoint() + zippedRDD.rdd2.checkpoint() + val partitionBeforeCheckpoint = + serializeDeserialize(zippedRDD.partitions.head.asInstanceOf[ZippedPartition[_, _]]) + zippedRDD.count() + val partitionAfterCheckpoint = + serializeDeserialize(zippedRDD.partitions.head.asInstanceOf[ZippedPartition[_, _]]) + assert( + partitionAfterCheckpoint.partition1.getClass != partitionBeforeCheckpoint.partition1.getClass && + partitionAfterCheckpoint.partition2.getClass != partitionBeforeCheckpoint.partition2.getClass, + "ZippedRDD.partition1 and ZippedRDD.partition2 not updated after parent RDD is checkpointed" + ) + } + + test("PartitionerAwareUnionRDD") { + testRDD(rdd => { + new PartitionerAwareUnionRDD[(Int, Int)](sc, Array( + generateFatPairRDD(), + rdd.map(x => (x % 2, 1)).reduceByKey(partitioner, _ + _) + )) + }) + + testRDDPartitions(rdd => { + new PartitionerAwareUnionRDD[(Int, Int)](sc, Array( + generateFatPairRDD(), + rdd.map(x => (x % 2, 1)).reduceByKey(partitioner, _ + _) + )) + }) + + // Test that the PartitionerAwareUnionRDD updates parent partitions + // (PartitionerAwareUnionRDD.parents) after the parent RDD has been checkpointed and parent + // partitions have been changed. Note that this test is very specific to the current + // implementation of PartitionerAwareUnionRDD. + val pairRDD = generateFatPairRDD() + pairRDD.checkpoint() + val unionRDD = new PartitionerAwareUnionRDD(sc, Array(pairRDD)) + val partitionBeforeCheckpoint = serializeDeserialize( + unionRDD.partitions.head.asInstanceOf[PartitionerAwareUnionRDDPartition]) + pairRDD.count() + val partitionAfterCheckpoint = serializeDeserialize( + unionRDD.partitions.head.asInstanceOf[PartitionerAwareUnionRDDPartition]) + assert( + partitionBeforeCheckpoint.parents.head.getClass != partitionAfterCheckpoint.parents.head.getClass, + "PartitionerAwareUnionRDDPartition.parents not updated after parent RDD is checkpointed" + ) } test("CheckpointRDD with zero partitions") { @@ -200,29 +232,32 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { } /** - * Test checkpointing of the final RDD generated by the given operation. By default, - * this method tests whether the size of serialized RDD has reduced after checkpointing or not. - * It can also test whether the size of serialized RDD partitions has reduced after checkpointing or - * not, but this is not done by default as usually the partitions do not refer to any RDD and - * therefore never store the lineage. + * Test checkpointing of the RDD generated by the given operation. It tests whether the + * serialized size of the RDD is reduce after checkpointing or not. This function should be called + * on all RDDs that have a parent RDD (i.e., do not call on ParallelCollection, BlockRDD, etc.). */ - def testCheckpointing[U: ClassTag]( - op: (RDD[Int]) => RDD[U], - testRDDSize: Boolean = true, - testRDDPartitionSize: Boolean = false - ) { + def testRDD[U: ClassTag](op: (RDD[Int]) => RDD[U]) { // Generate the final RDD using given RDD operation - val baseRDD = generateLongLineageRDD() + val baseRDD = generateFatRDD() val operatedRDD = op(baseRDD) val parentRDD = operatedRDD.dependencies.headOption.orNull val rddType = operatedRDD.getClass.getSimpleName val numPartitions = operatedRDD.partitions.length + // Force initialization of all the data structures in RDDs + // Without this, serializing the RDD will give a wrong estimate of the size of the RDD + initializeRdd(operatedRDD) + + val partitionsBeforeCheckpoint = operatedRDD.partitions + // Find serialized sizes before and after the checkpoint - val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) + logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) + val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) operatedRDD.checkpoint() val result = operatedRDD.collect() - val (rddSizeAfterCheckpoint, splitSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) + operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables + val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) + logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) // Test whether the checkpoint file has been created assert(sc.checkpointFile[U](operatedRDD.getCheckpointFile.get).collect() === result) @@ -230,6 +265,9 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { // Test whether dependencies have been changed from its earlier parent RDD assert(operatedRDD.dependencies.head.rdd != parentRDD) + // Test whether the partitions have been changed from its earlier partitions + assert(operatedRDD.partitions.toList != partitionsBeforeCheckpoint.toList) + // Test whether the partitions have been changed to the new Hadoop partitions assert(operatedRDD.partitions.toList === operatedRDD.checkpointData.get.getPartitions.toList) @@ -239,122 +277,72 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { // Test whether the data in the checkpointed RDD is same as original assert(operatedRDD.collect() === result) - // Test whether serialized size of the RDD has reduced. If the RDD - // does not have any dependency to another RDD (e.g., ParallelCollection, - // ShuffleRDD with ShuffleDependency), it may not reduce in size after checkpointing. - if (testRDDSize) { - logInfo("Size of " + rddType + - "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]") - assert( - rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint, - "Size of " + rddType + " did not reduce after checkpointing " + - "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]" - ) - } + // Test whether serialized size of the RDD has reduced. + logInfo("Size of " + rddType + + " [" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]") + assert( + rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint, + "Size of " + rddType + " did not reduce after checkpointing " + + " [" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]" + ) - // Test whether serialized size of the partitions has reduced. If the partitions - // do not have any non-transient reference to another RDD or another RDD's partitions, it - // does not refer to a lineage and therefore may not reduce in size after checkpointing. - // However, if the original partitions before checkpointing do refer to a parent RDD, the partitions - // must be forgotten after checkpointing (to remove all reference to parent RDDs) and - // replaced with the HadooPartitions of the checkpointed RDD. - if (testRDDPartitionSize) { - logInfo("Size of " + rddType + " partitions " - + "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]") - assert( - splitSizeAfterCheckpoint < splitSizeBeforeCheckpoint, - "Size of " + rddType + " partitions did not reduce after checkpointing " + - "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]" - ) - } } /** * Test whether checkpointing of the parent of the generated RDD also * truncates the lineage or not. Some RDDs like CoGroupedRDD hold on to its parent * RDDs partitions. So even if the parent RDD is checkpointed and its partitions changed, - * this RDD will remember the partitions and therefore potentially the whole lineage. + * the generated RDD will remember the partitions and therefore potentially the whole lineage. + * This function should be called only those RDD whose partitions refer to parent RDD's + * partitions (i.e., do not call it on simple RDD like MappedRDD). + * */ - def testParentCheckpointing[U: ClassTag]( - op: (RDD[Int]) => RDD[U], - testRDDSize: Boolean, - testRDDPartitionSize: Boolean - ) { + def testRDDPartitions[U: ClassTag](op: (RDD[Int]) => RDD[U]) { // Generate the final RDD using given RDD operation - val baseRDD = generateLongLineageRDD() + val baseRDD = generateFatRDD() val operatedRDD = op(baseRDD) - val parentRDD = operatedRDD.dependencies.head.rdd + val parentRDDs = operatedRDD.dependencies.map(_.rdd) val rddType = operatedRDD.getClass.getSimpleName - val parentRDDType = parentRDD.getClass.getSimpleName - // Get the partitions and dependencies of the parent in case they're lazily computed - parentRDD.dependencies - parentRDD.partitions + // Force initialization of all the data structures in RDDs + // Without this, serializing the RDD will give a wrong estimate of the size of the RDD + initializeRdd(operatedRDD) // Find serialized sizes before and after the checkpoint - val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) - parentRDD.checkpoint() // checkpoint the parent RDD, not the generated one - val result = operatedRDD.collect() - val (rddSizeAfterCheckpoint, splitSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) + logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) + val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) + parentRDDs.foreach(_.checkpoint()) // checkpoint the parent RDD, not the generated one + val result = operatedRDD.collect() // force checkpointing + operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables + val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) + logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) // Test whether the data in the checkpointed RDD is same as original assert(operatedRDD.collect() === result) - // Test whether serialized size of the RDD has reduced because of its parent being - // checkpointed. If this RDD or its parent RDD do not have any dependency - // to another RDD (e.g., ParallelCollection, ShuffleRDD with ShuffleDependency), it may - // not reduce in size after checkpointing. - if (testRDDSize) { - assert( - rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint, - "Size of " + rddType + " did not reduce after checkpointing parent " + parentRDDType + - "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]" - ) - } - - // Test whether serialized size of the partitions has reduced because of its parent being - // checkpointed. If the partitions do not have any non-transient reference to another RDD - // or another RDD's partitions, it does not refer to a lineage and therefore may not reduce - // in size after checkpointing. However, if the partitions do refer to the *partitions* of a parent - // RDD, then these partitions must update reference to the parent RDD partitions as the parent RDD's - // partitions must have changed after checkpointing. - if (testRDDPartitionSize) { - assert( - splitSizeAfterCheckpoint < splitSizeBeforeCheckpoint, - "Size of " + rddType + " partitions did not reduce after checkpointing parent " + parentRDDType + - "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]" - ) - } - + // Test whether serialized size of the partitions has reduced + logInfo("Size of partitions of " + rddType + + " [" + partitionSizeBeforeCheckpoint + " --> " + partitionSizeAfterCheckpoint + "]") + assert( + partitionSizeAfterCheckpoint < partitionSizeBeforeCheckpoint, + "Size of " + rddType + " partitions did not reduce after checkpointing parent RDDs" + + " [" + partitionSizeBeforeCheckpoint + " --> " + partitionSizeAfterCheckpoint + "]" + ) } /** - * Generate an RDD with a long lineage of one-to-one dependencies. + * Generate an RDD such that both the RDD and its partitions have large size. */ - def generateLongLineageRDD(): RDD[Int] = { - var rdd = sc.makeRDD(1 to 100, 4) - for (i <- 1 to 50) { - rdd = rdd.map(x => x + 1) - } - rdd + def generateFatRDD(): RDD[Int] = { + new FatRDD(sc.makeRDD(1 to 100, 4)).map(x => x) } /** - * Generate an RDD with a long lineage specifically for CoGroupedRDD. - * A CoGroupedRDD can have a long lineage only one of its parents have a long lineage - * and narrow dependency with this RDD. This method generate such an RDD by a sequence - * of cogroups and mapValues which creates a long lineage of narrow dependencies. + * Generate an pair RDD (with partitioner) such that both the RDD and its partitions + * have large size. */ - def generateLongLineageRDDForCoGroupedRDD() = { - val add = (x: (Seq[Int], Seq[Int])) => (x._1 ++ x._2).reduce(_ + _) - - def ones: RDD[(Int, Int)] = sc.makeRDD(1 to 2, 2).map(x => (x % 2, 1)).reduceByKey(partitioner, _ + _) - - var cogrouped: RDD[(Int, (Seq[Int], Seq[Int]))] = ones.cogroup(ones) - for(i <- 1 to 10) { - cogrouped = cogrouped.mapValues(add).cogroup(ones) - } - cogrouped.mapValues(add) + def generateFatPairRDD() = { + new FatPairRDD(sc.makeRDD(1 to 100, 4), partitioner).mapValues(x => x) } /** @@ -362,8 +350,26 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { * upon checkpointing. Ignores the checkpointData field, which may grow when we checkpoint. */ def getSerializedSizes(rdd: RDD[_]): (Int, Int) = { - (Utils.serialize(rdd).length - Utils.serialize(rdd.checkpointData).length, - Utils.serialize(rdd.partitions).length) + val rddSize = Utils.serialize(rdd).size + val rddCpDataSize = Utils.serialize(rdd.checkpointData).size + val rddPartitionSize = Utils.serialize(rdd.partitions).size + val rddDependenciesSize = Utils.serialize(rdd.dependencies).size + + // Print detailed size, helps in debugging + logInfo("Serialized sizes of " + rdd + + ": RDD = " + rddSize + + ", RDD checkpoint data = " + rddCpDataSize + + ", RDD partitions = " + rddPartitionSize + + ", RDD dependencies = " + rddDependenciesSize + ) + // this makes sure that serializing the RDD's checkpoint data does not + // serialize the whole RDD as well + assert( + rddSize > rddCpDataSize, + "RDD's checkpoint data (" + rddCpDataSize + ") is equal or larger than the " + + "whole RDD with checkpoint data (" + rddSize + ")" + ) + (rddSize - rddCpDataSize, rddPartitionSize) } /** @@ -375,8 +381,49 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { val bytes = Utils.serialize(obj) Utils.deserialize[T](bytes) } + + /** + * Recursively force the initialization of the all members of an RDD and it parents. + */ + def initializeRdd(rdd: RDD[_]) { + rdd.partitions // forces the + rdd.dependencies.map(_.rdd).foreach(initializeRdd(_)) + } } +/** RDD partition that has large serialized size. */ +class FatPartition(val partition: Partition) extends Partition { + val bigData = new Array[Byte](10000) + def index: Int = partition.index +} + +/** RDD that has large serialized size. */ +class FatRDD(parent: RDD[Int]) extends RDD[Int](parent) { + val bigData = new Array[Byte](100000) + + protected def getPartitions: Array[Partition] = { + parent.partitions.map(p => new FatPartition(p)) + } + + def compute(split: Partition, context: TaskContext): Iterator[Int] = { + parent.compute(split.asInstanceOf[FatPartition].partition, context) + } +} + +/** Pair RDD that has large serialized size. */ +class FatPairRDD(parent: RDD[Int], _partitioner: Partitioner) extends RDD[(Int, Int)](parent) { + val bigData = new Array[Byte](100000) + + protected def getPartitions: Array[Partition] = { + parent.partitions.map(p => new FatPartition(p)) + } + + @transient override val partitioner = Some(_partitioner) + + def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = { + parent.compute(split.asInstanceOf[FatPartition].partition, context).map(x => (x, x)) + } +} object CheckpointSuite { // This is a custom cogroup function that does not use mapValues like diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala index 6d1695eae7..fb89537258 100644 --- a/core/src/test/scala/org/apache/spark/DriverSuite.scala +++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala @@ -30,13 +30,15 @@ import org.apache.spark.util.Utils class DriverSuite extends FunSuite with Timeouts { test("driver should exit after finishing") { - assert(System.getenv("SPARK_HOME") != null) + val sparkHome = sys.env.get("SPARK_HOME").orElse(sys.props.get("spark.home")).get // Regression test for SPARK-530: "Spark driver process doesn't exit after finishing" val masters = Table(("master"), ("local"), ("local-cluster[2,1,512]")) forAll(masters) { (master: String) => failAfter(60 seconds) { - Utils.execute(Seq("./spark-class", "org.apache.spark.DriverWithoutCleanup", master), - new File(System.getenv("SPARK_HOME"))) + Utils.executeAndGetOutput( + Seq("./bin/spark-class", "org.apache.spark.DriverWithoutCleanup", master), + new File(sparkHome), + Map("SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome)) } } } diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index af448fcb37..befdc1589f 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -42,7 +42,7 @@ class FailureSuite extends FunSuite with LocalSparkContext { // Run a 3-task map job in which task 1 deterministically fails once, and check // whether the job completes successfully and we ran 4 tasks in total. test("failure in a single-stage job") { - sc = new SparkContext("local[1,1]", "test") + sc = new SparkContext("local[1,2]", "test") val results = sc.makeRDD(1 to 3, 3).map { x => FailureSuiteState.synchronized { FailureSuiteState.tasksRun += 1 @@ -62,7 +62,7 @@ class FailureSuite extends FunSuite with LocalSparkContext { // Run a map-reduce job in which a reduce task deterministically fails once. test("failure in a two-stage job") { - sc = new SparkContext("local[1,1]", "test") + sc = new SparkContext("local[1,2]", "test") val results = sc.makeRDD(1 to 3).map(x => (x, x)).groupByKey(3).map { case (k, v) => FailureSuiteState.synchronized { diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala index c210dd5c3b..a2eb9a4e84 100644 --- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala @@ -17,33 +17,49 @@ package org.apache.spark +import java.io._ +import java.util.jar.{JarEntry, JarOutputStream} + +import SparkContext._ import com.google.common.io.Files import org.scalatest.FunSuite -import java.io.{File, PrintWriter, FileReader, BufferedReader} -import SparkContext._ class FileServerSuite extends FunSuite with LocalSparkContext { @transient var tmpFile: File = _ - @transient var testJarFile: File = _ - - override def beforeEach() { - super.beforeEach() - // Create a sample text file - val tmpdir = new File(Files.createTempDir(), "test") - tmpdir.mkdir() - tmpFile = new File(tmpdir, "FileServerSuite.txt") - val pw = new PrintWriter(tmpFile) + @transient var tmpJarUrl: String = _ + + override def beforeAll() { + super.beforeAll() + val tmpDir = new File(Files.createTempDir(), "test") + tmpDir.mkdir() + + val textFile = new File(tmpDir, "FileServerSuite.txt") + val pw = new PrintWriter(textFile) pw.println("100") pw.close() - } + + val jarFile = new File(tmpDir, "test.jar") + val jarStream = new FileOutputStream(jarFile) + val jar = new JarOutputStream(jarStream, new java.util.jar.Manifest()) - override def afterEach() { - super.afterEach() - // Clean up downloaded file - if (tmpFile.exists) { - tmpFile.delete() + val jarEntry = new JarEntry(textFile.getName) + jar.putNextEntry(jarEntry) + + val in = new FileInputStream(textFile) + val buffer = new Array[Byte](10240) + var nRead = 0 + while (nRead <= 0) { + nRead = in.read(buffer, 0, buffer.length) + jar.write(buffer, 0, nRead) } + + in.close() + jar.close() + jarStream.close() + + tmpFile = textFile + tmpJarUrl = jarFile.toURI.toURL.toString } test("Distributing files locally") { @@ -77,18 +93,13 @@ class FileServerSuite extends FunSuite with LocalSparkContext { test ("Dynamically adding JARS locally") { sc = new SparkContext("local[4]", "test") - val sampleJarFile = getClass.getClassLoader.getResource("uncommons-maths-1.2.2.jar").getFile() - sc.addJar(sampleJarFile) - val testData = Array((1,1), (1,1), (2,1), (3,5), (2,3), (3,0)) - val result = sc.parallelize(testData).reduceByKey { (x,y) => - val fac = Thread.currentThread.getContextClassLoader() - .loadClass("org.uncommons.maths.Maths") - .getDeclaredMethod("factorial", classOf[Int]) - val a = fac.invoke(null, x.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt - val b = fac.invoke(null, y.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt - a + b - }.collect() - assert(result.toSet === Set((1,2), (2,7), (3,121))) + sc.addJar(tmpJarUrl) + val testData = Array((1, 1)) + sc.parallelize(testData).foreach { x => + if (Thread.currentThread.getContextClassLoader.getResource("FileServerSuite.txt") == null) { + throw new SparkException("jar not added") + } + } } test("Distributing files on a standalone cluster") { @@ -107,33 +118,24 @@ class FileServerSuite extends FunSuite with LocalSparkContext { test ("Dynamically adding JARS on a standalone cluster") { sc = new SparkContext("local-cluster[1,1,512]", "test") - val sampleJarFile = getClass.getClassLoader.getResource("uncommons-maths-1.2.2.jar").getFile() - sc.addJar(sampleJarFile) - val testData = Array((1,1), (1,1), (2,1), (3,5), (2,3), (3,0)) - val result = sc.parallelize(testData).reduceByKey { (x,y) => - val fac = Thread.currentThread.getContextClassLoader() - .loadClass("org.uncommons.maths.Maths") - .getDeclaredMethod("factorial", classOf[Int]) - val a = fac.invoke(null, x.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt - val b = fac.invoke(null, y.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt - a + b - }.collect() - assert(result.toSet === Set((1,2), (2,7), (3,121))) + sc.addJar(tmpJarUrl) + val testData = Array((1,1)) + sc.parallelize(testData).foreach { x => + if (Thread.currentThread.getContextClassLoader.getResource("FileServerSuite.txt") == null) { + throw new SparkException("jar not added") + } + } } test ("Dynamically adding JARS on a standalone cluster using local: URL") { sc = new SparkContext("local-cluster[1,1,512]", "test") - val sampleJarFile = getClass.getClassLoader.getResource("uncommons-maths-1.2.2.jar").getFile() - sc.addJar(sampleJarFile.replace("file", "local")) - val testData = Array((1,1), (1,1), (2,1), (3,5), (2,3), (3,0)) - val result = sc.parallelize(testData).reduceByKey { (x,y) => - val fac = Thread.currentThread.getContextClassLoader() - .loadClass("org.uncommons.maths.Maths") - .getDeclaredMethod("factorial", classOf[Int]) - val a = fac.invoke(null, x.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt - val b = fac.invoke(null, y.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt - a + b - }.collect() - assert(result.toSet === Set((1,2), (2,7), (3,121))) + sc.addJar(tmpJarUrl.replace("file", "local")) + val testData = Array((1,1)) + sc.parallelize(testData).foreach { x => + if (Thread.currentThread.getContextClassLoader.getResource("FileServerSuite.txt") == null) { + throw new SparkException("jar not added") + } + } } + } diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java b/core/src/test/scala/org/apache/spark/JavaAPISuite.java index 79913dc718..23ec6c3b31 100644 --- a/core/src/test/scala/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java @@ -851,7 +851,7 @@ public class JavaAPISuite implements Serializable { public void checkpointAndComputation() { File tempDir = Files.createTempDir(); JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); - sc.setCheckpointDir(tempDir.getAbsolutePath(), true); + sc.setCheckpointDir(tempDir.getAbsolutePath()); Assert.assertEquals(false, rdd.isCheckpointed()); rdd.checkpoint(); rdd.count(); // Forces the DAG to cause a checkpoint @@ -863,7 +863,7 @@ public class JavaAPISuite implements Serializable { public void checkpointAndRestore() { File tempDir = Files.createTempDir(); JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); - sc.setCheckpointDir(tempDir.getAbsolutePath(), true); + sc.setCheckpointDir(tempDir.getAbsolutePath()); Assert.assertEquals(false, rdd.isCheckpointed()); rdd.checkpoint(); rdd.count(); // Forces the DAG to cause a checkpoint @@ -930,4 +930,36 @@ public class JavaAPISuite implements Serializable { parts[1]); } + @Test + public void countApproxDistinct() { + List<Integer> arrayData = new ArrayList<Integer>(); + int size = 100; + for (int i = 0; i < 100000; i++) { + arrayData.add(i % size); + } + JavaRDD<Integer> simpleRdd = sc.parallelize(arrayData, 10); + Assert.assertTrue(Math.abs((simpleRdd.countApproxDistinct(0.2) - size) / (size * 1.0)) < 0.2); + Assert.assertTrue(Math.abs((simpleRdd.countApproxDistinct(0.05) - size) / (size * 1.0)) <= 0.05); + Assert.assertTrue(Math.abs((simpleRdd.countApproxDistinct(0.01) - size) / (size * 1.0)) <= 0.01); + } + + @Test + public void countApproxDistinctByKey() { + double relativeSD = 0.001; + + List<Tuple2<Integer, Integer>> arrayData = new ArrayList<Tuple2<Integer, Integer>>(); + for (int i = 10; i < 100; i++) + for (int j = 0; j < i; j++) + arrayData.add(new Tuple2<Integer, Integer>(i, j)); + + JavaPairRDD<Integer, Integer> pairRdd = sc.parallelizePairs(arrayData); + List<Tuple2<Integer, Object>> res = pairRdd.countApproxDistinctByKey(relativeSD).collect(); + for (Tuple2<Integer, Object> resItem : res) { + double count = (double)resItem._1(); + Long resCount = (Long)resItem._2(); + Double error = Math.abs((resCount - count) / count); + Assert.assertTrue(error < relativeSD); + } + + } } diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 271dc905bc..10b8b441fd 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.AkkaUtils class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { - + private val conf = new SparkConf test("compressSize") { assert(MapOutputTracker.compressSize(0L) === 0) assert(MapOutputTracker.compressSize(1L) === 1) @@ -48,14 +48,14 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { test("master start and stop") { val actorSystem = ActorSystem("test") - val tracker = new MapOutputTrackerMaster() + val tracker = new MapOutputTrackerMaster(conf) tracker.trackerActor = Left(actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker)))) tracker.stop() } test("master register and fetch") { val actorSystem = ActorSystem("test") - val tracker = new MapOutputTrackerMaster() + val tracker = new MapOutputTrackerMaster(conf) tracker.trackerActor = Left(actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker)))) tracker.registerShuffle(10, 2) val compressedSize1000 = MapOutputTracker.compressSize(1000L) @@ -74,7 +74,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { test("master register and unregister and fetch") { val actorSystem = ActorSystem("test") - val tracker = new MapOutputTrackerMaster() + val tracker = new MapOutputTrackerMaster(conf) tracker.trackerActor = Left(actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker)))) tracker.registerShuffle(10, 2) val compressedSize1000 = MapOutputTracker.compressSize(1000L) @@ -96,16 +96,16 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { test("remote fetch") { val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0) + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf) System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext System.setProperty("spark.hostPort", hostname + ":" + boundPort) - val masterTracker = new MapOutputTrackerMaster() + val masterTracker = new MapOutputTrackerMaster(conf) masterTracker.trackerActor = Left(actorSystem.actorOf( Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker")) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0) - val slaveTracker = new MapOutputTracker() + val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf) + val slaveTracker = new MapOutputTracker(conf) slaveTracker.trackerActor = Right(slaveSystem.actorSelection( "akka.tcp://spark@localhost:" + boundPort + "/user/MapOutputTracker")) diff --git a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala index 288aa14eeb..c650ef4ed5 100644 --- a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala @@ -27,8 +27,10 @@ trait SharedSparkContext extends BeforeAndAfterAll { self: Suite => def sc: SparkContext = _sc + var conf = new SparkConf(false) + override def beforeAll() { - _sc = new SparkContext("local", "test") + _sc = new SparkContext("local", "test", conf) super.beforeAll() } diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala new file mode 100644 index 0000000000..ef5936dd2f --- /dev/null +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -0,0 +1,110 @@ +package org.apache.spark + +import org.scalatest.FunSuite + +class SparkConfSuite extends FunSuite with LocalSparkContext { + // This test uses the spark.conf in core/src/test/resources, which has a few test properties + test("loading from spark.conf") { + val conf = new SparkConf() + assert(conf.get("spark.test.intTestProperty") === "1") + assert(conf.get("spark.test.stringTestProperty") === "hi") + // NOTE: we don't use list properties yet, but when we do, we'll have to deal with this syntax + assert(conf.get("spark.test.listTestProperty") === "[a, b]") + } + + // This test uses the spark.conf in core/src/test/resources, which has a few test properties + test("system properties override spark.conf") { + try { + System.setProperty("spark.test.intTestProperty", "2") + val conf = new SparkConf() + assert(conf.get("spark.test.intTestProperty") === "2") + assert(conf.get("spark.test.stringTestProperty") === "hi") + } finally { + System.clearProperty("spark.test.intTestProperty") + } + } + + test("initializing without loading defaults") { + try { + System.setProperty("spark.test.intTestProperty", "2") + val conf = new SparkConf(false) + assert(!conf.contains("spark.test.intTestProperty")) + assert(!conf.contains("spark.test.stringTestProperty")) + } finally { + System.clearProperty("spark.test.intTestProperty") + } + } + + test("named set methods") { + val conf = new SparkConf(false) + + conf.setMaster("local[3]") + conf.setAppName("My app") + conf.setSparkHome("/path") + conf.setJars(Seq("a.jar", "b.jar")) + conf.setExecutorEnv("VAR1", "value1") + conf.setExecutorEnv(Seq(("VAR2", "value2"), ("VAR3", "value3"))) + + assert(conf.get("spark.master") === "local[3]") + assert(conf.get("spark.app.name") === "My app") + assert(conf.get("spark.home") === "/path") + assert(conf.get("spark.jars") === "a.jar,b.jar") + assert(conf.get("spark.executorEnv.VAR1") === "value1") + assert(conf.get("spark.executorEnv.VAR2") === "value2") + assert(conf.get("spark.executorEnv.VAR3") === "value3") + + // Test the Java-friendly versions of these too + conf.setJars(Array("c.jar", "d.jar")) + conf.setExecutorEnv(Array(("VAR4", "value4"), ("VAR5", "value5"))) + assert(conf.get("spark.jars") === "c.jar,d.jar") + assert(conf.get("spark.executorEnv.VAR4") === "value4") + assert(conf.get("spark.executorEnv.VAR5") === "value5") + } + + test("basic get and set") { + val conf = new SparkConf(false) + assert(conf.getAll.toSet === Set()) + conf.set("k1", "v1") + conf.setAll(Seq(("k2", "v2"), ("k3", "v3"))) + assert(conf.getAll.toSet === Set(("k1", "v1"), ("k2", "v2"), ("k3", "v3"))) + conf.set("k1", "v4") + conf.setAll(Seq(("k2", "v5"), ("k3", "v6"))) + assert(conf.getAll.toSet === Set(("k1", "v4"), ("k2", "v5"), ("k3", "v6"))) + assert(conf.contains("k1"), "conf did not contain k1") + assert(!conf.contains("k4"), "conf contained k4") + assert(conf.get("k1") === "v4") + intercept[Exception] { conf.get("k4") } + assert(conf.get("k4", "not found") === "not found") + assert(conf.getOption("k1") === Some("v4")) + assert(conf.getOption("k4") === None) + } + + test("creating SparkContext without master and app name") { + val conf = new SparkConf(false) + intercept[SparkException] { sc = new SparkContext(conf) } + } + + test("creating SparkContext without master") { + val conf = new SparkConf(false).setAppName("My app") + intercept[SparkException] { sc = new SparkContext(conf) } + } + + test("creating SparkContext without app name") { + val conf = new SparkConf(false).setMaster("local") + intercept[SparkException] { sc = new SparkContext(conf) } + } + + test("creating SparkContext with both master and app name") { + val conf = new SparkConf(false).setMaster("local").setAppName("My app") + sc = new SparkContext(conf) + assert(sc.master === "local") + assert(sc.appName === "My app") + } + + test("SparkContext property overriding") { + val conf = new SparkConf(false).setMaster("local").setAppName("My app") + sc = new SparkContext("local[2]", "My other app", conf) + assert(sc.master === "local[2]") + assert(sc.appName === "My other app") + } +} diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala index 151af0d213..f28d5c7b13 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -19,20 +19,21 @@ package org.apache.spark import org.scalatest.{FunSuite, PrivateMethodTester} -import org.apache.spark.scheduler.TaskScheduler -import org.apache.spark.scheduler.cluster.{ClusterScheduler, SimrSchedulerBackend, SparkDeploySchedulerBackend} +import org.apache.spark.scheduler.{TaskSchedulerImpl, TaskScheduler} +import org.apache.spark.scheduler.cluster.{SimrSchedulerBackend, SparkDeploySchedulerBackend} import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} -import org.apache.spark.scheduler.local.LocalScheduler +import org.apache.spark.scheduler.local.LocalBackend class SparkContextSchedulerCreationSuite extends FunSuite with PrivateMethodTester with LocalSparkContext with Logging { - def createTaskScheduler(master: String): TaskScheduler = { + def createTaskScheduler(master: String): TaskSchedulerImpl = { // Create local SparkContext to setup a SparkEnv. We don't actually want to start() the // real schedulers, so we don't want to create a full SparkContext with the desired scheduler. sc = new SparkContext("local", "test") val createTaskSchedulerMethod = PrivateMethod[TaskScheduler]('createTaskScheduler) - SparkContext invokePrivate createTaskSchedulerMethod(sc, master, "test") + val sched = SparkContext invokePrivate createTaskSchedulerMethod(sc, master, "test") + sched.asInstanceOf[TaskSchedulerImpl] } test("bad-master") { @@ -43,55 +44,49 @@ class SparkContextSchedulerCreationSuite } test("local") { - createTaskScheduler("local") match { - case s: LocalScheduler => - assert(s.threads === 1) - assert(s.maxFailures === 0) + val sched = createTaskScheduler("local") + sched.backend match { + case s: LocalBackend => assert(s.totalCores === 1) case _ => fail() } } test("local-n") { - createTaskScheduler("local[5]") match { - case s: LocalScheduler => - assert(s.threads === 5) - assert(s.maxFailures === 0) + val sched = createTaskScheduler("local[5]") + assert(sched.maxTaskFailures === 1) + sched.backend match { + case s: LocalBackend => assert(s.totalCores === 5) case _ => fail() } } test("local-n-failures") { - createTaskScheduler("local[4, 2]") match { - case s: LocalScheduler => - assert(s.threads === 4) - assert(s.maxFailures === 2) + val sched = createTaskScheduler("local[4, 2]") + assert(sched.maxTaskFailures === 2) + sched.backend match { + case s: LocalBackend => assert(s.totalCores === 4) case _ => fail() } } test("simr") { - createTaskScheduler("simr://uri") match { - case s: ClusterScheduler => - assert(s.backend.isInstanceOf[SimrSchedulerBackend]) + createTaskScheduler("simr://uri").backend match { + case s: SimrSchedulerBackend => // OK case _ => fail() } } test("local-cluster") { - createTaskScheduler("local-cluster[3, 14, 512]") match { - case s: ClusterScheduler => - assert(s.backend.isInstanceOf[SparkDeploySchedulerBackend]) + createTaskScheduler("local-cluster[3, 14, 512]").backend match { + case s: SparkDeploySchedulerBackend => // OK case _ => fail() } } def testYarn(master: String, expectedClassName: String) { try { - createTaskScheduler(master) match { - case s: ClusterScheduler => - assert(s.getClass === Class.forName(expectedClassName)) - case _ => fail() - } + val sched = createTaskScheduler(master) + assert(sched.getClass === Class.forName(expectedClassName)) } catch { case e: SparkException => assert(e.getMessage.contains("YARN mode not available")) @@ -110,11 +105,8 @@ class SparkContextSchedulerCreationSuite def testMesos(master: String, expectedClass: Class[_]) { try { - createTaskScheduler(master) match { - case s: ClusterScheduler => - assert(s.backend.getClass === expectedClass) - case _ => fail() - } + val sched = createTaskScheduler(master) + assert(sched.backend.getClass === expectedClass) } catch { case e: UnsatisfiedLinkError => assert(e.getMessage.contains("no mesos in")) diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala index 4cb4ddc9cd..f58b1ee05a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala @@ -18,13 +18,15 @@ package org.apache.spark.deploy.worker import java.io.File + import org.scalatest.FunSuite + import org.apache.spark.deploy.{ExecutorState, Command, ApplicationDescription} class ExecutorRunnerTest extends FunSuite { test("command includes appId") { def f(s:String) = new File(s) - val sparkHome = sys.env("SPARK_HOME") + val sparkHome = sys.env.get("SPARK_HOME").orElse(sys.env.get("spark.home")).get val appDesc = new ApplicationDescription("app name", 8, 500, Command("foo", Seq(),Map()), sparkHome, "appUiUrl") val appId = "12345-worker321-9876" diff --git a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala index ab81bfbe55..8d7546085f 100644 --- a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala @@ -20,9 +20,11 @@ package org.apache.spark.io import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import org.scalatest.FunSuite +import org.apache.spark.SparkConf class CompressionCodecSuite extends FunSuite { + val conf = new SparkConf(false) def testCodec(codec: CompressionCodec) { // Write 1000 integers to the output stream, compressed. @@ -43,19 +45,19 @@ class CompressionCodecSuite extends FunSuite { } test("default compression codec") { - val codec = CompressionCodec.createCodec() + val codec = CompressionCodec.createCodec(conf) assert(codec.getClass === classOf[LZFCompressionCodec]) testCodec(codec) } test("lzf compression codec") { - val codec = CompressionCodec.createCodec(classOf[LZFCompressionCodec].getName) + val codec = CompressionCodec.createCodec(conf, classOf[LZFCompressionCodec].getName) assert(codec.getClass === classOf[LZFCompressionCodec]) testCodec(codec) } test("snappy compression codec") { - val codec = CompressionCodec.createCodec(classOf[SnappyCompressionCodec].getName) + val codec = CompressionCodec.createCodec(conf, classOf[SnappyCompressionCodec].getName) assert(codec.getClass === classOf[SnappyCompressionCodec]) testCodec(codec) } diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala index 7181333adf..71a2c6c498 100644 --- a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala @@ -19,17 +19,19 @@ package org.apache.spark.metrics import org.scalatest.{BeforeAndAfter, FunSuite} import org.apache.spark.deploy.master.MasterSource +import org.apache.spark.SparkConf class MetricsSystemSuite extends FunSuite with BeforeAndAfter { var filePath: String = _ + var conf: SparkConf = null before { filePath = getClass.getClassLoader.getResource("test_metrics_system.properties").getFile() - System.setProperty("spark.metrics.conf", filePath) + conf = new SparkConf(false).set("spark.metrics.conf", filePath) } test("MetricsSystem with default config") { - val metricsSystem = MetricsSystem.createMetricsSystem("default") + val metricsSystem = MetricsSystem.createMetricsSystem("default", conf) val sources = metricsSystem.sources val sinks = metricsSystem.sinks @@ -39,7 +41,7 @@ class MetricsSystemSuite extends FunSuite with BeforeAndAfter { } test("MetricsSystem with sources add") { - val metricsSystem = MetricsSystem.createMetricsSystem("test") + val metricsSystem = MetricsSystem.createMetricsSystem("test", conf) val sources = metricsSystem.sources val sinks = metricsSystem.sinks diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 57d3382ed0..5da538a1dd 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.rdd import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashSet +import scala.util.Random import org.scalatest.FunSuite @@ -109,6 +110,39 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { assert(deps.size === 2) // ShuffledRDD, ParallelCollection. } + test("countApproxDistinctByKey") { + def error(est: Long, size: Long) = math.abs(est - size) / size.toDouble + + /* Since HyperLogLog unique counting is approximate, and the relative standard deviation is + * only a statistical bound, the tests can fail for large values of relativeSD. We will be using + * relatively tight error bounds to check correctness of functionality rather than checking + * whether the approximation conforms with the requested bound. + */ + val relativeSD = 0.001 + + // For each value i, there are i tuples with first element equal to i. + // Therefore, the expected count for key i would be i. + val stacked = (1 to 100).flatMap(i => (1 to i).map(j => (i, j))) + val rdd1 = sc.parallelize(stacked) + val counted1 = rdd1.countApproxDistinctByKey(relativeSD).collect() + counted1.foreach{ + case(k, count) => assert(error(count, k) < relativeSD) + } + + val rnd = new Random() + + // The expected count for key num would be num + val randStacked = (1 to 100).flatMap { i => + val num = rnd.nextInt % 500 + (1 to num).map(j => (num, j)) + } + val rdd2 = sc.parallelize(randStacked) + val counted2 = rdd2.countApproxDistinctByKey(relativeSD, 4).collect() + counted2.foreach{ + case(k, count) => assert(error(count, k) < relativeSD) + } + } + test("join") { val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index d8dcd6d14c..559ea051d3 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -63,6 +63,19 @@ class RDDSuite extends FunSuite with SharedSparkContext { } } + test("countApproxDistinct") { + + def error(est: Long, size: Long) = math.abs(est - size) / size.toDouble + + val size = 100 + val uniformDistro = for (i <- 1 to 100000) yield i % size + val simpleRdd = sc.makeRDD(uniformDistro) + assert(error(simpleRdd.countApproxDistinct(0.2), size) < 0.2) + assert(error(simpleRdd.countApproxDistinct(0.05), size) < 0.05) + assert(error(simpleRdd.countApproxDistinct(0.01), size) < 0.01) + assert(error(simpleRdd.countApproxDistinct(0.001), size) < 0.001) + } + test("SparkContext.union") { val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) assert(sc.union(nums).collect().toList === List(1, 2, 3, 4)) @@ -71,6 +84,33 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(sc.union(Seq(nums, nums)).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4)) } + test("partitioner aware union") { + import SparkContext._ + def makeRDDWithPartitioner(seq: Seq[Int]) = { + sc.makeRDD(seq, 1) + .map(x => (x, null)) + .partitionBy(new HashPartitioner(2)) + .mapPartitions(_.map(_._1), true) + } + + val nums1 = makeRDDWithPartitioner(1 to 4) + val nums2 = makeRDDWithPartitioner(5 to 8) + assert(nums1.partitioner == nums2.partitioner) + assert(new PartitionerAwareUnionRDD(sc, Seq(nums1)).collect().toSet === Set(1, 2, 3, 4)) + + val union = new PartitionerAwareUnionRDD(sc, Seq(nums1, nums2)) + assert(union.collect().toSet === Set(1, 2, 3, 4, 5, 6, 7, 8)) + val nums1Parts = nums1.collectPartitions() + val nums2Parts = nums2.collectPartitions() + val unionParts = union.collectPartitions() + assert(nums1Parts.length === 2) + assert(nums2Parts.length === 2) + assert(unionParts.length === 2) + assert((nums1Parts(0) ++ nums2Parts(0)).toList === unionParts(0).toList) + assert((nums1Parts(1) ++ nums2Parts(1)).toList === unionParts(1).toList) + assert(union.partitioner === nums1.partitioner) + } + test("aggregate") { val pairs = sc.makeRDD(Array(("a", 1), ("b", 2), ("a", 2), ("c", 5), ("a", 3))) type StringMap = HashMap[String, Int] diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala index 95d3553d91..7bf2020fe3 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala @@ -15,14 +15,12 @@ * limitations under the License. */ -package org.apache.spark.scheduler.cluster +package org.apache.spark.scheduler import org.scalatest.FunSuite import org.scalatest.BeforeAndAfter import org.apache.spark._ -import org.apache.spark.scheduler._ -import org.apache.spark.scheduler.cluster._ import scala.collection.mutable.ArrayBuffer import java.util.Properties @@ -31,9 +29,9 @@ class FakeTaskSetManager( initPriority: Int, initStageId: Int, initNumTasks: Int, - clusterScheduler: ClusterScheduler, + clusterScheduler: TaskSchedulerImpl, taskSet: TaskSet) - extends ClusterTaskSetManager(clusterScheduler, taskSet) { + extends TaskSetManager(clusterScheduler, taskSet, 0) { parent = null weight = 1 @@ -106,7 +104,7 @@ class FakeTaskSetManager( class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging { - def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: ClusterScheduler, taskSet: TaskSet): FakeTaskSetManager = { + def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: TaskSchedulerImpl, taskSet: TaskSet): FakeTaskSetManager = { new FakeTaskSetManager(priority, stage, numTasks, cs , taskSet) } @@ -133,7 +131,7 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging test("FIFO Scheduler Test") { sc = new SparkContext("local", "ClusterSchedulerSuite") - val clusterScheduler = new ClusterScheduler(sc) + val clusterScheduler = new TaskSchedulerImpl(sc) var tasks = ArrayBuffer[Task[_]]() val task = new FakeTask(0) tasks += task @@ -160,7 +158,7 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging test("Fair Scheduler Test") { sc = new SparkContext("local", "ClusterSchedulerSuite") - val clusterScheduler = new ClusterScheduler(sc) + val clusterScheduler = new TaskSchedulerImpl(sc) var tasks = ArrayBuffer[Task[_]]() val task = new FakeTask(0) tasks += task @@ -169,7 +167,7 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() System.setProperty("spark.scheduler.allocation.file", xmlPath) val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0) - val schedulableBuilder = new FairSchedulableBuilder(rootPool) + val schedulableBuilder = new FairSchedulableBuilder(rootPool, sc.conf) schedulableBuilder.buildPools() assert(rootPool.getSchedulableByName("default") != null) @@ -217,7 +215,7 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging test("Nested Pool Test") { sc = new SparkContext("local", "ClusterSchedulerSuite") - val clusterScheduler = new ClusterScheduler(sc) + val clusterScheduler = new TaskSchedulerImpl(sc) var tasks = ArrayBuffer[Task[_]]() val task = new FakeTask(0) tasks += task diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 706d84a58b..2aa259daf3 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -17,21 +17,14 @@ package org.apache.spark.scheduler -import scala.collection.mutable.{Map, HashMap} - -import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter - -import org.apache.spark.LocalSparkContext -import org.apache.spark.MapOutputTrackerMaster -import org.apache.spark.SparkContext -import org.apache.spark.Partition -import org.apache.spark.TaskContext -import org.apache.spark.{Dependency, ShuffleDependency, OneToOneDependency} -import org.apache.spark.{FetchFailed, Success, TaskEndReason} +import scala.Tuple2 +import scala.collection.mutable.{HashMap, Map} + +import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} +import org.scalatest.{BeforeAndAfter, FunSuite} /** * Tests for DAGScheduler. These tests directly call the event processing functions in DAGScheduler @@ -46,7 +39,7 @@ import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} * and capturing the resulting TaskSets from the mock TaskScheduler. */ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { - + val conf = new SparkConf /** Set of TaskSets the DAGScheduler has requested executed. */ val taskSets = scala.collection.mutable.Buffer[TaskSet]() val taskScheduler = new TaskScheduler() { @@ -74,7 +67,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont */ val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]] // stub out BlockManagerMaster.getLocations to use our cacheLocations - val blockManagerMaster = new BlockManagerMaster(null) { + val blockManagerMaster = new BlockManagerMaster(null, conf) { override def getLocations(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = { blockIds.map { _.asRDDId.map(id => (id.rddId -> id.splitIndex)).flatMap(key => cacheLocations.get(key)). @@ -99,7 +92,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont taskSets.clear() cacheLocations.clear() results.clear() - mapOutputTracker = new MapOutputTrackerMaster() + mapOutputTracker = new MapOutputTrackerMaster(conf) scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, sc.env) { override def runLocally(job: ActiveJob) { // don't bother with the thread while unit testing diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala index 0f01515179..0b90c4e74c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala @@ -15,10 +15,9 @@ * limitations under the License. */ -package org.apache.spark.scheduler.cluster +package org.apache.spark.scheduler import org.apache.spark.TaskContext -import org.apache.spark.scheduler.{TaskLocation, Task} class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0) { override def runTask(context: TaskContext): Int = 0 diff --git a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala index 002368ff55..5cc48ee00a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala @@ -95,7 +95,7 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) - val user = System.getProperty("user.name", SparkContext.SPARK_UNKNOWN_USER) + val user = System.getProperty("user.name", SparkContext.SPARK_UNKNOWN_USER) joblogger.getLogDir should be ("/tmp/spark-%s".format(user)) joblogger.getJobIDtoPrintWriter.size should be (1) @@ -117,7 +117,7 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = onTaskEndCount += 1 override def onJobEnd(jobEnd: SparkListenerJobEnd) = onJobEndCount += 1 override def onJobStart(jobStart: SparkListenerJobStart) = onJobStartCount += 1 - override def onStageCompleted(stageCompleted: StageCompleted) = onStageCompletedCount += 1 + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) = onStageCompletedCount += 1 override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) = onStageSubmittedCount += 1 } sc.addSparkListener(joblogger) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 2e41438a52..1a16e438c4 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -19,23 +19,26 @@ package org.apache.spark.scheduler import scala.collection.mutable.{Buffer, HashSet} -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} import org.scalatest.matchers.ShouldMatchers import org.apache.spark.{LocalSparkContext, SparkContext} import org.apache.spark.SparkContext._ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers - with BeforeAndAfterAll { + with BeforeAndAfter with BeforeAndAfterAll { /** Length of time to wait while draining listener events. */ val WAIT_TIMEOUT_MILLIS = 10000 + before { + sc = new SparkContext("local", "SparkListenerSuite") + } + override def afterAll { System.clearProperty("spark.akka.frameSize") } test("basic creation of StageInfo") { - sc = new SparkContext("local", "DAGSchedulerSuite") val listener = new SaveStageInfo sc.addSparkListener(listener) val rdd1 = sc.parallelize(1 to 100, 4) @@ -56,7 +59,6 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc } test("StageInfo with fewer tasks than partitions") { - sc = new SparkContext("local", "DAGSchedulerSuite") val listener = new SaveStageInfo sc.addSparkListener(listener) val rdd1 = sc.parallelize(1 to 100, 4) @@ -72,7 +74,6 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc } test("local metrics") { - sc = new SparkContext("local", "DAGSchedulerSuite") val listener = new SaveStageInfo sc.addSparkListener(listener) sc.addSparkListener(new StatsReportListener) @@ -135,10 +136,6 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc } test("onTaskGettingResult() called when result fetched remotely") { - // Need to use local cluster mode here, because results are not ever returned through the - // block manager when using the LocalScheduler. - sc = new SparkContext("local-cluster[1,1,512]", "test") - val listener = new SaveTaskEvents sc.addSparkListener(listener) @@ -157,10 +154,6 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc } test("onTaskGettingResult() not called when result sent directly") { - // Need to use local cluster mode here, because results are not ever returned through the - // block manager when using the LocalScheduler. - sc = new SparkContext("local-cluster[1,1,512]", "test") - val listener = new SaveTaskEvents sc.addSparkListener(listener) @@ -181,7 +174,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc class SaveStageInfo extends SparkListener { val stageInfos = Buffer[StageInfo]() - override def onStageCompleted(stage: StageCompleted) { + override def onStageCompleted(stage: SparkListenerStageCompleted) { stageInfos += stage.stage } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index 27c2d53361..4b52d9651e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -15,14 +15,13 @@ * limitations under the License. */ -package org.apache.spark.scheduler.cluster +package org.apache.spark.scheduler import java.nio.ByteBuffer import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} -import org.apache.spark.{LocalSparkContext, SparkContext, SparkEnv} -import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, TaskResult} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv} import org.apache.spark.storage.TaskResultBlockId /** @@ -31,12 +30,12 @@ import org.apache.spark.storage.TaskResultBlockId * Used to test the case where a BlockManager evicts the task result (or dies) before the * TaskResult is retrieved. */ -class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterScheduler) +class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedulerImpl) extends TaskResultGetter(sparkEnv, scheduler) { var removedResult = false override def enqueueSuccessfulTask( - taskSetManager: ClusterTaskSetManager, tid: Long, serializedData: ByteBuffer) { + taskSetManager: TaskSetManager, tid: Long, serializedData: ByteBuffer) { if (!removedResult) { // Only remove the result once, since we'd like to test the case where the task eventually // succeeds. @@ -44,13 +43,13 @@ class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterSched case IndirectTaskResult(blockId) => sparkEnv.blockManager.master.removeBlock(blockId) case directResult: DirectTaskResult[_] => - taskSetManager.abort("Internal error: expect only indirect results") + taskSetManager.abort("Internal error: expect only indirect results") } serializedData.rewind() removedResult = true } super.enqueueSuccessfulTask(taskSetManager, tid, serializedData) - } + } } /** @@ -65,22 +64,18 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndA System.setProperty("spark.akka.frameSize", "1") } - before { - // Use local-cluster mode because results are returned differently when running with the - // LocalScheduler. - sc = new SparkContext("local-cluster[1,1,512]", "test") - } - override def afterAll { System.clearProperty("spark.akka.frameSize") } test("handling results smaller than Akka frame size") { + sc = new SparkContext("local", "test") val result = sc.parallelize(Seq(1), 1).map(x => 2 * x).reduce((x, y) => x) assert(result === 2) } - test("handling results larger than Akka frame size") { + test("handling results larger than Akka frame size") { + sc = new SparkContext("local", "test") val akkaFrameSize = sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x) @@ -92,10 +87,13 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndA } test("task retried if result missing from block manager") { + // Set the maximum number of task failures to > 0, so that the task set isn't aborted + // after the result is missing. + sc = new SparkContext("local[1,2]", "test") // If this test hangs, it's probably because no resource offers were made after the task // failed. - val scheduler: ClusterScheduler = sc.taskScheduler match { - case clusterScheduler: ClusterScheduler => + val scheduler: TaskSchedulerImpl = sc.taskScheduler match { + case clusterScheduler: TaskSchedulerImpl => clusterScheduler case _ => assert(false, "Expect local cluster to use ClusterScheduler") diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index bb28a31a99..1eec6726f4 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.scheduler.cluster +package org.apache.spark.scheduler import scala.collection.mutable.ArrayBuffer import scala.collection.mutable @@ -23,7 +23,6 @@ import scala.collection.mutable import org.scalatest.FunSuite import org.apache.spark._ -import org.apache.spark.scheduler._ import org.apache.spark.executor.TaskMetrics import java.nio.ByteBuffer import org.apache.spark.util.{Utils, FakeClock} @@ -56,10 +55,10 @@ class FakeDAGScheduler(taskScheduler: FakeClusterScheduler) extends DAGScheduler * A mock ClusterScheduler implementation that just remembers information about tasks started and * feedback received from the TaskSetManagers. Note that it's important to initialize this with * a list of "live" executors and their hostnames for isExecutorAlive and hasExecutorsAliveOnHost - * to work, and these are required for locality in ClusterTaskSetManager. + * to work, and these are required for locality in TaskSetManager. */ class FakeClusterScheduler(sc: SparkContext, liveExecutors: (String, String)* /* execId, host */) - extends ClusterScheduler(sc) + extends TaskSchedulerImpl(sc) { val startedTasks = new ArrayBuffer[Long] val endedTasks = new mutable.HashMap[Long, TaskEndReason] @@ -79,16 +78,19 @@ class FakeClusterScheduler(sc: SparkContext, liveExecutors: (String, String)* /* override def hasExecutorsAliveOnHost(host: String): Boolean = executors.values.exists(_ == host) } -class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { +class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { import TaskLocality.{ANY, PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL} - val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong + private val conf = new SparkConf + + val LOCALITY_WAIT = conf.get("spark.locality.wait", "3000").toLong + val MAX_TASK_FAILURES = 4 test("TaskSet with no preferences") { sc = new SparkContext("local", "test") val sched = new FakeClusterScheduler(sc, ("exec1", "host1")) val taskSet = createTaskSet(1) - val manager = new ClusterTaskSetManager(sched, taskSet) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) // Offer a host with no CPUs assert(manager.resourceOffer("exec1", "host1", 0, ANY) === None) @@ -114,7 +116,7 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo sc = new SparkContext("local", "test") val sched = new FakeClusterScheduler(sc, ("exec1", "host1")) val taskSet = createTaskSet(3) - val manager = new ClusterTaskSetManager(sched, taskSet) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) // First three offers should all find tasks for (i <- 0 until 3) { @@ -151,7 +153,7 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo Seq() // Last task has no locality prefs ) val clock = new FakeClock - val manager = new ClusterTaskSetManager(sched, taskSet, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // First offer host1, exec1: first task should be chosen assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0) @@ -197,7 +199,7 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo Seq(TaskLocation("host2")) ) val clock = new FakeClock - val manager = new ClusterTaskSetManager(sched, taskSet, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // First offer host1: first task should be chosen assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0) @@ -234,7 +236,7 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo Seq(TaskLocation("host3")) ) val clock = new FakeClock - val manager = new ClusterTaskSetManager(sched, taskSet, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // First offer host1: first task should be chosen assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0) @@ -262,7 +264,7 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo val sched = new FakeClusterScheduler(sc, ("exec1", "host1")) val taskSet = createTaskSet(1) val clock = new FakeClock - val manager = new ClusterTaskSetManager(sched, taskSet, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0) @@ -279,17 +281,17 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo val sched = new FakeClusterScheduler(sc, ("exec1", "host1")) val taskSet = createTaskSet(1) val clock = new FakeClock - val manager = new ClusterTaskSetManager(sched, taskSet, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // Fail the task MAX_TASK_FAILURES times, and check that the task set is aborted // after the last failure. - (1 to manager.MAX_TASK_FAILURES).foreach { index => + (1 to manager.maxTaskFailures).foreach { index => val offerResult = manager.resourceOffer("exec1", "host1", 1, ANY) assert(offerResult != None, "Expect resource offer on iteration %s to return a task".format(index)) assert(offerResult.get.index === 0) manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, Some(TaskResultLost)) - if (index < manager.MAX_TASK_FAILURES) { + if (index < MAX_TASK_FAILURES) { assert(!sched.taskSetsFailed.contains(taskSet.id)) } else { assert(sched.taskSetsFailed.contains(taskSet.id)) diff --git a/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala deleted file mode 100644 index 1e676c1719..0000000000 --- a/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala +++ /dev/null @@ -1,227 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler.local - -import java.util.concurrent.Semaphore -import java.util.concurrent.CountDownLatch - -import scala.collection.mutable.HashMap - -import org.scalatest.{BeforeAndAfterEach, FunSuite} - -import org.apache.spark._ - - -class Lock() { - var finished = false - def jobWait() = { - synchronized { - while(!finished) { - this.wait() - } - } - } - - def jobFinished() = { - synchronized { - finished = true - this.notifyAll() - } - } -} - -object TaskThreadInfo { - val threadToLock = HashMap[Int, Lock]() - val threadToRunning = HashMap[Int, Boolean]() - val threadToStarted = HashMap[Int, CountDownLatch]() -} - -/* - * 1. each thread contains one job. - * 2. each job contains one stage. - * 3. each stage only contains one task. - * 4. each task(launched) must be lanched orderly(using threadToStarted) to make sure - * it will get cpu core resource, and will wait to finished after user manually - * release "Lock" and then cluster will contain another free cpu cores. - * 5. each task(pending) must use "sleep" to make sure it has been added to taskSetManager queue, - * thus it will be scheduled later when cluster has free cpu cores. - */ -class LocalSchedulerSuite extends FunSuite with LocalSparkContext with BeforeAndAfterEach { - - override def afterEach() { - super.afterEach() - System.clearProperty("spark.scheduler.mode") - } - - def createThread(threadIndex: Int, poolName: String, sc: SparkContext, sem: Semaphore) { - - TaskThreadInfo.threadToRunning(threadIndex) = false - val nums = sc.parallelize(threadIndex to threadIndex, 1) - TaskThreadInfo.threadToLock(threadIndex) = new Lock() - TaskThreadInfo.threadToStarted(threadIndex) = new CountDownLatch(1) - new Thread { - if (poolName != null) { - sc.setLocalProperty("spark.scheduler.pool", poolName) - } - override def run() { - val ans = nums.map(number => { - TaskThreadInfo.threadToRunning(number) = true - TaskThreadInfo.threadToStarted(number).countDown() - TaskThreadInfo.threadToLock(number).jobWait() - TaskThreadInfo.threadToRunning(number) = false - number - }).collect() - assert(ans.toList === List(threadIndex)) - sem.release() - } - }.start() - } - - test("Local FIFO scheduler end-to-end test") { - System.setProperty("spark.scheduler.mode", "FIFO") - sc = new SparkContext("local[4]", "test") - val sem = new Semaphore(0) - - createThread(1,null,sc,sem) - TaskThreadInfo.threadToStarted(1).await() - createThread(2,null,sc,sem) - TaskThreadInfo.threadToStarted(2).await() - createThread(3,null,sc,sem) - TaskThreadInfo.threadToStarted(3).await() - createThread(4,null,sc,sem) - TaskThreadInfo.threadToStarted(4).await() - // thread 5 and 6 (stage pending)must meet following two points - // 1. stages (taskSetManager) of jobs in thread 5 and 6 should be add to taskSetManager - // queue before executing TaskThreadInfo.threadToLock(1).jobFinished() - // 2. priority of stage in thread 5 should be prior to priority of stage in thread 6 - // So I just use "sleep" 1s here for each thread. - // TODO: any better solution? - createThread(5,null,sc,sem) - Thread.sleep(1000) - createThread(6,null,sc,sem) - Thread.sleep(1000) - - assert(TaskThreadInfo.threadToRunning(1) === true) - assert(TaskThreadInfo.threadToRunning(2) === true) - assert(TaskThreadInfo.threadToRunning(3) === true) - assert(TaskThreadInfo.threadToRunning(4) === true) - assert(TaskThreadInfo.threadToRunning(5) === false) - assert(TaskThreadInfo.threadToRunning(6) === false) - - TaskThreadInfo.threadToLock(1).jobFinished() - TaskThreadInfo.threadToStarted(5).await() - - assert(TaskThreadInfo.threadToRunning(1) === false) - assert(TaskThreadInfo.threadToRunning(2) === true) - assert(TaskThreadInfo.threadToRunning(3) === true) - assert(TaskThreadInfo.threadToRunning(4) === true) - assert(TaskThreadInfo.threadToRunning(5) === true) - assert(TaskThreadInfo.threadToRunning(6) === false) - - TaskThreadInfo.threadToLock(3).jobFinished() - TaskThreadInfo.threadToStarted(6).await() - - assert(TaskThreadInfo.threadToRunning(1) === false) - assert(TaskThreadInfo.threadToRunning(2) === true) - assert(TaskThreadInfo.threadToRunning(3) === false) - assert(TaskThreadInfo.threadToRunning(4) === true) - assert(TaskThreadInfo.threadToRunning(5) === true) - assert(TaskThreadInfo.threadToRunning(6) === true) - - TaskThreadInfo.threadToLock(2).jobFinished() - TaskThreadInfo.threadToLock(4).jobFinished() - TaskThreadInfo.threadToLock(5).jobFinished() - TaskThreadInfo.threadToLock(6).jobFinished() - sem.acquire(6) - } - - test("Local fair scheduler end-to-end test") { - System.setProperty("spark.scheduler.mode", "FAIR") - val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() - System.setProperty("spark.scheduler.allocation.file", xmlPath) - - sc = new SparkContext("local[8]", "LocalSchedulerSuite") - val sem = new Semaphore(0) - - createThread(10,"1",sc,sem) - TaskThreadInfo.threadToStarted(10).await() - createThread(20,"2",sc,sem) - TaskThreadInfo.threadToStarted(20).await() - createThread(30,"3",sc,sem) - TaskThreadInfo.threadToStarted(30).await() - - assert(TaskThreadInfo.threadToRunning(10) === true) - assert(TaskThreadInfo.threadToRunning(20) === true) - assert(TaskThreadInfo.threadToRunning(30) === true) - - createThread(11,"1",sc,sem) - TaskThreadInfo.threadToStarted(11).await() - createThread(21,"2",sc,sem) - TaskThreadInfo.threadToStarted(21).await() - createThread(31,"3",sc,sem) - TaskThreadInfo.threadToStarted(31).await() - - assert(TaskThreadInfo.threadToRunning(11) === true) - assert(TaskThreadInfo.threadToRunning(21) === true) - assert(TaskThreadInfo.threadToRunning(31) === true) - - createThread(12,"1",sc,sem) - TaskThreadInfo.threadToStarted(12).await() - createThread(22,"2",sc,sem) - TaskThreadInfo.threadToStarted(22).await() - createThread(32,"3",sc,sem) - - assert(TaskThreadInfo.threadToRunning(12) === true) - assert(TaskThreadInfo.threadToRunning(22) === true) - assert(TaskThreadInfo.threadToRunning(32) === false) - - TaskThreadInfo.threadToLock(10).jobFinished() - TaskThreadInfo.threadToStarted(32).await() - - assert(TaskThreadInfo.threadToRunning(32) === true) - - //1. Similar with above scenario, sleep 1s for stage of 23 and 33 to be added to taskSetManager - // queue so that cluster will assign free cpu core to stage 23 after stage 11 finished. - //2. priority of 23 and 33 will be meaningless as using fair scheduler here. - createThread(23,"2",sc,sem) - createThread(33,"3",sc,sem) - Thread.sleep(1000) - - TaskThreadInfo.threadToLock(11).jobFinished() - TaskThreadInfo.threadToStarted(23).await() - - assert(TaskThreadInfo.threadToRunning(23) === true) - assert(TaskThreadInfo.threadToRunning(33) === false) - - TaskThreadInfo.threadToLock(12).jobFinished() - TaskThreadInfo.threadToStarted(33).await() - - assert(TaskThreadInfo.threadToRunning(33) === true) - - TaskThreadInfo.threadToLock(20).jobFinished() - TaskThreadInfo.threadToLock(21).jobFinished() - TaskThreadInfo.threadToLock(22).jobFinished() - TaskThreadInfo.threadToLock(23).jobFinished() - TaskThreadInfo.threadToLock(30).jobFinished() - TaskThreadInfo.threadToLock(31).jobFinished() - TaskThreadInfo.threadToLock(32).jobFinished() - TaskThreadInfo.threadToLock(33).jobFinished() - - sem.acquire(11) - } -} diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index c016c51171..3898583275 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -22,12 +22,15 @@ import scala.collection.mutable import com.esotericsoftware.kryo.Kryo import org.scalatest.FunSuite -import org.apache.spark.SharedSparkContext +import org.apache.spark.{SparkConf, SharedSparkContext} import org.apache.spark.serializer.KryoTest._ class KryoSerializerSuite extends FunSuite with SharedSparkContext { + conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + conf.set("spark.kryo.registrator", classOf[MyRegistrator].getName) + test("basic types") { - val ser = (new KryoSerializer).newInstance() + val ser = new KryoSerializer(conf).newInstance() def check[T](t: T) { assert(ser.deserialize[T](ser.serialize(t)) === t) } @@ -57,7 +60,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { } test("pairs") { - val ser = (new KryoSerializer).newInstance() + val ser = new KryoSerializer(conf).newInstance() def check[T](t: T) { assert(ser.deserialize[T](ser.serialize(t)) === t) } @@ -81,7 +84,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { } test("Scala data structures") { - val ser = (new KryoSerializer).newInstance() + val ser = new KryoSerializer(conf).newInstance() def check[T](t: T) { assert(ser.deserialize[T](ser.serialize(t)) === t) } @@ -104,7 +107,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { } test("ranges") { - val ser = (new KryoSerializer).newInstance() + val ser = new KryoSerializer(conf).newInstance() def check[T](t: T) { assert(ser.deserialize[T](ser.serialize(t)) === t) // Check that very long ranges don't get written one element at a time @@ -125,9 +128,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { } test("custom registrator") { - System.setProperty("spark.kryo.registrator", classOf[MyRegistrator].getName) - - val ser = (new KryoSerializer).newInstance() + val ser = new KryoSerializer(conf).newInstance() def check[T](t: T) { assert(ser.deserialize[T](ser.serialize(t)) === t) } @@ -172,6 +173,10 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { assert (sc.parallelize( Array((1, 11), (2, 22), (3, 33)) ).collect().head === (1, 11)) } + test("kryo with SerializableHyperLogLog") { + assert(sc.parallelize( Array(1, 2, 3, 2, 3, 3, 2, 3, 1) ).countApproxDistinct(0.01) === 3) + } + test("kryo with reduce") { val control = 1 :: 2 :: Nil val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_)) @@ -186,18 +191,6 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { .fold(new ClassWithoutNoArgConstructor(10))((t1, t2) => new ClassWithoutNoArgConstructor(t1.x + t2.x)).x assert(10 + control.sum === result) } - - override def beforeAll() { - System.setProperty("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - System.setProperty("spark.kryo.registrator", classOf[MyRegistrator].getName) - super.beforeAll() - } - - override def afterAll() { - super.afterAll() - System.clearProperty("spark.kryo.registrator") - System.clearProperty("spark.serializer") - } } object KryoTest { diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 5b4d63b954..a0fc3445be 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -31,8 +31,10 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.util.{SizeEstimator, Utils, AkkaUtils, ByteBufferInputStream} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} +import org.apache.spark.{SparkConf, SparkContext} class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodTester { + private val conf = new SparkConf(false) var store: BlockManager = null var store2: BlockManager = null var actorSystem: ActorSystem = null @@ -42,30 +44,31 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT var oldHeartBeat: String = null // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test - System.setProperty("spark.kryoserializer.buffer.mb", "1") - val serializer = new KryoSerializer + conf.set("spark.kryoserializer.buffer.mb", "1") + val serializer = new KryoSerializer(conf) // Implicitly convert strings to BlockIds for test clarity. implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value) def rdd(rddId: Int, splitId: Int) = RDDBlockId(rddId, splitId) before { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0) + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0, conf = conf) this.actorSystem = actorSystem - System.setProperty("spark.driver.port", boundPort.toString) - System.setProperty("spark.hostPort", "localhost:" + boundPort) + conf.set("spark.driver.port", boundPort.toString) + conf.set("spark.hostPort", "localhost:" + boundPort) master = new BlockManagerMaster( - Left(actorSystem.actorOf(Props(new BlockManagerMasterActor(true))))) + Left(actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf)))), conf) // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case - oldArch = System.setProperty("os.arch", "amd64") - oldOops = System.setProperty("spark.test.useCompressedOops", "true") - oldHeartBeat = System.setProperty("spark.storage.disableBlockManagerHeartBeat", "true") + System.setProperty("os.arch", "amd64") + conf.set("os.arch", "amd64") + conf.set("spark.test.useCompressedOops", "true") + conf.set("spark.storage.disableBlockManagerHeartBeat", "true") val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() // Set some value ... - System.setProperty("spark.hostPort", Utils.localHostName() + ":" + 1111) + conf.set("spark.hostPort", Utils.localHostName() + ":" + 1111) } after { @@ -86,13 +89,13 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT master = null if (oldArch != null) { - System.setProperty("os.arch", oldArch) + conf.set("os.arch", oldArch) } else { System.clearProperty("os.arch") } if (oldOops != null) { - System.setProperty("spark.test.useCompressedOops", oldOops) + conf.set("spark.test.useCompressedOops", oldOops) } else { System.clearProperty("spark.test.useCompressedOops") } @@ -133,7 +136,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("master + 1 manager interaction") { - store = new BlockManager("<driver>", actorSystem, master, serializer, 2000) + store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -163,8 +166,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("master + 2 managers interaction") { - store = new BlockManager("exec1", actorSystem, master, serializer, 2000) - store2 = new BlockManager("exec2", actorSystem, master, new KryoSerializer, 2000) + store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf) + store2 = new BlockManager("exec2", actorSystem, master, new KryoSerializer(conf), 2000, conf) val peers = master.getPeers(store.blockManagerId, 1) assert(peers.size === 1, "master did not return the other manager as a peer") @@ -179,7 +182,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("removing block") { - store = new BlockManager("<driver>", actorSystem, master, serializer, 2000) + store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -227,7 +230,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("removing rdd") { - store = new BlockManager("<driver>", actorSystem, master, serializer, 2000) + store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -261,7 +264,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT test("reregistration on heart beat") { val heartBeat = PrivateMethod[Unit]('heartBeat) - store = new BlockManager("<driver>", actorSystem, master, serializer, 2000) + store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf) val a1 = new Array[Byte](400) store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) @@ -277,7 +280,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("reregistration on block update") { - store = new BlockManager("<driver>", actorSystem, master, serializer, 2000) + store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) @@ -296,7 +299,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT test("reregistration doesn't dead lock") { val heartBeat = PrivateMethod[Unit]('heartBeat) - store = new BlockManager("<driver>", actorSystem, master, serializer, 2000) + store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf) val a1 = new Array[Byte](400) val a2 = List(new Array[Byte](400)) @@ -333,7 +336,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU storage") { - store = new BlockManager("<driver>", actorSystem, master, serializer, 1200) + store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -352,7 +355,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU storage with serialization") { - store = new BlockManager("<driver>", actorSystem, master, serializer, 1200) + store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -371,7 +374,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU for partitions of same RDD") { - store = new BlockManager("<driver>", actorSystem, master, serializer, 1200) + store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -390,7 +393,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU for partitions of multiple RDDs") { - store = new BlockManager("<driver>", actorSystem, master, serializer, 1200) + store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf) store.putSingle(rdd(0, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY) store.putSingle(rdd(0, 2), new Array[Byte](400), StorageLevel.MEMORY_ONLY) store.putSingle(rdd(1, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY) @@ -413,7 +416,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("on-disk storage") { - store = new BlockManager("<driver>", actorSystem, master, serializer, 1200) + store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -426,7 +429,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage") { - store = new BlockManager("<driver>", actorSystem, master, serializer, 1200) + store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -441,7 +444,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage with getLocalBytes") { - store = new BlockManager("<driver>", actorSystem, master, serializer, 1200) + store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -456,7 +459,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage with serialization") { - store = new BlockManager("<driver>", actorSystem, master, serializer, 1200) + store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -471,7 +474,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage with serialization and getLocalBytes") { - store = new BlockManager("<driver>", actorSystem, master, serializer, 1200) + store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -486,7 +489,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("LRU with mixed storage levels") { - store = new BlockManager("<driver>", actorSystem, master, serializer, 1200) + store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -511,7 +514,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU with streams") { - store = new BlockManager("<driver>", actorSystem, master, serializer, 1200) + store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf) val list1 = List(new Array[Byte](200), new Array[Byte](200)) val list2 = List(new Array[Byte](200), new Array[Byte](200)) val list3 = List(new Array[Byte](200), new Array[Byte](200)) @@ -535,7 +538,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("LRU with mixed storage levels and streams") { - store = new BlockManager("<driver>", actorSystem, master, serializer, 1200) + store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf) val list1 = List(new Array[Byte](200), new Array[Byte](200)) val list2 = List(new Array[Byte](200), new Array[Byte](200)) val list3 = List(new Array[Byte](200), new Array[Byte](200)) @@ -581,7 +584,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("overly large block") { - store = new BlockManager("<driver>", actorSystem, master, serializer, 500) + store = new BlockManager("<driver>", actorSystem, master, serializer, 500, conf) store.putSingle("a1", new Array[Byte](1000), StorageLevel.MEMORY_ONLY) assert(store.getSingle("a1") === None, "a1 was in store") store.putSingle("a2", new Array[Byte](1000), StorageLevel.MEMORY_AND_DISK) @@ -591,53 +594,53 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT test("block compression") { try { - System.setProperty("spark.shuffle.compress", "true") - store = new BlockManager("exec1", actorSystem, master, serializer, 2000) + conf.set("spark.shuffle.compress", "true") + store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf) store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) <= 100, "shuffle_0_0_0 was not compressed") store.stop() store = null - System.setProperty("spark.shuffle.compress", "false") - store = new BlockManager("exec2", actorSystem, master, serializer, 2000) + conf.set("spark.shuffle.compress", "false") + store = new BlockManager("exec2", actorSystem, master, serializer, 2000, conf) store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) >= 1000, "shuffle_0_0_0 was compressed") store.stop() store = null - System.setProperty("spark.broadcast.compress", "true") - store = new BlockManager("exec3", actorSystem, master, serializer, 2000) + conf.set("spark.broadcast.compress", "true") + store = new BlockManager("exec3", actorSystem, master, serializer, 2000, conf) store.putSingle(BroadcastBlockId(0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(BroadcastBlockId(0)) <= 100, "broadcast_0 was not compressed") store.stop() store = null - System.setProperty("spark.broadcast.compress", "false") - store = new BlockManager("exec4", actorSystem, master, serializer, 2000) + conf.set("spark.broadcast.compress", "false") + store = new BlockManager("exec4", actorSystem, master, serializer, 2000, conf) store.putSingle(BroadcastBlockId(0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(BroadcastBlockId(0)) >= 1000, "broadcast_0 was compressed") store.stop() store = null - System.setProperty("spark.rdd.compress", "true") - store = new BlockManager("exec5", actorSystem, master, serializer, 2000) + conf.set("spark.rdd.compress", "true") + store = new BlockManager("exec5", actorSystem, master, serializer, 2000, conf) store.putSingle(rdd(0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(rdd(0, 0)) <= 100, "rdd_0_0 was not compressed") store.stop() store = null - System.setProperty("spark.rdd.compress", "false") - store = new BlockManager("exec6", actorSystem, master, serializer, 2000) + conf.set("spark.rdd.compress", "false") + store = new BlockManager("exec6", actorSystem, master, serializer, 2000, conf) store.putSingle(rdd(0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(rdd(0, 0)) >= 1000, "rdd_0_0 was compressed") store.stop() store = null // Check that any other block types are also kept uncompressed - store = new BlockManager("exec7", actorSystem, master, serializer, 2000) + store = new BlockManager("exec7", actorSystem, master, serializer, 2000, conf) store.putSingle("other_block", new Array[Byte](1000), StorageLevel.MEMORY_ONLY) assert(store.memoryStore.getSize("other_block") >= 1000, "other_block was compressed") store.stop() @@ -651,7 +654,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT test("block store put failure") { // Use Java serializer so we can create an unserializable error. - store = new BlockManager("<driver>", actorSystem, master, new JavaSerializer, 1200) + store = new BlockManager("<driver>", actorSystem, master, new JavaSerializer(conf), 1200, conf) // The put should fail since a1 is not serializable. class UnserializableClass diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index 070982e798..af4b31d53c 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -17,15 +17,18 @@ package org.apache.spark.storage -import java.io.{FileWriter, File} +import java.io.{File, FileWriter} import scala.collection.mutable import com.google.common.io.Files +import org.apache.spark.SparkConf import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} +import scala.util.Try +import akka.actor.{Props, ActorSelection, ActorSystem} -class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll { - +class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach { + private val testConf = new SparkConf(false) val rootDir0 = Files.createTempDir() rootDir0.deleteOnExit() val rootDir1 = Files.createTempDir() @@ -35,21 +38,16 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before // This suite focuses primarily on consolidation features, // so we coerce consolidation if not already enabled. - val consolidateProp = "spark.shuffle.consolidateFiles" - val oldConsolidate = Option(System.getProperty(consolidateProp)) - System.setProperty(consolidateProp, "true") + testConf.set("spark.shuffle.consolidateFiles", "true") val shuffleBlockManager = new ShuffleBlockManager(null) { + override def conf = testConf.clone var idToSegmentMap = mutable.Map[ShuffleBlockId, FileSegment]() override def getBlockLocation(id: ShuffleBlockId) = idToSegmentMap(id) } var diskBlockManager: DiskBlockManager = _ - override def afterAll() { - oldConsolidate.map(c => System.setProperty(consolidateProp, c)) - } - override def beforeEach() { diskBlockManager = new DiskBlockManager(shuffleBlockManager, rootDirs) shuffleBlockManager.idToSegmentMap.clear() diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala new file mode 100644 index 0000000000..67a57a0e7f --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.jobs + +import org.scalatest.FunSuite +import org.apache.spark.scheduler._ +import org.apache.spark.{LocalSparkContext, SparkContext, Success} +import org.apache.spark.scheduler.SparkListenerTaskStart +import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics} + +class JobProgressListenerSuite extends FunSuite with LocalSparkContext { + test("test executor id to summary") { + val sc = new SparkContext("local", "test") + val listener = new JobProgressListener(sc) + val taskMetrics = new TaskMetrics() + val shuffleReadMetrics = new ShuffleReadMetrics() + + // nothing in it + assert(listener.stageIdToExecutorSummaries.size == 0) + + // finish this task, should get updated shuffleRead + shuffleReadMetrics.remoteBytesRead = 1000 + taskMetrics.shuffleReadMetrics = Some(shuffleReadMetrics) + var taskInfo = new TaskInfo(1234L, 0, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL) + taskInfo.finishTime = 1 + listener.onTaskEnd(new SparkListenerTaskEnd( + new ShuffleMapTask(0, null, null, 0, null), Success, taskInfo, taskMetrics)) + assert(listener.stageIdToExecutorSummaries.getOrElse(0, fail()).getOrElse("exe-1", fail()) + .shuffleRead == 1000) + + // finish a task with unknown executor-id, nothing should happen + taskInfo = new TaskInfo(1234L, 0, 1000L, "exe-unknown", "host1", TaskLocality.NODE_LOCAL) + taskInfo.finishTime = 1 + listener.onTaskEnd(new SparkListenerTaskEnd( + new ShuffleMapTask(0, null, null, 0, null), Success, taskInfo, taskMetrics)) + assert(listener.stageIdToExecutorSummaries.size == 1) + + // finish this task, should get updated duration + shuffleReadMetrics.remoteBytesRead = 1000 + taskMetrics.shuffleReadMetrics = Some(shuffleReadMetrics) + taskInfo = new TaskInfo(1235L, 0, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL) + taskInfo.finishTime = 1 + listener.onTaskEnd(new SparkListenerTaskEnd( + new ShuffleMapTask(0, null, null, 0, null), Success, taskInfo, taskMetrics)) + assert(listener.stageIdToExecutorSummaries.getOrElse(0, fail()).getOrElse("exe-1", fail()) + .shuffleRead == 2000) + + // finish this task, should get updated duration + shuffleReadMetrics.remoteBytesRead = 1000 + taskMetrics.shuffleReadMetrics = Some(shuffleReadMetrics) + taskInfo = new TaskInfo(1236L, 0, 0L, "exe-2", "host1", TaskLocality.NODE_LOCAL) + taskInfo.finishTime = 1 + listener.onTaskEnd(new SparkListenerTaskEnd( + new ShuffleMapTask(0, null, null, 0, null), Success, taskInfo, taskMetrics)) + assert(listener.stageIdToExecutorSummaries.getOrElse(0, fail()).getOrElse("exe-2", fail()) + .shuffleRead == 1000) + } +} diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala index 5aff26f9fc..11ebdc352b 100644 --- a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.util import org.scalatest.FunSuite import org.scalatest.BeforeAndAfterAll import org.scalatest.PrivateMethodTester +import org.apache.spark.SparkContext class DummyClass1 {} @@ -139,7 +140,6 @@ class SizeEstimatorSuite test("64-bit arch with no compressed oops") { val arch = System.setProperty("os.arch", "amd64") val oops = System.setProperty("spark.test.useCompressedOops", "false") - val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() |