aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore1
-rw-r--r--bin/compute-classpath.cmd54
-rwxr-xr-xbin/compute-classpath.sh91
-rw-r--r--conf/fairscheduler.xml.template15
-rwxr-xr-xconf/spark-env.sh.template18
-rw-r--r--core/src/main/java/spark/network/netty/FileClient.java28
-rw-r--r--core/src/main/java/spark/network/netty/FileClientHandler.java8
-rw-r--r--core/src/main/scala/spark/BlockStoreShuffleFetcher.scala1
-rw-r--r--core/src/main/scala/spark/PairRDDFunctions.scala15
-rw-r--r--core/src/main/scala/spark/RDD.scala63
-rw-r--r--core/src/main/scala/spark/SparkContext.scala29
-rw-r--r--core/src/main/scala/spark/SparkEnv.scala13
-rw-r--r--core/src/main/scala/spark/Utils.scala106
-rw-r--r--core/src/main/scala/spark/api/python/PythonRDD.scala107
-rw-r--r--core/src/main/scala/spark/api/python/PythonWorkerFactory.scala113
-rw-r--r--core/src/main/scala/spark/deploy/master/Master.scala4
-rw-r--r--core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala30
-rw-r--r--core/src/main/scala/spark/executor/Executor.scala39
-rw-r--r--core/src/main/scala/spark/executor/TaskMetrics.scala12
-rw-r--r--core/src/main/scala/spark/network/netty/ShuffleCopier.scala65
-rw-r--r--core/src/main/scala/spark/rdd/PipedRDD.scala2
-rw-r--r--core/src/main/scala/spark/scheduler/DAGScheduler.scala18
-rw-r--r--core/src/main/scala/spark/scheduler/JobLogger.scala306
-rw-r--r--core/src/main/scala/spark/scheduler/SparkListener.scala50
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala11
-rw-r--r--core/src/main/scala/spark/storage/BlockFetcherIterator.scala56
-rw-r--r--core/src/main/scala/spark/storage/DiskStore.scala53
-rw-r--r--core/src/main/scala/spark/storage/ShuffleBlockManager.scala2
-rw-r--r--core/src/main/scala/spark/util/AkkaUtils.scala4
-rw-r--r--core/src/test/resources/fairscheduler.xml1
-rw-r--r--core/src/test/scala/spark/CheckpointSuite.scala10
-rw-r--r--core/src/test/scala/spark/PairRDDFunctionsSuite.scala287
-rw-r--r--core/src/test/scala/spark/PartitioningSuite.scala19
-rw-r--r--core/src/test/scala/spark/PipedRDDSuite.scala16
-rw-r--r--core/src/test/scala/spark/RDDSuite.scala120
-rw-r--r--core/src/test/scala/spark/SharedSparkContext.scala25
-rw-r--r--core/src/test/scala/spark/ShuffleSuite.scala299
-rw-r--r--core/src/test/scala/spark/SizeEstimatorSuite.scala72
-rw-r--r--core/src/test/scala/spark/SortingSuite.scala23
-rw-r--r--core/src/test/scala/spark/UnpersistSuite.scala30
-rw-r--r--core/src/test/scala/spark/UtilsSuite.scala53
-rw-r--r--core/src/test/scala/spark/ZippedPartitionsSuite.scala3
-rw-r--r--core/src/test/scala/spark/scheduler/JobLoggerSuite.scala104
-rw-r--r--core/src/test/scala/spark/scheduler/SparkListenerSuite.scala2
-rw-r--r--docs/configuration.md42
-rw-r--r--docs/ec2-scripts.md5
-rw-r--r--docs/python-programming-guide.md12
-rw-r--r--docs/scala-programming-guide.md10
-rw-r--r--docs/tuning.md6
-rw-r--r--examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala2
-rw-r--r--examples/src/main/scala/spark/streaming/examples/StatefulNetworkWordCount.scala50
-rw-r--r--examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdCMS.scala9
-rw-r--r--examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdHLL.scala9
-rw-r--r--examples/src/main/scala/spark/streaming/examples/TwitterPopularTags.scala9
-rw-r--r--mllib/data/als/test.data16
-rwxr-xr-xmllib/data/lr-data/random.data1000
-rw-r--r--mllib/data/ridge-data/lpsa.data67
-rw-r--r--mllib/src/main/scala/spark/mllib/clustering/KMeans.scala317
-rw-r--r--mllib/src/main/scala/spark/mllib/clustering/KMeansModel.scala27
-rw-r--r--mllib/src/main/scala/spark/mllib/clustering/LocalKMeans.scala88
-rw-r--r--mllib/src/main/scala/spark/mllib/optimization/Gradient.scala33
-rw-r--r--mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala62
-rw-r--r--mllib/src/main/scala/spark/mllib/optimization/Updater.scala27
-rw-r--r--mllib/src/main/scala/spark/mllib/recommendation/ALS.scala389
-rw-r--r--mllib/src/main/scala/spark/mllib/recommendation/MatrixFactorizationModel.scala23
-rw-r--r--mllib/src/main/scala/spark/mllib/regression/LogisticRegression.scala158
-rw-r--r--mllib/src/main/scala/spark/mllib/regression/LogisticRegressionGenerator.scala41
-rw-r--r--mllib/src/main/scala/spark/mllib/regression/Regression.scala21
-rw-r--r--mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala183
-rw-r--r--mllib/src/main/scala/spark/mllib/regression/RidgeRegressionGenerator.scala55
-rw-r--r--mllib/src/main/scala/spark/mllib/util/MLUtils.scala95
-rw-r--r--mllib/src/test/resources/log4j.properties11
-rw-r--r--mllib/src/test/scala/spark/mllib/clustering/KMeansSuite.scala153
-rw-r--r--mllib/src/test/scala/spark/mllib/recommendation/ALSSuite.scala80
-rw-r--r--mllib/src/test/scala/spark/mllib/regression/LogisticRegressionSuite.scala57
-rw-r--r--mllib/src/test/scala/spark/mllib/regression/RidgeRegressionSuite.scala47
-rw-r--r--pom.xml13
-rw-r--r--project/SparkBuild.scala31
-rw-r--r--project/plugins.sbt2
-rw-r--r--python/pyspark/daemon.py164
-rw-r--r--python/pyspark/serializers.py4
-rw-r--r--python/pyspark/tests.py43
-rw-r--r--python/pyspark/worker.py55
-rw-r--r--repl/src/main/scala/spark/repl/SparkILoop.scala11
-rw-r--r--repl/src/test/scala/spark/repl/ReplSuite.scala38
-rwxr-xr-xrun94
-rw-r--r--run2.cmd46
-rw-r--r--streaming/src/main/scala/spark/streaming/DStream.scala9
-rw-r--r--streaming/src/main/scala/spark/streaming/StreamingContext.scala60
-rw-r--r--streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala123
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala94
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala20
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala30
-rw-r--r--streaming/src/main/scala/spark/streaming/receivers/ActorReceiver.scala7
-rw-r--r--streaming/src/test/java/spark/streaming/JavaAPISuite.java14
-rw-r--r--streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala4
-rw-r--r--streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala11
97 files changed, 5314 insertions, 1041 deletions
diff --git a/.gitignore b/.gitignore
index b87fc1ee79..ae39c52b11 100644
--- a/.gitignore
+++ b/.gitignore
@@ -36,4 +36,5 @@ streaming-tests.log
dependency-reduced-pom.xml
.ensime
.ensime_lucene
+checkpoint
derby.log
diff --git a/bin/compute-classpath.cmd b/bin/compute-classpath.cmd
new file mode 100644
index 0000000000..44826f339c
--- /dev/null
+++ b/bin/compute-classpath.cmd
@@ -0,0 +1,54 @@
+@echo off
+
+rem This script computes Spark's classpath and prints it to stdout; it's used by both the "run"
+rem script and the ExecutorRunner in standalone cluster mode.
+
+set SCALA_VERSION=2.9.3
+
+rem Figure out where the Spark framework is installed
+set FWDIR=%~dp0..\
+
+rem Load environment variables from conf\spark-env.cmd, if it exists
+if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd"
+
+set CORE_DIR=%FWDIR%core
+set REPL_DIR=%FWDIR%repl
+set EXAMPLES_DIR=%FWDIR%examples
+set BAGEL_DIR=%FWDIR%bagel
+set MLLIB_DIR=%FWDIR%mllib
+set STREAMING_DIR=%FWDIR%streaming
+set PYSPARK_DIR=%FWDIR%python
+
+rem Build up classpath
+set CLASSPATH=%SPARK_CLASSPATH%;%MESOS_CLASSPATH%;%FWDIR%conf;%CORE_DIR%\target\scala-%SCALA_VERSION%\classes
+set CLASSPATH=%CLASSPATH%;%CORE_DIR%\target\scala-%SCALA_VERSION%\test-classes;%CORE_DIR%\src\main\resources
+set CLASSPATH=%CLASSPATH%;%STREAMING_DIR%\target\scala-%SCALA_VERSION%\classes;%STREAMING_DIR%\target\scala-%SCALA_VERSION%\test-classes
+set CLASSPATH=%CLASSPATH%;%STREAMING_DIR%\lib\org\apache\kafka\kafka\0.7.2-spark\*
+set CLASSPATH=%CLASSPATH%;%REPL_DIR%\target\scala-%SCALA_VERSION%\classes;%EXAMPLES_DIR%\target\scala-%SCALA_VERSION%\classes
+set CLASSPATH=%CLASSPATH%;%FWDIR%lib_managed\jars\*
+set CLASSPATH=%CLASSPATH%;%FWDIR%lib_managed\bundles\*
+set CLASSPATH=%CLASSPATH%;%FWDIR%repl\lib\*
+set CLASSPATH=%CLASSPATH%;%FWDIR%python\lib\*
+set CLASSPATH=%CLASSPATH%;%BAGEL_DIR%\target\scala-%SCALA_VERSION%\classes
+set CLASSPATH=%CLASSPATH%;%MLLIB_DIR%\target\scala-%SCALA_VERSION%\classes
+
+rem Add hadoop conf dir - else FileSystem.*, etc fail
+rem Note, this assumes that there is either a HADOOP_CONF_DIR or YARN_CONF_DIR which hosts
+rem the configurtion files.
+if "x%HADOOP_CONF_DIR%"=="x" goto no_hadoop_conf_dir
+ set CLASSPATH=%CLASSPATH%;%HADOOP_CONF_DIR%
+:no_hadoop_conf_dir
+
+if "x%YARN_CONF_DIR%"=="x" goto no_yarn_conf_dir
+ set CLASSPATH=%CLASSPATH%;%YARN_CONF_DIR%
+:no_yarn_conf_dir
+
+rem Add Scala standard library
+set CLASSPATH=%CLASSPATH%;%SCALA_HOME%\lib\scala-library.jar;%SCALA_HOME%\lib\scala-compiler.jar;%SCALA_HOME%\lib\jline.jar
+
+rem A bit of a hack to allow calling this script within run2.cmd without seeing output
+if "%DONT_PRINT_CLASSPATH%"=="1" goto exit
+
+echo %CLASSPATH%
+
+:exit
diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh
new file mode 100755
index 0000000000..75c58d1181
--- /dev/null
+++ b/bin/compute-classpath.sh
@@ -0,0 +1,91 @@
+#!/bin/bash
+
+# This script computes Spark's classpath and prints it to stdout; it's used by both the "run"
+# script and the ExecutorRunner in standalone cluster mode.
+
+SCALA_VERSION=2.9.3
+
+# Figure out where Spark is installed
+FWDIR="$(cd `dirname $0`/..; pwd)"
+
+# Load environment variables from conf/spark-env.sh, if it exists
+if [ -e $FWDIR/conf/spark-env.sh ] ; then
+ . $FWDIR/conf/spark-env.sh
+fi
+
+CORE_DIR="$FWDIR/core"
+REPL_DIR="$FWDIR/repl"
+REPL_BIN_DIR="$FWDIR/repl-bin"
+EXAMPLES_DIR="$FWDIR/examples"
+BAGEL_DIR="$FWDIR/bagel"
+MLLIB_DIR="$FWDIR/mllib"
+STREAMING_DIR="$FWDIR/streaming"
+PYSPARK_DIR="$FWDIR/python"
+
+# Build up classpath
+CLASSPATH="$SPARK_CLASSPATH"
+CLASSPATH="$CLASSPATH:$FWDIR/conf"
+CLASSPATH="$CLASSPATH:$CORE_DIR/target/scala-$SCALA_VERSION/classes"
+if [ -n "$SPARK_TESTING" ] ; then
+ CLASSPATH="$CLASSPATH:$CORE_DIR/target/scala-$SCALA_VERSION/test-classes"
+ CLASSPATH="$CLASSPATH:$STREAMING_DIR/target/scala-$SCALA_VERSION/test-classes"
+fi
+CLASSPATH="$CLASSPATH:$CORE_DIR/src/main/resources"
+CLASSPATH="$CLASSPATH:$REPL_DIR/target/scala-$SCALA_VERSION/classes"
+CLASSPATH="$CLASSPATH:$EXAMPLES_DIR/target/scala-$SCALA_VERSION/classes"
+CLASSPATH="$CLASSPATH:$STREAMING_DIR/target/scala-$SCALA_VERSION/classes"
+CLASSPATH="$CLASSPATH:$STREAMING_DIR/lib/org/apache/kafka/kafka/0.7.2-spark/*" # <-- our in-project Kafka Jar
+if [ -e "$FWDIR/lib_managed" ]; then
+ CLASSPATH="$CLASSPATH:$FWDIR/lib_managed/jars/*"
+ CLASSPATH="$CLASSPATH:$FWDIR/lib_managed/bundles/*"
+fi
+CLASSPATH="$CLASSPATH:$REPL_DIR/lib/*"
+# Add the shaded JAR for Maven builds
+if [ -e $REPL_BIN_DIR/target ]; then
+ for jar in `find "$REPL_BIN_DIR/target" -name 'spark-repl-*-shaded-hadoop*.jar'`; do
+ CLASSPATH="$CLASSPATH:$jar"
+ done
+ # The shaded JAR doesn't contain examples, so include those separately
+ EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/spark-examples"*[0-9T].jar`
+ CLASSPATH+=":$EXAMPLES_JAR"
+fi
+CLASSPATH="$CLASSPATH:$BAGEL_DIR/target/scala-$SCALA_VERSION/classes"
+CLASSPATH="$CLASSPATH:$MLLIB_DIR/target/scala-$SCALA_VERSION/classes"
+for jar in `find $PYSPARK_DIR/lib -name '*jar'`; do
+ CLASSPATH="$CLASSPATH:$jar"
+done
+
+# Figure out the JAR file that our examples were packaged into. This includes a bit of a hack
+# to avoid the -sources and -doc packages that are built by publish-local.
+if [ -e "$EXAMPLES_DIR/target/scala-$SCALA_VERSION/spark-examples"*[0-9T].jar ]; then
+ # Use the JAR from the SBT build
+ export SPARK_EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/scala-$SCALA_VERSION/spark-examples"*[0-9T].jar`
+fi
+if [ -e "$EXAMPLES_DIR/target/spark-examples"*[0-9T].jar ]; then
+ # Use the JAR from the Maven build
+ export SPARK_EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/spark-examples"*[0-9T].jar`
+fi
+
+# Add hadoop conf dir - else FileSystem.*, etc fail !
+# Note, this assumes that there is either a HADOOP_CONF_DIR or YARN_CONF_DIR which hosts
+# the configurtion files.
+if [ "x" != "x$HADOOP_CONF_DIR" ]; then
+ CLASSPATH="$CLASSPATH:$HADOOP_CONF_DIR"
+fi
+if [ "x" != "x$YARN_CONF_DIR" ]; then
+ CLASSPATH="$CLASSPATH:$YARN_CONF_DIR"
+fi
+
+# Add Scala standard library
+if [ -z "$SCALA_LIBRARY_PATH" ]; then
+ if [ -z "$SCALA_HOME" ]; then
+ echo "SCALA_HOME is not set" >&2
+ exit 1
+ fi
+ SCALA_LIBRARY_PATH="$SCALA_HOME/lib"
+fi
+CLASSPATH="$CLASSPATH:$SCALA_LIBRARY_PATH/scala-library.jar"
+CLASSPATH="$CLASSPATH:$SCALA_LIBRARY_PATH/scala-compiler.jar"
+CLASSPATH="$CLASSPATH:$SCALA_LIBRARY_PATH/jline.jar"
+
+echo "$CLASSPATH"
diff --git a/conf/fairscheduler.xml.template b/conf/fairscheduler.xml.template
new file mode 100644
index 0000000000..04a6b418dc
--- /dev/null
+++ b/conf/fairscheduler.xml.template
@@ -0,0 +1,15 @@
+<?xml version="1.0"?>
+<allocations>
+<pool name="production">
+ <minShare>2</minShare>
+ <weight>1</weight>
+ <schedulingMode>FAIR</schedulingMode>
+</pool>
+<pool name="test">
+ <minShare>3</minShare>
+ <weight>2</weight>
+ <schedulingMode>FIFO</schedulingMode>
+</pool>
+<pool name="data">
+</pool>
+</allocations>
diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template
index 37565ca827..b8936314ec 100755
--- a/conf/spark-env.sh.template
+++ b/conf/spark-env.sh.template
@@ -3,8 +3,10 @@
# This file contains environment variables required to run Spark. Copy it as
# spark-env.sh and edit that to configure Spark for your site. At a minimum,
# the following two variables should be set:
-# - MESOS_NATIVE_LIBRARY, to point to your Mesos native library (libmesos.so)
-# - SCALA_HOME, to point to your Scala installation
+# - SCALA_HOME, to point to your Scala installation, or SCALA_LIBRARY_PATH to
+# point to the directory for Scala library JARs (if you install Scala as a
+# Debian or RPM package, these are in a separate path, often /usr/share/java)
+# - MESOS_NATIVE_LIBRARY, to point to your libmesos.so if you use Mesos
#
# If using the standalone deploy mode, you can also set variables for it:
# - SPARK_MASTER_IP, to bind the master to a different IP address
@@ -12,14 +14,6 @@
# - SPARK_WORKER_CORES, to set the number of cores to use on this machine
# - SPARK_WORKER_MEMORY, to set how much memory to use (e.g. 1000m, 2g)
# - SPARK_WORKER_PORT / SPARK_WORKER_WEBUI_PORT
-# - SPARK_WORKER_INSTANCES, to set the number of worker instances/processes to be spawned on every slave machine
-#
-# Finally, Spark also relies on the following variables, but these can be set
-# on just the *master* (i.e. in your driver program), and will automatically
-# be propagated to workers:
-# - SPARK_MEM, to change the amount of memory used per node (this should
-# be in the same format as the JVM's -Xmx option, e.g. 300m or 1g)
-# - SPARK_CLASSPATH, to add elements to Spark's classpath
-# - SPARK_JAVA_OPTS, to add JVM options
-# - SPARK_LIBRARY_PATH, to add extra search paths for native libraries.
+# - SPARK_WORKER_INSTANCES, to set the number of worker instances/processes
+# to be spawned on every slave machine
diff --git a/core/src/main/java/spark/network/netty/FileClient.java b/core/src/main/java/spark/network/netty/FileClient.java
index 3a62dacbc8..a4bb4bc701 100644
--- a/core/src/main/java/spark/network/netty/FileClient.java
+++ b/core/src/main/java/spark/network/netty/FileClient.java
@@ -8,15 +8,20 @@ import io.netty.channel.ChannelOption;
import io.netty.channel.oio.OioEventLoopGroup;
import io.netty.channel.socket.oio.OioSocketChannel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
class FileClient {
+ private Logger LOG = LoggerFactory.getLogger(this.getClass().getName());
private FileClientHandler handler = null;
private Channel channel = null;
private Bootstrap bootstrap = null;
+ private int connectTimeout = 60*1000; // 1 min
- public FileClient(FileClientHandler handler) {
+ public FileClient(FileClientHandler handler, int connectTimeout) {
this.handler = handler;
+ this.connectTimeout = connectTimeout;
}
public void init() {
@@ -25,25 +30,10 @@ class FileClient {
.channel(OioSocketChannel.class)
.option(ChannelOption.SO_KEEPALIVE, true)
.option(ChannelOption.TCP_NODELAY, true)
+ .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeout)
.handler(new FileClientChannelInitializer(handler));
}
- public static final class ChannelCloseListener implements ChannelFutureListener {
- private FileClient fc = null;
-
- public ChannelCloseListener(FileClient fc){
- this.fc = fc;
- }
-
- @Override
- public void operationComplete(ChannelFuture future) {
- if (fc.bootstrap!=null){
- fc.bootstrap.shutdown();
- fc.bootstrap = null;
- }
- }
- }
-
public void connect(String host, int port) {
try {
// Start the connection attempt.
@@ -58,8 +48,8 @@ class FileClient {
public void waitForClose() {
try {
channel.closeFuture().sync();
- } catch (InterruptedException e){
- e.printStackTrace();
+ } catch (InterruptedException e) {
+ LOG.warn("FileClient interrupted", e);
}
}
diff --git a/core/src/main/java/spark/network/netty/FileClientHandler.java b/core/src/main/java/spark/network/netty/FileClientHandler.java
index 2069dee5ca..9fc9449827 100644
--- a/core/src/main/java/spark/network/netty/FileClientHandler.java
+++ b/core/src/main/java/spark/network/netty/FileClientHandler.java
@@ -9,7 +9,14 @@ abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter {
private FileHeader currentHeader = null;
+ private volatile boolean handlerCalled = false;
+
+ public boolean isComplete() {
+ return handlerCalled;
+ }
+
public abstract void handle(ChannelHandlerContext ctx, ByteBuf in, FileHeader header);
+ public abstract void handleError(String blockId);
@Override
public ByteBuf newInboundBuffer(ChannelHandlerContext ctx) {
@@ -26,6 +33,7 @@ abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter {
// get file
if(in.readableBytes() >= currentHeader.fileLen()) {
handle(ctx, in, currentHeader);
+ handlerCalled = true;
currentHeader = null;
ctx.close();
}
diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
index e1fb02157a..3239f4c385 100644
--- a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
+++ b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
@@ -58,6 +58,7 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
CompletionIterator[(K,V), Iterator[(K,V)]](itr, {
val shuffleMetrics = new ShuffleReadMetrics
+ shuffleMetrics.shuffleFinishTime = System.currentTimeMillis
shuffleMetrics.remoteFetchTime = blockFetcherItr.remoteFetchTime
shuffleMetrics.fetchWaitTime = blockFetcherItr.fetchWaitTime
shuffleMetrics.remoteBytesRead = blockFetcherItr.remoteBytesRead
diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index c9d698fc09..8b313c645f 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -1,5 +1,6 @@
package spark
+import java.nio.ByteBuffer
import java.util.{Date, HashMap => JHashMap}
import java.text.SimpleDateFormat
@@ -65,8 +66,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
throw new SparkException("Default partitioner cannot partition array keys.")
}
}
- val aggregator =
- new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
+ val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
if (self.partitioner == Some(partitioner)) {
self.mapPartitions(aggregator.combineValuesByKey(_), true)
} else if (mapSideCombine) {
@@ -98,7 +98,16 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
* list concatenation, 0 for addition, or 1 for multiplication.).
*/
def foldByKey(zeroValue: V, partitioner: Partitioner)(func: (V, V) => V): RDD[(K, V)] = {
- combineByKey[V]({v: V => func(zeroValue, v)}, func, func, partitioner)
+ // Serialize the zero value to a byte array so that we can get a new clone of it on each key
+ val zeroBuffer = SparkEnv.get.closureSerializer.newInstance().serialize(zeroValue)
+ val zeroArray = new Array[Byte](zeroBuffer.limit)
+ zeroBuffer.get(zeroArray)
+
+ // When deserializing, use a lazy val to create just one instance of the serializer per task
+ lazy val cachedSerializer = SparkEnv.get.closureSerializer.newInstance()
+ def createZero() = cachedSerializer.deserialize[V](ByteBuffer.wrap(zeroArray))
+
+ combineByKey[V]((v: V) => func(createZero(), v), func, func, partitioner)
}
/**
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index 223dcdc19d..106fb2960f 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -105,6 +105,9 @@ abstract class RDD[T: ClassManifest](
// Methods and fields available on all RDDs
// =======================================================================
+ /** The SparkContext that created this RDD. */
+ def sparkContext: SparkContext = sc
+
/** A unique ID for this RDD (within its SparkContext). */
val id: Int = sc.newRddId()
@@ -117,6 +120,14 @@ abstract class RDD[T: ClassManifest](
this
}
+ /** User-defined generator of this RDD*/
+ var generator = Utils.getCallSiteInfo.firstUserClass
+
+ /** Reset generator*/
+ def setGenerator(_generator: String) = {
+ generator = _generator
+ }
+
/**
* Set this RDD's storage level to persist its values across operations after the first time
* it is computed. This can only be used to assign a new storage level if the RDD does not
@@ -273,31 +284,35 @@ abstract class RDD[T: ClassManifest](
def takeSample(withReplacement: Boolean, num: Int, seed: Int): Array[T] = {
var fraction = 0.0
var total = 0
- val multiplier = 3.0
- val initialCount = count()
+ var multiplier = 3.0
+ var initialCount = this.count()
var maxSelected = 0
+ if (num < 0) {
+ throw new IllegalArgumentException("Negative number of elements requested")
+ }
+
if (initialCount > Integer.MAX_VALUE - 1) {
maxSelected = Integer.MAX_VALUE - 1
} else {
maxSelected = initialCount.toInt
}
- if (num > initialCount) {
+ if (num > initialCount && !withReplacement) {
total = maxSelected
- fraction = math.min(multiplier * (maxSelected + 1) / initialCount, 1.0)
- } else if (num < 0) {
- throw(new IllegalArgumentException("Negative number of elements requested"))
+ fraction = multiplier * (maxSelected + 1) / initialCount
} else {
- fraction = math.min(multiplier * (num + 1) / initialCount, 1.0)
+ fraction = multiplier * (num + 1) / initialCount
total = num
}
val rand = new Random(seed)
- var samples = this.sample(withReplacement, fraction, rand.nextInt).collect()
+ var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()
+ // If the first sample didn't turn out large enough, keep trying to take samples;
+ // this shouldn't happen often because we use a big multiplier for thei initial size
while (samples.length < total) {
- samples = this.sample(withReplacement, fraction, rand.nextInt).collect()
+ samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()
}
Utils.randomizeInPlace(samples, rand).take(total)
@@ -355,7 +370,7 @@ abstract class RDD[T: ClassManifest](
/**
* Return an RDD created by piping elements to a forked external process.
*/
- def pipe(command: String, env: Map[String, String]): RDD[String] =
+ def pipe(command: String, env: Map[String, String]): RDD[String] =
new PipedRDD(this, command, env)
@@ -366,24 +381,24 @@ abstract class RDD[T: ClassManifest](
* @param command command to run in forked process.
* @param env environment variables to set.
* @param printPipeContext Before piping elements, this function is called as an oppotunity
- * to pipe context data. Print line function (like out.println) will be
+ * to pipe context data. Print line function (like out.println) will be
* passed as printPipeContext's parameter.
- * @param printPipeContext Use this function to customize how to pipe elements. This function
- * will be called with each RDD element as the 1st parameter, and the
- * print line function (like out.println()) as the 2nd parameter.
- * An example of pipe the RDD data of groupBy() in a streaming way,
- * instead of constructing a huge String to concat all the elements:
- * def printRDDElement(record:(String, Seq[String]), f:String=>Unit) =
- * for (e <- record._2){f(e)}
+ * @param printRDDElement Use this function to customize how to pipe elements. This function
+ * will be called with each RDD element as the 1st parameter, and the
+ * print line function (like out.println()) as the 2nd parameter.
+ * An example of pipe the RDD data of groupBy() in a streaming way,
+ * instead of constructing a huge String to concat all the elements:
+ * def printRDDElement(record:(String, Seq[String]), f:String=>Unit) =
+ * for (e <- record._2){f(e)}
* @return the result RDD
*/
def pipe(
- command: Seq[String],
- env: Map[String, String] = Map(),
+ command: Seq[String],
+ env: Map[String, String] = Map(),
printPipeContext: (String => Unit) => Unit = null,
- printRDDElement: (T, String => Unit) => Unit = null): RDD[String] =
- new PipedRDD(this, command, env,
- if (printPipeContext ne null) sc.clean(printPipeContext) else null,
+ printRDDElement: (T, String => Unit) => Unit = null): RDD[String] =
+ new PipedRDD(this, command, env,
+ if (printPipeContext ne null) sc.clean(printPipeContext) else null,
if (printRDDElement ne null) sc.clean(printRDDElement) else null)
/**
@@ -840,7 +855,7 @@ abstract class RDD[T: ClassManifest](
private var storageLevel: StorageLevel = StorageLevel.NONE
/** Record user function generating this RDD. */
- private[spark] val origin = Utils.getSparkCallSite
+ private[spark] val origin = Utils.formatSparkCallSite
private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T]
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index 6c37203707..228e831dff 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -50,7 +50,6 @@ import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend
import spark.storage.{BlockManagerUI, StorageStatus, StorageUtils, RDDInfo}
import spark.util.{MetadataCleaner, TimeStampedHashMap}
-
/**
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
* cluster, and can be used to create RDDs, accumulators and broadcast variables on that cluster.
@@ -117,13 +116,14 @@ class SparkContext(
// Environment variables to pass to our executors
private[spark] val executorEnvs = HashMap[String, String]()
// Note: SPARK_MEM is included for Mesos, but overwritten for standalone mode in ExecutorRunner
- for (key <- Seq("SPARK_MEM", "SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS",
- "SPARK_TESTING")) {
+ for (key <- Seq("SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS", "SPARK_TESTING")) {
val value = System.getenv(key)
if (value != null) {
executorEnvs(key) = value
}
}
+ // Since memory can be set with a system property too, use that
+ executorEnvs("SPARK_MEM") = SparkContext.executorMemoryRequested + "m"
if (environment != null) {
executorEnvs ++= environment
}
@@ -158,14 +158,12 @@ class SparkContext(
scheduler
case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) =>
- // Check to make sure SPARK_MEM <= memoryPerSlave. Otherwise Spark will just hang.
+ // Check to make sure memory requested <= memoryPerSlave. Otherwise Spark will just hang.
val memoryPerSlaveInt = memoryPerSlave.toInt
- val sparkMemEnv = System.getenv("SPARK_MEM")
- val sparkMemEnvInt = if (sparkMemEnv != null) Utils.memoryStringToMb(sparkMemEnv) else 512
- if (sparkMemEnvInt > memoryPerSlaveInt) {
+ if (SparkContext.executorMemoryRequested > memoryPerSlaveInt) {
throw new SparkException(
- "Slave memory (%d MB) cannot be smaller than SPARK_MEM (%d MB)".format(
- memoryPerSlaveInt, sparkMemEnvInt))
+ "Asked to launch cluster with %d MB RAM / worker but requested %d MB/worker".format(
+ memoryPerSlaveInt, SparkContext.executorMemoryRequested))
}
val scheduler = new ClusterScheduler(this)
@@ -631,7 +629,7 @@ class SparkContext(
partitions: Seq[Int],
allowLocal: Boolean,
resultHandler: (Int, U) => Unit) {
- val callSite = Utils.getSparkCallSite
+ val callSite = Utils.formatSparkCallSite
logInfo("Starting job: " + callSite)
val start = System.nanoTime
val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler, localProperties.value)
@@ -714,7 +712,7 @@ class SparkContext(
func: (TaskContext, Iterator[T]) => U,
evaluator: ApproximateEvaluator[U, R],
timeout: Long): PartialResult[R] = {
- val callSite = Utils.getSparkCallSite
+ val callSite = Utils.formatSparkCallSite
logInfo("Starting job: " + callSite)
val start = System.nanoTime
val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout, localProperties.value)
@@ -883,6 +881,15 @@ object SparkContext {
/** Find the JAR that contains the class of a particular object */
def jarOfObject(obj: AnyRef): Seq[String] = jarOfClass(obj.getClass)
+
+ /** Get the amount of memory per executor requested through system properties or SPARK_MEM */
+ private[spark] val executorMemoryRequested = {
+ // TODO: Might need to add some extra memory for the non-heap parts of the JVM
+ Option(System.getProperty("spark.executor.memory"))
+ .orElse(Option(System.getenv("SPARK_MEM")))
+ .map(Utils.memoryStringToMb)
+ .getOrElse(512)
+ }
}
diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala
index be1a04d619..ec59b4f48f 100644
--- a/core/src/main/scala/spark/SparkEnv.scala
+++ b/core/src/main/scala/spark/SparkEnv.scala
@@ -1,5 +1,8 @@
package spark
+import collection.mutable
+import serializer.Serializer
+
import akka.actor.{Actor, ActorRef, Props, ActorSystemImpl, ActorSystem}
import akka.remote.RemoteActorRefProvider
@@ -9,6 +12,7 @@ import spark.storage.BlockManagerMaster
import spark.network.ConnectionManager
import spark.serializer.{Serializer, SerializerManager}
import spark.util.AkkaUtils
+import spark.api.python.PythonWorkerFactory
/**
@@ -37,7 +41,10 @@ class SparkEnv (
// If executorId is NOT found, return defaultHostPort
var executorIdToHostPort: Option[(String, String) => String]) {
+ private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
+
def stop() {
+ pythonWorkers.foreach { case(key, worker) => worker.stop() }
httpFileServer.stop()
mapOutputTracker.stop()
shuffleFetcher.stop()
@@ -50,6 +57,12 @@ class SparkEnv (
actorSystem.awaitTermination()
}
+ def createPythonWorker(pythonExec: String, envVars: Map[String, String]): java.net.Socket = {
+ synchronized {
+ val key = (pythonExec, envVars)
+ pythonWorkers.getOrElseUpdate(key, new PythonWorkerFactory(pythonExec, envVars)).create()
+ }
+ }
def resolveExecutorIdToHostPort(executorId: String, defaultHostPort: String): String = {
val env = SparkEnv.get
diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala
index 645c18541e..f41efa9d29 100644
--- a/core/src/main/scala/spark/Utils.scala
+++ b/core/src/main/scala/spark/Utils.scala
@@ -522,13 +522,45 @@ private object Utils extends Logging {
execute(command, new File("."))
}
+ /**
+ * Execute a command and get its output, throwing an exception if it yields a code other than 0.
+ */
+ def executeAndGetOutput(command: Seq[String], workingDir: File = new File(".")): String = {
+ val process = new ProcessBuilder(command: _*)
+ .directory(workingDir)
+ .start()
+ new Thread("read stderr for " + command(0)) {
+ override def run() {
+ for (line <- Source.fromInputStream(process.getErrorStream).getLines) {
+ System.err.println(line)
+ }
+ }
+ }.start()
+ val output = new StringBuffer
+ val stdoutThread = new Thread("read stdout for " + command(0)) {
+ override def run() {
+ for (line <- Source.fromInputStream(process.getInputStream).getLines) {
+ output.append(line)
+ }
+ }
+ }
+ stdoutThread.start()
+ val exitCode = process.waitFor()
+ stdoutThread.join() // Wait for it to finish reading output
+ if (exitCode != 0) {
+ throw new SparkException("Process " + command + " exited with code " + exitCode)
+ }
+ output.toString
+ }
+ private[spark] class CallSiteInfo(val lastSparkMethod: String, val firstUserFile: String,
+ val firstUserLine: Int, val firstUserClass: String)
/**
* When called inside a class in the spark package, returns the name of the user code class
* (outside the spark package) that called into Spark, as well as which Spark method they called.
* This is used, for example, to tell users where in their code each RDD got created.
*/
- def getSparkCallSite: String = {
+ def getCallSiteInfo: CallSiteInfo = {
val trace = Thread.currentThread.getStackTrace().filter( el =>
(!el.getMethodName.contains("getStackTrace")))
@@ -540,6 +572,7 @@ private object Utils extends Logging {
var firstUserFile = "<unknown>"
var firstUserLine = 0
var finished = false
+ var firstUserClass = "<unknown>"
for (el <- trace) {
if (!finished) {
@@ -554,13 +587,19 @@ private object Utils extends Logging {
else {
firstUserLine = el.getLineNumber
firstUserFile = el.getFileName
+ firstUserClass = el.getClassName
finished = true
}
}
}
- "%s at %s:%s".format(lastSparkMethod, firstUserFile, firstUserLine)
+ new CallSiteInfo(lastSparkMethod, firstUserFile, firstUserLine, firstUserClass)
}
+ def formatSparkCallSite = {
+ val callSiteInfo = getCallSiteInfo
+ "%s at %s:%s".format(callSiteInfo.lastSparkMethod, callSiteInfo.firstUserFile,
+ callSiteInfo.firstUserLine)
+ }
/**
* Try to find a free port to bind to on the local host. This should ideally never be needed,
* except that, unfortunately, some of the networking libraries we currently rely on (e.g. Spray)
@@ -602,4 +641,67 @@ private object Utils extends Logging {
}
return false
}
+
+ def isSpace(c: Char): Boolean = {
+ " \t\r\n".indexOf(c) != -1
+ }
+
+ /**
+ * Split a string of potentially quoted arguments from the command line the way that a shell
+ * would do it to determine arguments to a command. For example, if the string is 'a "b c" d',
+ * then it would be parsed as three arguments: 'a', 'b c' and 'd'.
+ */
+ def splitCommandString(s: String): Seq[String] = {
+ val buf = new ArrayBuffer[String]
+ var inWord = false
+ var inSingleQuote = false
+ var inDoubleQuote = false
+ var curWord = new StringBuilder
+ def endWord() {
+ buf += curWord.toString
+ curWord.clear()
+ }
+ var i = 0
+ while (i < s.length) {
+ var nextChar = s.charAt(i)
+ if (inDoubleQuote) {
+ if (nextChar == '"') {
+ inDoubleQuote = false
+ } else if (nextChar == '\\') {
+ if (i < s.length - 1) {
+ // Append the next character directly, because only " and \ may be escaped in
+ // double quotes after the shell's own expansion
+ curWord.append(s.charAt(i + 1))
+ i += 1
+ }
+ } else {
+ curWord.append(nextChar)
+ }
+ } else if (inSingleQuote) {
+ if (nextChar == '\'') {
+ inSingleQuote = false
+ } else {
+ curWord.append(nextChar)
+ }
+ // Backslashes are not treated specially in single quotes
+ } else if (nextChar == '"') {
+ inWord = true
+ inDoubleQuote = true
+ } else if (nextChar == '\'') {
+ inWord = true
+ inSingleQuote = true
+ } else if (!isSpace(nextChar)) {
+ curWord.append(nextChar)
+ inWord = true
+ } else if (inWord && isSpace(nextChar)) {
+ endWord()
+ inWord = false
+ }
+ i += 1
+ }
+ if (inWord || inDoubleQuote || inSingleQuote) {
+ endWord()
+ }
+ return buf
+ }
}
diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
index 807119ca8c..31d8ea89d4 100644
--- a/core/src/main/scala/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -2,10 +2,9 @@ package spark.api.python
import java.io._
import java.net._
-import java.util.{List => JList, ArrayList => JArrayList, Collections}
+import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}
import scala.collection.JavaConversions._
-import scala.io.Source
import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
import spark.broadcast.Broadcast
@@ -16,16 +15,18 @@ import spark.rdd.PipedRDD
private[spark] class PythonRDD[T: ClassManifest](
parent: RDD[T],
command: Seq[String],
- envVars: java.util.Map[String, String],
+ envVars: JMap[String, String],
preservePartitoning: Boolean,
pythonExec: String,
broadcastVars: JList[Broadcast[Array[Byte]]],
accumulator: Accumulator[JList[Array[Byte]]])
extends RDD[Array[Byte]](parent) {
+ val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
+
// Similar to Runtime.exec(), if we are given a single string, split it into words
// using a standard StringTokenizer (i.e. by spaces)
- def this(parent: RDD[T], command: String, envVars: java.util.Map[String, String],
+ def this(parent: RDD[T], command: String, envVars: JMap[String, String],
preservePartitoning: Boolean, pythonExec: String,
broadcastVars: JList[Broadcast[Array[Byte]]],
accumulator: Accumulator[JList[Array[Byte]]]) =
@@ -36,68 +37,57 @@ private[spark] class PythonRDD[T: ClassManifest](
override val partitioner = if (preservePartitoning) parent.partitioner else None
- override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
- val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME")
-
- val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/python/pyspark/worker.py"))
- // Add the environmental variables to the process.
- val currentEnvVars = pb.environment()
- for ((variable, value) <- envVars) {
- currentEnvVars.put(variable, value)
- }
-
- val proc = pb.start()
+ override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
+ val startTime = System.currentTimeMillis
val env = SparkEnv.get
-
- // Start a thread to print the process's stderr to ours
- new Thread("stderr reader for " + pythonExec) {
- override def run() {
- for (line <- Source.fromInputStream(proc.getErrorStream).getLines) {
- System.err.println(line)
- }
- }
- }.start()
+ val worker = env.createPythonWorker(pythonExec, envVars.toMap)
// Start a thread to feed the process input from our parent's iterator
new Thread("stdin writer for " + pythonExec) {
override def run() {
SparkEnv.set(env)
- val out = new PrintWriter(proc.getOutputStream)
- val dOut = new DataOutputStream(proc.getOutputStream)
+ val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
+ val dataOut = new DataOutputStream(stream)
+ val printOut = new PrintWriter(stream)
// Partition index
- dOut.writeInt(split.index)
+ dataOut.writeInt(split.index)
// sparkFilesDir
- PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dOut)
+ PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dataOut)
// Broadcast variables
- dOut.writeInt(broadcastVars.length)
+ dataOut.writeInt(broadcastVars.length)
for (broadcast <- broadcastVars) {
- dOut.writeLong(broadcast.id)
- dOut.writeInt(broadcast.value.length)
- dOut.write(broadcast.value)
- dOut.flush()
+ dataOut.writeLong(broadcast.id)
+ dataOut.writeInt(broadcast.value.length)
+ dataOut.write(broadcast.value)
}
+ dataOut.flush()
// Serialized user code
for (elem <- command) {
- out.println(elem)
+ printOut.println(elem)
}
- out.flush()
+ printOut.flush()
// Data values
for (elem <- parent.iterator(split, context)) {
- PythonRDD.writeAsPickle(elem, dOut)
+ PythonRDD.writeAsPickle(elem, dataOut)
}
- dOut.flush()
- out.flush()
- proc.getOutputStream.close()
+ dataOut.flush()
+ printOut.flush()
+ worker.shutdownOutput()
}
}.start()
// Return an iterator that read lines from the process's stdout
- val stream = new DataInputStream(proc.getInputStream)
+ val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
return new Iterator[Array[Byte]] {
def next(): Array[Byte] = {
val obj = _nextObj
- _nextObj = read()
+ if (hasNext) {
+ // FIXME: can deadlock if worker is waiting for us to
+ // respond to current message (currently irrelevant because
+ // output is shutdown before we read any input)
+ _nextObj = read()
+ }
obj
}
@@ -108,6 +98,17 @@ private[spark] class PythonRDD[T: ClassManifest](
val obj = new Array[Byte](length)
stream.readFully(obj)
obj
+ case -3 =>
+ // Timing data from worker
+ val bootTime = stream.readLong()
+ val initTime = stream.readLong()
+ val finishTime = stream.readLong()
+ val boot = bootTime - startTime
+ val init = initTime - bootTime
+ val finish = finishTime - initTime
+ val total = finishTime - startTime
+ logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, init, finish))
+ read
case -2 =>
// Signals that an exception has been thrown in python
val exLength = stream.readInt()
@@ -115,23 +116,21 @@ private[spark] class PythonRDD[T: ClassManifest](
stream.readFully(obj)
throw new PythonException(new String(obj))
case -1 =>
- // We've finished the data section of the output, but we can still read some
- // accumulator updates; let's do that, breaking when we get EOFException
- while (true) {
- val len2 = stream.readInt()
+ // We've finished the data section of the output, but we can still
+ // read some accumulator updates; let's do that, breaking when we
+ // get a negative length record.
+ var len2 = stream.readInt()
+ while (len2 >= 0) {
val update = new Array[Byte](len2)
stream.readFully(update)
accumulator += Collections.singletonList(update)
+ len2 = stream.readInt()
}
new Array[Byte](0)
}
} catch {
case eof: EOFException => {
- val exitStatus = proc.waitFor()
- if (exitStatus != 0) {
- throw new Exception("Subprocess exited with status " + exitStatus)
- }
- new Array[Byte](0)
+ throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
}
case e => throw e
}
@@ -159,7 +158,7 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends
override def compute(split: Partition, context: TaskContext) =
prev.iterator(split, context).grouped(2).map {
case Seq(a, b) => (a, b)
- case x => throw new Exception("PairwiseRDD: unexpected value: " + x)
+ case x => throw new SparkException("PairwiseRDD: unexpected value: " + x)
}
val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this)
}
@@ -215,7 +214,7 @@ private[spark] object PythonRDD {
dOut.write(s)
dOut.writeByte(Pickle.STOP)
} else {
- throw new Exception("Unexpected RDD type")
+ throw new SparkException("Unexpected RDD type")
}
}
@@ -279,6 +278,8 @@ class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
extends AccumulatorParam[JList[Array[Byte]]] {
Utils.checkHost(serverHost, "Expected hostname")
+
+ val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList
@@ -292,7 +293,7 @@ class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
// This happens on the master, where we pass the updates to Python through a socket
val socket = new Socket(serverHost, serverPort)
val in = socket.getInputStream
- val out = new DataOutputStream(socket.getOutputStream)
+ val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize))
out.writeInt(val2.size)
for (array <- val2) {
out.writeInt(array.length)
diff --git a/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala
new file mode 100644
index 0000000000..85d1dfeac8
--- /dev/null
+++ b/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala
@@ -0,0 +1,113 @@
+package spark.api.python
+
+import java.io.{DataInputStream, IOException}
+import java.net.{Socket, SocketException, InetAddress}
+
+import scala.collection.JavaConversions._
+
+import spark._
+
+private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String])
+ extends Logging {
+ var daemon: Process = null
+ val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
+ var daemonPort: Int = 0
+
+ def create(): Socket = {
+ synchronized {
+ // Start the daemon if it hasn't been started
+ startDaemon()
+
+ // Attempt to connect, restart and retry once if it fails
+ try {
+ new Socket(daemonHost, daemonPort)
+ } catch {
+ case exc: SocketException => {
+ logWarning("Python daemon unexpectedly quit, attempting to restart")
+ stopDaemon()
+ startDaemon()
+ new Socket(daemonHost, daemonPort)
+ }
+ case e => throw e
+ }
+ }
+ }
+
+ def stop() {
+ stopDaemon()
+ }
+
+ private def startDaemon() {
+ synchronized {
+ // Is it already running?
+ if (daemon != null) {
+ return
+ }
+
+ try {
+ // Create and start the daemon
+ val sparkHome = new ProcessBuilder().environment().get("SPARK_HOME")
+ val pb = new ProcessBuilder(Seq(pythonExec, sparkHome + "/python/pyspark/daemon.py"))
+ val workerEnv = pb.environment()
+ workerEnv.putAll(envVars)
+ daemon = pb.start()
+
+ // Redirect the stderr to ours
+ new Thread("stderr reader for " + pythonExec) {
+ override def run() {
+ scala.util.control.Exception.ignoring(classOf[IOException]) {
+ // FIXME HACK: We copy the stream on the level of bytes to
+ // attempt to dodge encoding problems.
+ val in = daemon.getErrorStream
+ var buf = new Array[Byte](1024)
+ var len = in.read(buf)
+ while (len != -1) {
+ System.err.write(buf, 0, len)
+ len = in.read(buf)
+ }
+ }
+ }
+ }.start()
+
+ val in = new DataInputStream(daemon.getInputStream)
+ daemonPort = in.readInt()
+
+ // Redirect further stdout output to our stderr
+ new Thread("stdout reader for " + pythonExec) {
+ override def run() {
+ scala.util.control.Exception.ignoring(classOf[IOException]) {
+ // FIXME HACK: We copy the stream on the level of bytes to
+ // attempt to dodge encoding problems.
+ var buf = new Array[Byte](1024)
+ var len = in.read(buf)
+ while (len != -1) {
+ System.err.write(buf, 0, len)
+ len = in.read(buf)
+ }
+ }
+ }
+ }.start()
+ } catch {
+ case e => {
+ stopDaemon()
+ throw e
+ }
+ }
+
+ // Important: don't close daemon's stdin (daemon.getOutputStream) so it can correctly
+ // detect our disappearance.
+ }
+ }
+
+ private def stopDaemon() {
+ synchronized {
+ // Request shutdown of existing daemon by sending SIGTERM
+ if (daemon != null) {
+ daemon.destroy()
+ }
+
+ daemon = null
+ daemonPort = 0
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala
index 3e965e82ac..87f304b6cd 100644
--- a/core/src/main/scala/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/spark/deploy/master/Master.scala
@@ -278,7 +278,9 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
exec.state = ExecutorState.KILLED
}
app.markFinished(state)
- app.driver ! ApplicationRemoved(state.toString)
+ if (state != ApplicationState.FINISHED) {
+ app.driver ! ApplicationRemoved(state.toString)
+ }
schedule()
}
}
diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
index 04a774658e..d7f58b2cb1 100644
--- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
@@ -1,6 +1,7 @@
package spark.deploy.worker
import java.io._
+import java.lang.System.getenv
import spark.deploy.{ExecutorState, ExecutorStateChanged, ApplicationDescription}
import akka.actor.ActorRef
import spark.{Utils, Logging}
@@ -40,7 +41,7 @@ private[spark] class ExecutorRunner(
workerThread.start()
// Shutdown hook that kills actors on shutdown.
- shutdownHook = new Thread() {
+ shutdownHook = new Thread() {
override def run() {
if (process != null) {
logInfo("Shutdown hook killing child process.")
@@ -77,9 +78,29 @@ private[spark] class ExecutorRunner(
def buildCommandSeq(): Seq[String] = {
val command = appDesc.command
- val script = if (System.getProperty("os.name").startsWith("Windows")) "run.cmd" else "run"
- val runScript = new File(sparkHome, script).getCanonicalPath
- Seq(runScript, command.mainClass) ++ (command.arguments ++ Seq(appId)).map(substituteVariables)
+ val runner = Option(getenv("JAVA_HOME")).map(_ + "/bin/java").getOrElse("java")
+ // SPARK-698: do not call the run.cmd script, as process.destroy()
+ // fails to kill a process tree on Windows
+ Seq(runner) ++ buildJavaOpts() ++ Seq(command.mainClass) ++
+ command.arguments.map(substituteVariables)
+ }
+
+ /**
+ * Attention: this must always be aligned with the environment variables in the run scripts and
+ * the way the JAVA_OPTS are assembled there.
+ */
+ def buildJavaOpts(): Seq[String] = {
+ val libraryOpts = Option(getenv("SPARK_LIBRARY_PATH"))
+ .map(p => List("-Djava.library.path=" + p))
+ .getOrElse(Nil)
+ val userOpts = Option(getenv("SPARK_JAVA_OPTS")).map(Utils.splitCommandString).getOrElse(Nil)
+ val memoryOpts = Seq("-Xms" + memory + "M", "-Xmx" + memory + "M")
+
+ // Figure out our classpath with the external compute-classpath script
+ val ext = if (System.getProperty("os.name").startsWith("Windows")) ".cmd" else ".sh"
+ val classPath = Utils.executeAndGetOutput(Seq(sparkHome + "/bin/compute-classpath" + ext))
+
+ Seq("-cp", classPath) ++ libraryOpts ++ userOpts ++ memoryOpts
}
/** Spawn a thread that will redirect a given stream to a file */
@@ -115,7 +136,6 @@ private[spark] class ExecutorRunner(
for ((key, value) <- appDesc.command.environment) {
env.put(key, value)
}
- env.put("SPARK_MEM", memory.toString + "m")
// In case we are running this from within the Spark Shell, avoid creating a "scala"
// parent process for the executor command
env.put("SPARK_LAUNCH_WITH_SCALA", "0")
diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala
index 890938d48b..2bf55ea9a9 100644
--- a/core/src/main/scala/spark/executor/Executor.scala
+++ b/core/src/main/scala/spark/executor/Executor.scala
@@ -42,7 +42,8 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert
// Create our ClassLoader and set it on this thread
private val urlClassLoader = createClassLoader()
- Thread.currentThread.setContextClassLoader(urlClassLoader)
+ private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader)
+ Thread.currentThread.setContextClassLoader(replClassLoader)
// Make any thread terminations due to uncaught exceptions kill the entire
// executor process to avoid surprising stalls.
@@ -88,7 +89,7 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert
override def run() {
val startTime = System.currentTimeMillis()
SparkEnv.set(env)
- Thread.currentThread.setContextClassLoader(urlClassLoader)
+ Thread.currentThread.setContextClassLoader(replClassLoader)
val ser = SparkEnv.get.closureSerializer.newInstance()
logInfo("Running task ID " + taskId)
context.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
@@ -104,6 +105,7 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert
val value = task.run(taskId.toInt)
val taskFinish = System.currentTimeMillis()
task.metrics.foreach{ m =>
+ m.hostname = Utils.localHostName
m.executorDeserializeTime = (taskStart - startTime).toInt
m.executorRunTime = (taskFinish - taskStart).toInt
}
@@ -152,26 +154,31 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert
val urls = currentJars.keySet.map { uri =>
new File(uri.split("/").last).toURI.toURL
}.toArray
- loader = new URLClassLoader(urls, loader)
+ new ExecutorURLClassLoader(urls, loader)
+ }
- // If the REPL is in use, add another ClassLoader that will read
- // new classes defined by the REPL as the user types code
+ /**
+ * If the REPL is in use, add another ClassLoader that will read
+ * new classes defined by the REPL as the user types code
+ */
+ private def addReplClassLoaderIfNeeded(parent: ClassLoader): ClassLoader = {
val classUri = System.getProperty("spark.repl.class.uri")
if (classUri != null) {
logInfo("Using REPL class URI: " + classUri)
- loader = {
- try {
- val klass = Class.forName("spark.repl.ExecutorClassLoader")
- .asInstanceOf[Class[_ <: ClassLoader]]
- val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader])
- constructor.newInstance(classUri, loader)
- } catch {
- case _: ClassNotFoundException => loader
- }
+ try {
+ val klass = Class.forName("spark.repl.ExecutorClassLoader")
+ .asInstanceOf[Class[_ <: ClassLoader]]
+ val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader])
+ return constructor.newInstance(classUri, parent)
+ } catch {
+ case _: ClassNotFoundException =>
+ logError("Could not find spark.repl.ExecutorClassLoader on classpath!")
+ System.exit(1)
+ null
}
+ } else {
+ return parent
}
-
- return new ExecutorURLClassLoader(Array(), loader)
}
/**
diff --git a/core/src/main/scala/spark/executor/TaskMetrics.scala b/core/src/main/scala/spark/executor/TaskMetrics.scala
index a7c56c2371..1dc13754f9 100644
--- a/core/src/main/scala/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/spark/executor/TaskMetrics.scala
@@ -2,6 +2,11 @@ package spark.executor
class TaskMetrics extends Serializable {
/**
+ * Host's name the task runs on
+ */
+ var hostname: String = _
+
+ /**
* Time taken on the executor to deserialize this task
*/
var executorDeserializeTime: Int = _
@@ -34,9 +39,14 @@ object TaskMetrics {
class ShuffleReadMetrics extends Serializable {
/**
+ * Time when shuffle finishs
+ */
+ var shuffleFinishTime: Long = _
+
+ /**
* Total number of blocks fetched in a shuffle (remote or local)
*/
- var totalBlocksFetched : Int = _
+ var totalBlocksFetched: Int = _
/**
* Number of remote blocks fetched in a shuffle
diff --git a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala
index a91f5a886d..8d5194a737 100644
--- a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala
+++ b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala
@@ -9,19 +9,36 @@ import io.netty.util.CharsetUtil
import spark.Logging
import spark.network.ConnectionManagerId
+import scala.collection.JavaConverters._
+
private[spark] class ShuffleCopier extends Logging {
- def getBlock(cmId: ConnectionManagerId, blockId: String,
+ def getBlock(host: String, port: Int, blockId: String,
resultCollectCallback: (String, Long, ByteBuf) => Unit) {
val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback)
- val fc = new FileClient(handler)
- fc.init()
- fc.connect(cmId.host, cmId.port)
- fc.sendRequest(blockId)
- fc.waitForClose()
- fc.close()
+ val connectTimeout = System.getProperty("spark.shuffle.netty.connect.timeout", "60000").toInt
+ val fc = new FileClient(handler, connectTimeout)
+
+ try {
+ fc.init()
+ fc.connect(host, port)
+ fc.sendRequest(blockId)
+ fc.waitForClose()
+ fc.close()
+ } catch {
+ // Handle any socket-related exceptions in FileClient
+ case e: Exception => {
+ logError("Shuffle copy of block " + blockId + " from " + host + ":" + port + " failed", e)
+ handler.handleError(blockId)
+ }
+ }
+ }
+
+ def getBlock(cmId: ConnectionManagerId, blockId: String,
+ resultCollectCallback: (String, Long, ByteBuf) => Unit) {
+ getBlock(cmId.host, cmId.port, blockId, resultCollectCallback)
}
def getBlocks(cmId: ConnectionManagerId,
@@ -44,20 +61,18 @@ private[spark] object ShuffleCopier extends Logging {
logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)");
resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen))
}
- }
- def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) {
- logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"")
+ override def handleError(blockId: String) {
+ if (!isComplete) {
+ resultCollectCallBack(blockId, -1, null)
+ }
+ }
}
- def runGetBlock(host:String, port:Int, file:String){
- val handler = new ShuffleClientHandler(echoResultCollectCallBack)
- val fc = new FileClient(handler)
- fc.init();
- fc.connect(host, port)
- fc.sendRequest(file)
- fc.waitForClose();
- fc.close()
+ def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) {
+ if (size != -1) {
+ logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"")
+ }
}
def main(args: Array[String]) {
@@ -71,14 +86,16 @@ private[spark] object ShuffleCopier extends Logging {
val threads = if (args.length > 3) args(3).toInt else 10
val copiers = Executors.newFixedThreadPool(80)
- for (i <- Range(0, threads)) {
- val runnable = new Runnable() {
+ val tasks = (for (i <- Range(0, threads)) yield {
+ Executors.callable(new Runnable() {
def run() {
- runGetBlock(host, port, file)
+ val copier = new ShuffleCopier()
+ copier.getBlock(host, port, file, echoResultCollectCallBack)
}
- }
- copiers.execute(runnable)
- }
+ })
+ }).asJava
+ copiers.invokeAll(tasks)
copiers.shutdown
+ System.exit(0)
}
}
diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala
index b2c07891ab..c0baf43d43 100644
--- a/core/src/main/scala/spark/rdd/PipedRDD.scala
+++ b/core/src/main/scala/spark/rdd/PipedRDD.scala
@@ -62,7 +62,7 @@ class PipedRDD[T: ClassManifest](
val out = new PrintWriter(proc.getOutputStream)
// input the pipe context firstly
- if ( printPipeContext != null) {
+ if (printPipeContext != null) {
printPipeContext(out.println(_))
}
for (elem <- firstParent[T].iterator(split, context)) {
diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
index 7feeb97542..cbd375e5c1 100644
--- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
@@ -298,6 +298,7 @@ class DAGScheduler(
// Compute very short actions like first() or take() with no parent stages locally.
runLocally(job)
} else {
+ sparkListeners.foreach(_.onJobStart(SparkListenerJobStart(job, properties)))
idToActiveJob(runId) = job
activeJobs += job
resultStageToJob(finalStage) = job
@@ -311,6 +312,8 @@ class DAGScheduler(
handleExecutorLost(execId)
case completion: CompletionEvent =>
+ sparkListeners.foreach(_.onTaskEnd(SparkListenerTaskEnd(completion.task,
+ completion.reason, completion.taskInfo, completion.taskMetrics)))
handleTaskCompletion(completion)
case TaskSetFailed(taskSet, reason) =>
@@ -321,6 +324,7 @@ class DAGScheduler(
for (job <- activeJobs) {
val error = new SparkException("Job cancelled because SparkContext was shut down")
job.listener.jobFailed(error)
+ sparkListeners.foreach(_.onJobEnd(SparkListenerJobEnd(job, JobFailed(error))))
}
return true
}
@@ -468,6 +472,7 @@ class DAGScheduler(
}
}
if (tasks.size > 0) {
+ sparkListeners.foreach(_.onStageSubmitted(SparkListenerStageSubmitted(stage, tasks.size)))
logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")")
myPending ++= tasks
logDebug("New pending tasks: " + myPending)
@@ -519,9 +524,11 @@ class DAGScheduler(
job.numFinished += 1
// If the whole job has finished, remove it
if (job.numFinished == job.numPartitions) {
+ idToActiveJob -= stage.priority
activeJobs -= job
resultStageToJob -= stage
markStageAsFinished(stage)
+ sparkListeners.foreach(_.onJobEnd(SparkListenerJobEnd(job, JobSucceeded)))
}
job.listener.taskSucceeded(rt.outputId, event.result)
}
@@ -645,7 +652,7 @@ class DAGScheduler(
"(generation " + currentGeneration + ")")
}
}
-
+
private def handleExecutorGained(execId: String, hostPort: String) {
// remove from failedGeneration(execId) ?
if (failedGeneration.contains(execId)) {
@@ -662,7 +669,10 @@ class DAGScheduler(
val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq
for (resultStage <- dependentStages) {
val job = resultStageToJob(resultStage)
- job.listener.jobFailed(new SparkException("Job failed: " + reason))
+ val error = new SparkException("Job failed: " + reason)
+ job.listener.jobFailed(error)
+ sparkListeners.foreach(_.onJobEnd(SparkListenerJobEnd(job, JobFailed(error))))
+ idToActiveJob -= resultStage.priority
activeJobs -= job
resultStageToJob -= resultStage
}
@@ -739,6 +749,10 @@ class DAGScheduler(
sizeBefore = pendingTasks.size
pendingTasks.clearOldValues(cleanupTime)
logInfo("pendingTasks " + sizeBefore + " --> " + pendingTasks.size)
+
+ sizeBefore = stageToInfos.size
+ stageToInfos.clearOldValues(cleanupTime)
+ logInfo("stageToInfos " + sizeBefore + " --> " + stageToInfos.size)
}
def stop() {
diff --git a/core/src/main/scala/spark/scheduler/JobLogger.scala b/core/src/main/scala/spark/scheduler/JobLogger.scala
new file mode 100644
index 0000000000..178bfaba3d
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/JobLogger.scala
@@ -0,0 +1,306 @@
+package spark.scheduler
+
+import java.io.PrintWriter
+import java.io.File
+import java.io.FileNotFoundException
+import java.text.SimpleDateFormat
+import java.util.{Date, Properties}
+import java.util.concurrent.LinkedBlockingQueue
+import scala.collection.mutable.{Map, HashMap, ListBuffer}
+import scala.io.Source
+import spark._
+import spark.executor.TaskMetrics
+import spark.scheduler.cluster.TaskInfo
+
+// Used to record runtime information for each job, including RDD graph
+// tasks' start/stop shuffle information and information from outside
+
+class JobLogger(val logDirName: String) extends SparkListener with Logging {
+ private val logDir =
+ if (System.getenv("SPARK_LOG_DIR") != null)
+ System.getenv("SPARK_LOG_DIR")
+ else
+ "/tmp/spark"
+ private val jobIDToPrintWriter = new HashMap[Int, PrintWriter]
+ private val stageIDToJobID = new HashMap[Int, Int]
+ private val jobIDToStages = new HashMap[Int, ListBuffer[Stage]]
+ private val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
+ private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents]
+
+ createLogDir()
+ def this() = this(String.valueOf(System.currentTimeMillis()))
+
+ def getLogDir = logDir
+ def getJobIDtoPrintWriter = jobIDToPrintWriter
+ def getStageIDToJobID = stageIDToJobID
+ def getJobIDToStages = jobIDToStages
+ def getEventQueue = eventQueue
+
+ new Thread("JobLogger") {
+ setDaemon(true)
+ override def run() {
+ while (true) {
+ val event = eventQueue.take
+ logDebug("Got event of type " + event.getClass.getName)
+ event match {
+ case SparkListenerJobStart(job, properties) =>
+ processJobStartEvent(job, properties)
+ case SparkListenerStageSubmitted(stage, taskSize) =>
+ processStageSubmittedEvent(stage, taskSize)
+ case StageCompleted(stageInfo) =>
+ processStageCompletedEvent(stageInfo)
+ case SparkListenerJobEnd(job, result) =>
+ processJobEndEvent(job, result)
+ case SparkListenerTaskEnd(task, reason, taskInfo, taskMetrics) =>
+ processTaskEndEvent(task, reason, taskInfo, taskMetrics)
+ case _ =>
+ }
+ }
+ }
+ }.start()
+
+ // Create a folder for log files, the folder's name is the creation time of the jobLogger
+ protected def createLogDir() {
+ val dir = new File(logDir + "/" + logDirName + "/")
+ if (dir.exists()) {
+ return
+ }
+ if (dir.mkdirs() == false) {
+ logError("create log directory error:" + logDir + "/" + logDirName + "/")
+ }
+ }
+
+ // Create a log file for one job, the file name is the jobID
+ protected def createLogWriter(jobID: Int) {
+ try{
+ val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobID)
+ jobIDToPrintWriter += (jobID -> fileWriter)
+ } catch {
+ case e: FileNotFoundException => e.printStackTrace()
+ }
+ }
+
+ // Close log file, and clean the stage relationship in stageIDToJobID
+ protected def closeLogWriter(jobID: Int) =
+ jobIDToPrintWriter.get(jobID).foreach { fileWriter =>
+ fileWriter.close()
+ jobIDToStages.get(jobID).foreach(_.foreach{ stage =>
+ stageIDToJobID -= stage.id
+ })
+ jobIDToPrintWriter -= jobID
+ jobIDToStages -= jobID
+ }
+
+ // Write log information to log file, withTime parameter controls whether to recored
+ // time stamp for the information
+ protected def jobLogInfo(jobID: Int, info: String, withTime: Boolean = true) {
+ var writeInfo = info
+ if (withTime) {
+ val date = new Date(System.currentTimeMillis())
+ writeInfo = DATE_FORMAT.format(date) + ": " +info
+ }
+ jobIDToPrintWriter.get(jobID).foreach(_.println(writeInfo))
+ }
+
+ protected def stageLogInfo(stageID: Int, info: String, withTime: Boolean = true) =
+ stageIDToJobID.get(stageID).foreach(jobID => jobLogInfo(jobID, info, withTime))
+
+ protected def buildJobDep(jobID: Int, stage: Stage) {
+ if (stage.priority == jobID) {
+ jobIDToStages.get(jobID) match {
+ case Some(stageList) => stageList += stage
+ case None => val stageList = new ListBuffer[Stage]
+ stageList += stage
+ jobIDToStages += (jobID -> stageList)
+ }
+ stageIDToJobID += (stage.id -> jobID)
+ stage.parents.foreach(buildJobDep(jobID, _))
+ }
+ }
+
+ protected def recordStageDep(jobID: Int) {
+ def getRddsInStage(rdd: RDD[_]): ListBuffer[RDD[_]] = {
+ var rddList = new ListBuffer[RDD[_]]
+ rddList += rdd
+ rdd.dependencies.foreach{ dep => dep match {
+ case shufDep: ShuffleDependency[_,_] =>
+ case _ => rddList ++= getRddsInStage(dep.rdd)
+ }
+ }
+ rddList
+ }
+ jobIDToStages.get(jobID).foreach {_.foreach { stage =>
+ var depRddDesc: String = ""
+ getRddsInStage(stage.rdd).foreach { rdd =>
+ depRddDesc += rdd.id + ","
+ }
+ var depStageDesc: String = ""
+ stage.parents.foreach { stage =>
+ depStageDesc += "(" + stage.id + "," + stage.shuffleDep.get.shuffleId + ")"
+ }
+ jobLogInfo(jobID, "STAGE_ID=" + stage.id + " RDD_DEP=(" +
+ depRddDesc.substring(0, depRddDesc.length - 1) + ")" +
+ " STAGE_DEP=" + depStageDesc, false)
+ }
+ }
+ }
+
+ // Generate indents and convert to String
+ protected def indentString(indent: Int) = {
+ val sb = new StringBuilder()
+ for (i <- 1 to indent) {
+ sb.append(" ")
+ }
+ sb.toString()
+ }
+
+ protected def getRddName(rdd: RDD[_]) = {
+ var rddName = rdd.getClass.getName
+ if (rdd.name != null) {
+ rddName = rdd.name
+ }
+ rddName
+ }
+
+ protected def recordRddInStageGraph(jobID: Int, rdd: RDD[_], indent: Int) {
+ val rddInfo = "RDD_ID=" + rdd.id + "(" + getRddName(rdd) + "," + rdd.generator + ")"
+ jobLogInfo(jobID, indentString(indent) + rddInfo, false)
+ rdd.dependencies.foreach{ dep => dep match {
+ case shufDep: ShuffleDependency[_,_] =>
+ val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId
+ jobLogInfo(jobID, indentString(indent + 1) + depInfo, false)
+ case _ => recordRddInStageGraph(jobID, dep.rdd, indent + 1)
+ }
+ }
+ }
+
+ protected def recordStageDepGraph(jobID: Int, stage: Stage, indent: Int = 0) {
+ var stageInfo: String = ""
+ if (stage.isShuffleMap) {
+ stageInfo = "STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" +
+ stage.shuffleDep.get.shuffleId
+ }else{
+ stageInfo = "STAGE_ID=" + stage.id + " RESULT_STAGE"
+ }
+ if (stage.priority == jobID) {
+ jobLogInfo(jobID, indentString(indent) + stageInfo, false)
+ recordRddInStageGraph(jobID, stage.rdd, indent)
+ stage.parents.foreach(recordStageDepGraph(jobID, _, indent + 2))
+ } else
+ jobLogInfo(jobID, indentString(indent) + stageInfo + " JOB_ID=" + stage.priority, false)
+ }
+
+ // Record task metrics into job log files
+ protected def recordTaskMetrics(stageID: Int, status: String,
+ taskInfo: TaskInfo, taskMetrics: TaskMetrics) {
+ val info = " TID=" + taskInfo.taskId + " STAGE_ID=" + stageID +
+ " START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime +
+ " EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname
+ val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime
+ val readMetrics =
+ taskMetrics.shuffleReadMetrics match {
+ case Some(metrics) =>
+ " SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime +
+ " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
+ " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched +
+ " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched +
+ " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime +
+ " REMOTE_FETCH_TIME=" + metrics.remoteFetchTime +
+ " REMOTE_BYTES_READ=" + metrics.remoteBytesRead
+ case None => ""
+ }
+ val writeMetrics =
+ taskMetrics.shuffleWriteMetrics match {
+ case Some(metrics) =>
+ " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten
+ case None => ""
+ }
+ stageLogInfo(stageID, status + info + executorRunTime + readMetrics + writeMetrics)
+ }
+
+ override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
+ eventQueue.put(stageSubmitted)
+ }
+
+ protected def processStageSubmittedEvent(stage: Stage, taskSize: Int) {
+ stageLogInfo(stage.id, "STAGE_ID=" + stage.id + " STATUS=SUBMITTED" + " TASK_SIZE=" + taskSize)
+ }
+
+ override def onStageCompleted(stageCompleted: StageCompleted) {
+ eventQueue.put(stageCompleted)
+ }
+
+ protected def processStageCompletedEvent(stageInfo: StageInfo) {
+ stageLogInfo(stageInfo.stage.id, "STAGE_ID=" +
+ stageInfo.stage.id + " STATUS=COMPLETED")
+
+ }
+
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
+ eventQueue.put(taskEnd)
+ }
+
+ protected def processTaskEndEvent(task: Task[_], reason: TaskEndReason,
+ taskInfo: TaskInfo, taskMetrics: TaskMetrics) {
+ var taskStatus = ""
+ task match {
+ case resultTask: ResultTask[_, _] => taskStatus = "TASK_TYPE=RESULT_TASK"
+ case shuffleMapTask: ShuffleMapTask => taskStatus = "TASK_TYPE=SHUFFLE_MAP_TASK"
+ }
+ reason match {
+ case Success => taskStatus += " STATUS=SUCCESS"
+ recordTaskMetrics(task.stageId, taskStatus, taskInfo, taskMetrics)
+ case Resubmitted =>
+ taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId +
+ " STAGE_ID=" + task.stageId
+ stageLogInfo(task.stageId, taskStatus)
+ case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
+ taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" +
+ task.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" +
+ mapId + " REDUCE_ID=" + reduceId
+ stageLogInfo(task.stageId, taskStatus)
+ case OtherFailure(message) =>
+ taskStatus += " STATUS=FAILURE TID=" + taskInfo.taskId +
+ " STAGE_ID=" + task.stageId + " INFO=" + message
+ stageLogInfo(task.stageId, taskStatus)
+ case _ =>
+ }
+ }
+
+ override def onJobEnd(jobEnd: SparkListenerJobEnd) {
+ eventQueue.put(jobEnd)
+ }
+
+ protected def processJobEndEvent(job: ActiveJob, reason: JobResult) {
+ var info = "JOB_ID=" + job.runId
+ reason match {
+ case JobSucceeded => info += " STATUS=SUCCESS"
+ case JobFailed(exception) =>
+ info += " STATUS=FAILED REASON="
+ exception.getMessage.split("\\s+").foreach(info += _ + "_")
+ case _ =>
+ }
+ jobLogInfo(job.runId, info.substring(0, info.length - 1).toUpperCase)
+ closeLogWriter(job.runId)
+ }
+
+ protected def recordJobProperties(jobID: Int, properties: Properties) {
+ if(properties != null) {
+ val annotation = properties.getProperty("spark.job.annotation", "")
+ jobLogInfo(jobID, annotation, false)
+ }
+ }
+
+ override def onJobStart(jobStart: SparkListenerJobStart) {
+ eventQueue.put(jobStart)
+ }
+
+ protected def processJobStartEvent(job: ActiveJob, properties: Properties) {
+ createLogWriter(job.runId)
+ recordJobProperties(job.runId, properties)
+ buildJobDep(job.runId, job.finalStage)
+ recordStageDep(job.runId)
+ recordStageDepGraph(job.runId, job.finalStage)
+ jobLogInfo(job.runId, "JOB_ID=" + job.runId + " STATUS=STARTED")
+ }
+}
diff --git a/core/src/main/scala/spark/scheduler/SparkListener.scala b/core/src/main/scala/spark/scheduler/SparkListener.scala
index a65140b145..bac984b5c9 100644
--- a/core/src/main/scala/spark/scheduler/SparkListener.scala
+++ b/core/src/main/scala/spark/scheduler/SparkListener.scala
@@ -1,27 +1,59 @@
package spark.scheduler
+import java.util.Properties
import spark.scheduler.cluster.TaskInfo
import spark.util.Distribution
-import spark.{Utils, Logging}
+import spark.{Logging, SparkContext, TaskEndReason, Utils}
import spark.executor.TaskMetrics
-trait SparkListener {
- /**
- * called when a stage is completed, with information on the completed stage
- */
- def onStageCompleted(stageCompleted: StageCompleted)
-}
-
sealed trait SparkListenerEvents
+case class SparkListenerStageSubmitted(stage: Stage, taskSize: Int) extends SparkListenerEvents
+
case class StageCompleted(val stageInfo: StageInfo) extends SparkListenerEvents
+case class SparkListenerTaskEnd(task: Task[_], reason: TaskEndReason, taskInfo: TaskInfo,
+ taskMetrics: TaskMetrics) extends SparkListenerEvents
+
+case class SparkListenerJobStart(job: ActiveJob, properties: Properties = null)
+ extends SparkListenerEvents
+
+case class SparkListenerJobEnd(job: ActiveJob, jobResult: JobResult)
+ extends SparkListenerEvents
+
+trait SparkListener {
+ /**
+ * Called when a stage is completed, with information on the completed stage
+ */
+ def onStageCompleted(stageCompleted: StageCompleted) { }
+
+ /**
+ * Called when a stage is submitted
+ */
+ def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) { }
+
+ /**
+ * Called when a task ends
+ */
+ def onTaskEnd(taskEnd: SparkListenerTaskEnd) { }
+
+ /**
+ * Called when a job starts
+ */
+ def onJobStart(jobStart: SparkListenerJobStart) { }
+
+ /**
+ * Called when a job ends
+ */
+ def onJobEnd(jobEnd: SparkListenerJobEnd) { }
+
+}
/**
* Simple SparkListener that logs a few summary statistics when each stage completes
*/
class StatsReportListener extends SparkListener with Logging {
- def onStageCompleted(stageCompleted: StageCompleted) {
+ override def onStageCompleted(stageCompleted: StageCompleted) {
import spark.scheduler.StatsReportListener._
implicit val sc = stageCompleted
this.logInfo("Finished stage: " + stageCompleted.stageInfo)
diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala
index 9ac875de3a..8844057a5c 100644
--- a/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala
@@ -1,6 +1,6 @@
package spark.scheduler.cluster
-import spark.Utils
+import spark.{SparkContext, Utils}
/**
* A backend interface for cluster scheduling systems that allows plugging in different ones under
@@ -14,14 +14,7 @@ private[spark] trait SchedulerBackend {
def defaultParallelism(): Int
// Memory used by each executor (in megabytes)
- protected val executorMemory = {
- // TODO: Might need to add some extra memory for the non-heap parts of the JVM
- Option(System.getProperty("spark.executor.memory"))
- .orElse(Option(System.getenv("SPARK_MEM")))
- .map(Utils.memoryStringToMb)
- .getOrElse(512)
- }
-
+ protected val executorMemory: Int = SparkContext.executorMemoryRequested
// TODO: Probably want to add a killTask too
}
diff --git a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala
index 1d69d658f7..bec876213e 100644
--- a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala
+++ b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala
@@ -67,11 +67,20 @@ object BlockFetcherIterator {
throw new IllegalArgumentException("BlocksByAddress is null")
}
- protected var _totalBlocks = blocksByAddress.map(_._2.size).sum
- logDebug("Getting " + _totalBlocks + " blocks")
+ // Total number blocks fetched (local + remote). Also number of FetchResults expected
+ protected var _numBlocksToFetch = 0
+
protected var startTime = System.currentTimeMillis
- protected val localBlockIds = new ArrayBuffer[String]()
- protected val remoteBlockIds = new HashSet[String]()
+
+ // This represents the number of local blocks, also counting zero-sized blocks
+ private var numLocal = 0
+ // BlockIds for local blocks that need to be fetched. Excludes zero-sized blocks
+ protected val localBlocksToFetch = new ArrayBuffer[String]()
+
+ // This represents the number of remote blocks, also counting zero-sized blocks
+ private var numRemote = 0
+ // BlockIds for remote blocks that need to be fetched. Excludes zero-sized blocks
+ protected val remoteBlocksToFetch = new HashSet[String]()
// A queue to hold our results.
protected val results = new LinkedBlockingQueue[FetchResult]
@@ -124,13 +133,15 @@ object BlockFetcherIterator {
protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
// Split local and remote blocks. Remote blocks are further split into FetchRequests of size
// at most maxBytesInFlight in order to limit the amount of data in flight.
- val originalTotalBlocks = _totalBlocks
val remoteRequests = new ArrayBuffer[FetchRequest]
for ((address, blockInfos) <- blocksByAddress) {
if (address == blockManagerId) {
- localBlockIds ++= blockInfos.map(_._1)
+ numLocal = blockInfos.size
+ // Filter out zero-sized blocks
+ localBlocksToFetch ++= blockInfos.filter(_._2 != 0).map(_._1)
+ _numBlocksToFetch += localBlocksToFetch.size
} else {
- remoteBlockIds ++= blockInfos.map(_._1)
+ numRemote += blockInfos.size
// Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them
// smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
// nodes, rather than blocking on reading output from one node.
@@ -144,10 +155,10 @@ object BlockFetcherIterator {
// Skip empty blocks
if (size > 0) {
curBlocks += ((blockId, size))
+ remoteBlocksToFetch += blockId
+ _numBlocksToFetch += 1
curRequestSize += size
- } else if (size == 0) {
- _totalBlocks -= 1
- } else {
+ } else if (size < 0) {
throw new BlockException(blockId, "Negative block size " + size)
}
if (curRequestSize >= minRequestSize) {
@@ -163,8 +174,8 @@ object BlockFetcherIterator {
}
}
}
- logInfo("Getting " + _totalBlocks + " non-zero-bytes blocks out of " +
- originalTotalBlocks + " blocks")
+ logInfo("Getting " + _numBlocksToFetch + " non-zero-bytes blocks out of " +
+ totalBlocks + " blocks")
remoteRequests
}
@@ -172,7 +183,7 @@ object BlockFetcherIterator {
// Get the local blocks while remote blocks are being fetched. Note that it's okay to do
// these all at once because they will just memory-map some files, so they won't consume
// any memory that might exceed our maxBytesInFlight
- for (id <- localBlockIds) {
+ for (id <- localBlocksToFetch) {
getLocalFromDisk(id, serializer) match {
case Some(iter) => {
// Pass 0 as size since it's not in flight
@@ -198,7 +209,7 @@ object BlockFetcherIterator {
sendRequest(fetchRequests.dequeue())
}
- val numGets = remoteBlockIds.size - fetchRequests.size
+ val numGets = remoteRequests.size - fetchRequests.size
logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime))
// Get Local Blocks
@@ -210,7 +221,7 @@ object BlockFetcherIterator {
//an iterator that will read fetched blocks off the queue as they arrive.
@volatile protected var resultsGotten = 0
- override def hasNext: Boolean = resultsGotten < _totalBlocks
+ override def hasNext: Boolean = resultsGotten < _numBlocksToFetch
override def next(): (String, Option[Iterator[Any]]) = {
resultsGotten += 1
@@ -227,9 +238,9 @@ object BlockFetcherIterator {
}
// Implementing BlockFetchTracker trait.
- override def totalBlocks: Int = _totalBlocks
- override def numLocalBlocks: Int = localBlockIds.size
- override def numRemoteBlocks: Int = remoteBlockIds.size
+ override def totalBlocks: Int = numLocal + numRemote
+ override def numLocalBlocks: Int = numLocal
+ override def numRemoteBlocks: Int = numRemote
override def remoteFetchTime: Long = _remoteFetchTime
override def fetchWaitTime: Long = _fetchWaitTime
override def remoteBytesRead: Long = _remoteBytesRead
@@ -265,7 +276,7 @@ object BlockFetcherIterator {
}).toList
}
- //keep this to interrupt the threads when necessary
+ // keep this to interrupt the threads when necessary
private def stopCopiers() {
for (copier <- copiers) {
copier.interrupt()
@@ -291,7 +302,7 @@ object BlockFetcherIterator {
private var copiers: List[_ <: Thread] = null
override def initialize() {
- // Split Local Remote Blocks and adjust totalBlocks to include only the non 0-byte blocks
+ // Split Local Remote Blocks and set numBlocksToFetch
val remoteRequests = splitLocalRemoteBlocks()
// Add the remote requests into our queue in a random order
for (request <- Utils.randomize(remoteRequests)) {
@@ -311,10 +322,7 @@ object BlockFetcherIterator {
override def next(): (String, Option[Iterator[Any]]) = {
resultsGotten += 1
val result = results.take()
- // if all the results has been retrieved, shutdown the copiers
- if (resultsGotten == _totalBlocks && copiers != null) {
- stopCopiers()
- }
+ // If all the results has been retrieved, copiers will exit automatically
(result.blockId, if (result.failed) None else Some(result.deserialize()))
}
}
diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala
index 9914beec99..da859eebcb 100644
--- a/core/src/main/scala/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/spark/storage/DiskStore.scala
@@ -35,21 +35,25 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
private var bs: OutputStream = null
private var objOut: SerializationStream = null
private var lastValidPosition = 0L
+ private var initialized = false
override def open(): DiskBlockObjectWriter = {
val fos = new FileOutputStream(f, true)
channel = fos.getChannel()
- bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos))
+ bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos, bufferSize))
objOut = serializer.newInstance().serializeStream(bs)
+ initialized = true
this
}
override def close() {
- objOut.close()
- bs.close()
- channel = null
- bs = null
- objOut = null
+ if (initialized) {
+ objOut.close()
+ bs.close()
+ channel = null
+ bs = null
+ objOut = null
+ }
// Invoke the close callback handler.
super.close()
}
@@ -59,23 +63,33 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
// Flush the partial writes, and set valid length to be the length of the entire file.
// Return the number of bytes written for this commit.
override def commit(): Long = {
- // NOTE: Flush the serializer first and then the compressed/buffered output stream
- objOut.flush()
- bs.flush()
- val prevPos = lastValidPosition
- lastValidPosition = channel.position()
- lastValidPosition - prevPos
+ if (initialized) {
+ // NOTE: Flush the serializer first and then the compressed/buffered output stream
+ objOut.flush()
+ bs.flush()
+ val prevPos = lastValidPosition
+ lastValidPosition = channel.position()
+ lastValidPosition - prevPos
+ } else {
+ // lastValidPosition is zero if stream is uninitialized
+ lastValidPosition
+ }
}
override def revertPartialWrites() {
- // Discard current writes. We do this by flushing the outstanding writes and
- // truncate the file to the last valid position.
- objOut.flush()
- bs.flush()
- channel.truncate(lastValidPosition)
+ if (initialized) {
+ // Discard current writes. We do this by flushing the outstanding writes and
+ // truncate the file to the last valid position.
+ objOut.flush()
+ bs.flush()
+ channel.truncate(lastValidPosition)
+ }
}
override def write(value: Any) {
+ if (!initialized) {
+ open()
+ }
objOut.writeObject(value)
}
@@ -196,7 +210,10 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
private def createFile(blockId: String, allowAppendExisting: Boolean = false): File = {
val file = getFile(blockId)
if (!allowAppendExisting && file.exists()) {
- throw new Exception("File for block " + blockId + " already exists on disk: " + file)
+ // NOTE(shivaram): Delete the file if it exists. This might happen if a ShuffleMap task
+ // was rescheduled on the same machine as the old task.
+ logWarning("File for block " + blockId + " already exists on disk: " + file + ". Deleting")
+ file.delete()
}
file
}
diff --git a/core/src/main/scala/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala
index 49eabfb0d2..44638e0c2d 100644
--- a/core/src/main/scala/spark/storage/ShuffleBlockManager.scala
+++ b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala
@@ -24,7 +24,7 @@ class ShuffleBlockManager(blockManager: BlockManager) {
val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024
val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
val blockId = ShuffleBlockManager.blockId(shuffleId, bucketId, mapId)
- blockManager.getDiskBlockWriter(blockId, serializer, bufferSize).open()
+ blockManager.getDiskBlockWriter(blockId, serializer, bufferSize)
}
new ShuffleWriterGroup(mapId, writers)
}
diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala
index e93cc3b485..0cff0d5b01 100644
--- a/core/src/main/scala/spark/util/AkkaUtils.scala
+++ b/core/src/main/scala/spark/util/AkkaUtils.scala
@@ -1,6 +1,6 @@
package spark.util
-import akka.actor.{ActorRef, Props, ActorSystemImpl, ActorSystem}
+import akka.actor.{ActorRef, Props, ExtendedActorSystem, ActorSystem}
import com.typesafe.config.ConfigFactory
import akka.util.duration._
import akka.pattern.ask
@@ -56,7 +56,7 @@ private[spark] object AkkaUtils {
// Figure out the port number we bound to, in case port was passed as 0. This is a bit of a
// hack because Akka doesn't let you figure out the port through the public API yet.
- val provider = actorSystem.asInstanceOf[ActorSystemImpl].provider
+ val provider = actorSystem.asInstanceOf[ExtendedActorSystem].provider
val boundPort = provider.asInstanceOf[RemoteActorRefProvider].transport.address.port.get
return (actorSystem, boundPort)
}
diff --git a/core/src/test/resources/fairscheduler.xml b/core/src/test/resources/fairscheduler.xml
index 5a688b0ebb..6e573b1883 100644
--- a/core/src/test/resources/fairscheduler.xml
+++ b/core/src/test/resources/fairscheduler.xml
@@ -1,3 +1,4 @@
+<?xml version="1.0"?>
<allocations>
<pool name="1">
<minShare>2</minShare>
diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala
index ca385972fb..28a7b21b92 100644
--- a/core/src/test/scala/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/spark/CheckpointSuite.scala
@@ -27,6 +27,16 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
}
}
+ test("basic checkpointing") {
+ val parCollection = sc.makeRDD(1 to 4)
+ val flatMappedRDD = parCollection.flatMap(x => 1 to x)
+ flatMappedRDD.checkpoint()
+ assert(flatMappedRDD.dependencies.head.rdd == parCollection)
+ val result = flatMappedRDD.collect()
+ assert(flatMappedRDD.dependencies.head.rdd != parCollection)
+ assert(flatMappedRDD.collect() === result)
+ }
+
test("RDDs with one-to-one dependencies") {
testCheckpointing(_.map(x => x.toString))
testCheckpointing(_.flatMap(x => 1 to x))
diff --git a/core/src/test/scala/spark/PairRDDFunctionsSuite.scala b/core/src/test/scala/spark/PairRDDFunctionsSuite.scala
new file mode 100644
index 0000000000..682d2745bf
--- /dev/null
+++ b/core/src/test/scala/spark/PairRDDFunctionsSuite.scala
@@ -0,0 +1,287 @@
+package spark
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashSet
+
+import org.scalatest.FunSuite
+import org.scalatest.prop.Checkers
+import org.scalacheck.Arbitrary._
+import org.scalacheck.Gen
+import org.scalacheck.Prop._
+
+import com.google.common.io.Files
+
+import spark.rdd.ShuffledRDD
+import spark.SparkContext._
+
+class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
+ test("groupByKey") {
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
+ val groups = pairs.groupByKey().collect()
+ assert(groups.size === 2)
+ val valuesFor1 = groups.find(_._1 == 1).get._2
+ assert(valuesFor1.toList.sorted === List(1, 2, 3))
+ val valuesFor2 = groups.find(_._1 == 2).get._2
+ assert(valuesFor2.toList.sorted === List(1))
+ }
+
+ test("groupByKey with duplicates") {
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
+ val groups = pairs.groupByKey().collect()
+ assert(groups.size === 2)
+ val valuesFor1 = groups.find(_._1 == 1).get._2
+ assert(valuesFor1.toList.sorted === List(1, 1, 2, 3))
+ val valuesFor2 = groups.find(_._1 == 2).get._2
+ assert(valuesFor2.toList.sorted === List(1))
+ }
+
+ test("groupByKey with negative key hash codes") {
+ val pairs = sc.parallelize(Array((-1, 1), (-1, 2), (-1, 3), (2, 1)))
+ val groups = pairs.groupByKey().collect()
+ assert(groups.size === 2)
+ val valuesForMinus1 = groups.find(_._1 == -1).get._2
+ assert(valuesForMinus1.toList.sorted === List(1, 2, 3))
+ val valuesFor2 = groups.find(_._1 == 2).get._2
+ assert(valuesFor2.toList.sorted === List(1))
+ }
+
+ test("groupByKey with many output partitions") {
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
+ val groups = pairs.groupByKey(10).collect()
+ assert(groups.size === 2)
+ val valuesFor1 = groups.find(_._1 == 1).get._2
+ assert(valuesFor1.toList.sorted === List(1, 2, 3))
+ val valuesFor2 = groups.find(_._1 == 2).get._2
+ assert(valuesFor2.toList.sorted === List(1))
+ }
+
+ test("reduceByKey") {
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
+ val sums = pairs.reduceByKey(_+_).collect()
+ assert(sums.toSet === Set((1, 7), (2, 1)))
+ }
+
+ test("reduceByKey with collectAsMap") {
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
+ val sums = pairs.reduceByKey(_+_).collectAsMap()
+ assert(sums.size === 2)
+ assert(sums(1) === 7)
+ assert(sums(2) === 1)
+ }
+
+ test("reduceByKey with many output partitons") {
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
+ val sums = pairs.reduceByKey(_+_, 10).collect()
+ assert(sums.toSet === Set((1, 7), (2, 1)))
+ }
+
+ test("reduceByKey with partitioner") {
+ val p = new Partitioner() {
+ def numPartitions = 2
+ def getPartition(key: Any) = key.asInstanceOf[Int]
+ }
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 1), (0, 1))).partitionBy(p)
+ val sums = pairs.reduceByKey(_+_)
+ assert(sums.collect().toSet === Set((1, 4), (0, 1)))
+ assert(sums.partitioner === Some(p))
+ // count the dependencies to make sure there is only 1 ShuffledRDD
+ val deps = new HashSet[RDD[_]]()
+ def visit(r: RDD[_]) {
+ for (dep <- r.dependencies) {
+ deps += dep.rdd
+ visit(dep.rdd)
+ }
+ }
+ visit(sums)
+ assert(deps.size === 2) // ShuffledRDD, ParallelCollection
+ }
+
+ test("join") {
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+ val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
+ val joined = rdd1.join(rdd2).collect()
+ assert(joined.size === 4)
+ assert(joined.toSet === Set(
+ (1, (1, 'x')),
+ (1, (2, 'x')),
+ (2, (1, 'y')),
+ (2, (1, 'z'))
+ ))
+ }
+
+ test("join all-to-all") {
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (1, 3)))
+ val rdd2 = sc.parallelize(Array((1, 'x'), (1, 'y')))
+ val joined = rdd1.join(rdd2).collect()
+ assert(joined.size === 6)
+ assert(joined.toSet === Set(
+ (1, (1, 'x')),
+ (1, (1, 'y')),
+ (1, (2, 'x')),
+ (1, (2, 'y')),
+ (1, (3, 'x')),
+ (1, (3, 'y'))
+ ))
+ }
+
+ test("leftOuterJoin") {
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+ val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
+ val joined = rdd1.leftOuterJoin(rdd2).collect()
+ assert(joined.size === 5)
+ assert(joined.toSet === Set(
+ (1, (1, Some('x'))),
+ (1, (2, Some('x'))),
+ (2, (1, Some('y'))),
+ (2, (1, Some('z'))),
+ (3, (1, None))
+ ))
+ }
+
+ test("rightOuterJoin") {
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+ val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
+ val joined = rdd1.rightOuterJoin(rdd2).collect()
+ assert(joined.size === 5)
+ assert(joined.toSet === Set(
+ (1, (Some(1), 'x')),
+ (1, (Some(2), 'x')),
+ (2, (Some(1), 'y')),
+ (2, (Some(1), 'z')),
+ (4, (None, 'w'))
+ ))
+ }
+
+ test("join with no matches") {
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+ val rdd2 = sc.parallelize(Array((4, 'x'), (5, 'y'), (5, 'z'), (6, 'w')))
+ val joined = rdd1.join(rdd2).collect()
+ assert(joined.size === 0)
+ }
+
+ test("join with many output partitions") {
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+ val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
+ val joined = rdd1.join(rdd2, 10).collect()
+ assert(joined.size === 4)
+ assert(joined.toSet === Set(
+ (1, (1, 'x')),
+ (1, (2, 'x')),
+ (2, (1, 'y')),
+ (2, (1, 'z'))
+ ))
+ }
+
+ test("groupWith") {
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+ val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
+ val joined = rdd1.groupWith(rdd2).collect()
+ assert(joined.size === 4)
+ assert(joined.toSet === Set(
+ (1, (ArrayBuffer(1, 2), ArrayBuffer('x'))),
+ (2, (ArrayBuffer(1), ArrayBuffer('y', 'z'))),
+ (3, (ArrayBuffer(1), ArrayBuffer())),
+ (4, (ArrayBuffer(), ArrayBuffer('w')))
+ ))
+ }
+
+ test("zero-partition RDD") {
+ val emptyDir = Files.createTempDir()
+ val file = sc.textFile(emptyDir.getAbsolutePath)
+ assert(file.partitions.size == 0)
+ assert(file.collect().toList === Nil)
+ // Test that a shuffle on the file works, because this used to be a bug
+ assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil)
+ }
+
+ test("keys and values") {
+ val rdd = sc.parallelize(Array((1, "a"), (2, "b")))
+ assert(rdd.keys.collect().toList === List(1, 2))
+ assert(rdd.values.collect().toList === List("a", "b"))
+ }
+
+ test("default partitioner uses partition size") {
+ // specify 2000 partitions
+ val a = sc.makeRDD(Array(1, 2, 3, 4), 2000)
+ // do a map, which loses the partitioner
+ val b = a.map(a => (a, (a * 2).toString))
+ // then a group by, and see we didn't revert to 2 partitions
+ val c = b.groupByKey()
+ assert(c.partitions.size === 2000)
+ }
+
+ test("default partitioner uses largest partitioner") {
+ val a = sc.makeRDD(Array((1, "a"), (2, "b")), 2)
+ val b = sc.makeRDD(Array((1, "a"), (2, "b")), 2000)
+ val c = a.join(b)
+ assert(c.partitions.size === 2000)
+ }
+
+ test("subtract") {
+ val a = sc.parallelize(Array(1, 2, 3), 2)
+ val b = sc.parallelize(Array(2, 3, 4), 4)
+ val c = a.subtract(b)
+ assert(c.collect().toSet === Set(1))
+ assert(c.partitions.size === a.partitions.size)
+ }
+
+ test("subtract with narrow dependency") {
+ // use a deterministic partitioner
+ val p = new Partitioner() {
+ def numPartitions = 5
+ def getPartition(key: Any) = key.asInstanceOf[Int]
+ }
+ // partitionBy so we have a narrow dependency
+ val a = sc.parallelize(Array((1, "a"), (2, "b"), (3, "c"))).partitionBy(p)
+ // more partitions/no partitioner so a shuffle dependency
+ val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4)
+ val c = a.subtract(b)
+ assert(c.collect().toSet === Set((1, "a"), (3, "c")))
+ // Ideally we could keep the original partitioner...
+ assert(c.partitioner === None)
+ }
+
+ test("subtractByKey") {
+ val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 2)
+ val b = sc.parallelize(Array((2, 20), (3, 30), (4, 40)), 4)
+ val c = a.subtractByKey(b)
+ assert(c.collect().toSet === Set((1, "a"), (1, "a")))
+ assert(c.partitions.size === a.partitions.size)
+ }
+
+ test("subtractByKey with narrow dependency") {
+ // use a deterministic partitioner
+ val p = new Partitioner() {
+ def numPartitions = 5
+ def getPartition(key: Any) = key.asInstanceOf[Int]
+ }
+ // partitionBy so we have a narrow dependency
+ val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c"))).partitionBy(p)
+ // more partitions/no partitioner so a shuffle dependency
+ val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4)
+ val c = a.subtractByKey(b)
+ assert(c.collect().toSet === Set((1, "a"), (1, "a")))
+ assert(c.partitioner.get === p)
+ }
+
+ test("foldByKey") {
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
+ val sums = pairs.foldByKey(0)(_+_).collect()
+ assert(sums.toSet === Set((1, 7), (2, 1)))
+ }
+
+ test("foldByKey with mutable result type") {
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
+ val bufs = pairs.mapValues(v => ArrayBuffer(v)).cache()
+ // Fold the values using in-place mutation
+ val sums = bufs.foldByKey(new ArrayBuffer[Int])(_ ++= _).collect()
+ assert(sums.toSet === Set((1, ArrayBuffer(1, 2, 3, 1)), (2, ArrayBuffer(1))))
+ // Check that the mutable objects in the original RDD were not changed
+ assert(bufs.collect().toSet === Set(
+ (1, ArrayBuffer(1)),
+ (1, ArrayBuffer(2)),
+ (1, ArrayBuffer(3)),
+ (1, ArrayBuffer(1)),
+ (2, ArrayBuffer(1))))
+ }
+}
diff --git a/core/src/test/scala/spark/PartitioningSuite.scala b/core/src/test/scala/spark/PartitioningSuite.scala
index 16f93e71a3..99e433e3bd 100644
--- a/core/src/test/scala/spark/PartitioningSuite.scala
+++ b/core/src/test/scala/spark/PartitioningSuite.scala
@@ -6,8 +6,8 @@ import SparkContext._
import spark.util.StatCounter
import scala.math.abs
-class PartitioningSuite extends FunSuite with LocalSparkContext {
-
+class PartitioningSuite extends FunSuite with SharedSparkContext {
+
test("HashPartitioner equality") {
val p2 = new HashPartitioner(2)
val p4 = new HashPartitioner(4)
@@ -21,8 +21,6 @@ class PartitioningSuite extends FunSuite with LocalSparkContext {
}
test("RangePartitioner equality") {
- sc = new SparkContext("local", "test")
-
// Make an RDD where all the elements are the same so that the partition range bounds
// are deterministically all the same.
val rdd = sc.parallelize(Seq(1, 1, 1, 1)).map(x => (x, x))
@@ -50,7 +48,6 @@ class PartitioningSuite extends FunSuite with LocalSparkContext {
}
test("HashPartitioner not equal to RangePartitioner") {
- sc = new SparkContext("local", "test")
val rdd = sc.parallelize(1 to 10).map(x => (x, x))
val rangeP2 = new RangePartitioner(2, rdd)
val hashP2 = new HashPartitioner(2)
@@ -61,8 +58,6 @@ class PartitioningSuite extends FunSuite with LocalSparkContext {
}
test("partitioner preservation") {
- sc = new SparkContext("local", "test")
-
val rdd = sc.parallelize(1 to 10, 4).map(x => (x, x))
val grouped2 = rdd.groupByKey(2)
@@ -101,7 +96,6 @@ class PartitioningSuite extends FunSuite with LocalSparkContext {
}
test("partitioning Java arrays should fail") {
- sc = new SparkContext("local", "test")
val arrs: RDD[Array[Int]] = sc.parallelize(Array(1, 2, 3, 4), 2).map(x => Array(x))
val arrPairs: RDD[(Array[Int], Int)] =
sc.parallelize(Array(1, 2, 3, 4), 2).map(x => (Array(x), x))
@@ -120,21 +114,20 @@ class PartitioningSuite extends FunSuite with LocalSparkContext {
assert(intercept[SparkException]{ arrPairs.reduceByKeyLocally(_ + _) }.getMessage.contains("array"))
assert(intercept[SparkException]{ arrPairs.reduceByKey(_ + _) }.getMessage.contains("array"))
}
-
- test("Zero-length partitions should be correctly handled") {
+
+ test("zero-length partitions should be correctly handled") {
// Create RDD with some consecutive empty partitions (including the "first" one)
- sc = new SparkContext("local", "test")
val rdd: RDD[Double] = sc
.parallelize(Array(-1.0, -1.0, -1.0, -1.0, 2.0, 4.0, -1.0, -1.0), 8)
.filter(_ >= 0.0)
-
+
// Run the partitions, including the consecutive empty ones, through StatCounter
val stats: StatCounter = rdd.stats();
assert(abs(6.0 - stats.sum) < 0.01);
assert(abs(6.0/2 - rdd.mean) < 0.01);
assert(abs(1.0 - rdd.variance) < 0.01);
assert(abs(1.0 - rdd.stdev) < 0.01);
-
+
// Add other tests here for classes that should be able to handle empty partitions correctly
}
}
diff --git a/core/src/test/scala/spark/PipedRDDSuite.scala b/core/src/test/scala/spark/PipedRDDSuite.scala
index ed075f93ec..1c9ca50811 100644
--- a/core/src/test/scala/spark/PipedRDDSuite.scala
+++ b/core/src/test/scala/spark/PipedRDDSuite.scala
@@ -3,10 +3,9 @@ package spark
import org.scalatest.FunSuite
import SparkContext._
-class PipedRDDSuite extends FunSuite with LocalSparkContext {
-
+class PipedRDDSuite extends FunSuite with SharedSparkContext {
+
test("basic pipe") {
- sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
val piped = nums.pipe(Seq("cat"))
@@ -20,12 +19,11 @@ class PipedRDDSuite extends FunSuite with LocalSparkContext {
}
test("advanced pipe") {
- sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
val bl = sc.broadcast(List("0"))
- val piped = nums.pipe(Seq("cat"),
- Map[String, String](),
+ val piped = nums.pipe(Seq("cat"),
+ Map[String, String](),
(f: String => Unit) => {bl.value.map(f(_));f("\u0001")},
(i:Int, f: String=> Unit) => f(i + "_"))
@@ -43,8 +41,8 @@ class PipedRDDSuite extends FunSuite with LocalSparkContext {
val nums1 = sc.makeRDD(Array("a\t1", "b\t2", "a\t3", "b\t4"), 2)
val d = nums1.groupBy(str=>str.split("\t")(0)).
- pipe(Seq("cat"),
- Map[String, String](),
+ pipe(Seq("cat"),
+ Map[String, String](),
(f: String => Unit) => {bl.value.map(f(_));f("\u0001")},
(i:Tuple2[String, Seq[String]], f: String=> Unit) => {for (e <- i._2){ f(e + "_")}}).collect()
assert(d.size === 8)
@@ -59,7 +57,6 @@ class PipedRDDSuite extends FunSuite with LocalSparkContext {
}
test("pipe with env variable") {
- sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
val piped = nums.pipe(Seq("printenv", "MY_TEST_ENV"), Map("MY_TEST_ENV" -> "LALALA"))
val c = piped.collect()
@@ -69,7 +66,6 @@ class PipedRDDSuite extends FunSuite with LocalSparkContext {
}
test("pipe with non-zero exit status") {
- sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
val piped = nums.pipe("cat nonexistent_file")
intercept[SparkException] {
diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala
index 67f3332d44..e41ae385c0 100644
--- a/core/src/test/scala/spark/RDDSuite.scala
+++ b/core/src/test/scala/spark/RDDSuite.scala
@@ -7,10 +7,9 @@ import org.scalatest.time.{Span, Millis}
import spark.SparkContext._
import spark.rdd.{CoalescedRDD, CoGroupedRDD, EmptyRDD, PartitionPruningRDD, ShuffledRDD}
-class RDDSuite extends FunSuite with LocalSparkContext {
+class RDDSuite extends FunSuite with SharedSparkContext {
test("basic operations") {
- sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
assert(nums.collect().toList === List(1, 2, 3, 4))
val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2)
@@ -46,7 +45,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
}
test("SparkContext.union") {
- sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
assert(sc.union(nums).collect().toList === List(1, 2, 3, 4))
assert(sc.union(nums, nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4))
@@ -55,7 +53,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
}
test("aggregate") {
- sc = new SparkContext("local", "test")
val pairs = sc.makeRDD(Array(("a", 1), ("b", 2), ("a", 2), ("c", 5), ("a", 3)))
type StringMap = HashMap[String, Int]
val emptyMap = new StringMap {
@@ -75,57 +72,14 @@ class RDDSuite extends FunSuite with LocalSparkContext {
assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5)))
}
- test("basic checkpointing") {
- import java.io.File
- val checkpointDir = File.createTempFile("temp", "")
- checkpointDir.delete()
-
- sc = new SparkContext("local", "test")
- sc.setCheckpointDir(checkpointDir.toString)
- val parCollection = sc.makeRDD(1 to 4)
- val flatMappedRDD = parCollection.flatMap(x => 1 to x)
- flatMappedRDD.checkpoint()
- assert(flatMappedRDD.dependencies.head.rdd == parCollection)
- val result = flatMappedRDD.collect()
- Thread.sleep(1000)
- assert(flatMappedRDD.dependencies.head.rdd != parCollection)
- assert(flatMappedRDD.collect() === result)
-
- checkpointDir.deleteOnExit()
- }
-
test("basic caching") {
- sc = new SparkContext("local", "test")
val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
assert(rdd.collect().toList === List(1, 2, 3, 4))
assert(rdd.collect().toList === List(1, 2, 3, 4))
assert(rdd.collect().toList === List(1, 2, 3, 4))
}
- test("unpersist RDD") {
- sc = new SparkContext("local", "test")
- val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
- rdd.count
- assert(sc.persistentRdds.isEmpty === false)
- rdd.unpersist()
- assert(sc.persistentRdds.isEmpty === true)
-
- failAfter(Span(3000, Millis)) {
- try {
- while (! sc.getRDDStorageInfo.isEmpty) {
- Thread.sleep(200)
- }
- } catch {
- case _ => { Thread.sleep(10) }
- // Do nothing. We might see exceptions because block manager
- // is racing this thread to remove entries from the driver.
- }
- }
- assert(sc.getRDDStorageInfo.isEmpty === true)
- }
-
test("caching with failures") {
- sc = new SparkContext("local", "test")
val onlySplit = new Partition { override def index: Int = 0 }
var shouldFail = true
val rdd = new RDD[Int](sc, Nil) {
@@ -148,7 +102,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
}
test("empty RDD") {
- sc = new SparkContext("local", "test")
val empty = new EmptyRDD[Int](sc)
assert(empty.count === 0)
assert(empty.collect().size === 0)
@@ -168,37 +121,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
}
test("cogrouped RDDs") {
- sc = new SparkContext("local", "test")
- val rdd1 = sc.makeRDD(Array((1, "one"), (1, "another one"), (2, "two"), (3, "three")), 2)
- val rdd2 = sc.makeRDD(Array((1, "one1"), (1, "another one1"), (2, "two1")), 2)
-
- // Use cogroup function
- val cogrouped = rdd1.cogroup(rdd2).collectAsMap()
- assert(cogrouped(1) === (Seq("one", "another one"), Seq("one1", "another one1")))
- assert(cogrouped(2) === (Seq("two"), Seq("two1")))
- assert(cogrouped(3) === (Seq("three"), Seq()))
-
- // Construct CoGroupedRDD directly, with map side combine enabled
- val cogrouped1 = new CoGroupedRDD[Int](
- Seq(rdd1.asInstanceOf[RDD[(Int, Any)]], rdd2.asInstanceOf[RDD[(Int, Any)]]),
- new HashPartitioner(3),
- true).collectAsMap()
- assert(cogrouped1(1).toSeq === Seq(Seq("one", "another one"), Seq("one1", "another one1")))
- assert(cogrouped1(2).toSeq === Seq(Seq("two"), Seq("two1")))
- assert(cogrouped1(3).toSeq === Seq(Seq("three"), Seq()))
-
- // Construct CoGroupedRDD directly, with map side combine disabled
- val cogrouped2 = new CoGroupedRDD[Int](
- Seq(rdd1.asInstanceOf[RDD[(Int, Any)]], rdd2.asInstanceOf[RDD[(Int, Any)]]),
- new HashPartitioner(3),
- false).collectAsMap()
- assert(cogrouped2(1).toSeq === Seq(Seq("one", "another one"), Seq("one1", "another one1")))
- assert(cogrouped2(2).toSeq === Seq(Seq("two"), Seq("two1")))
- assert(cogrouped2(3).toSeq === Seq(Seq("three"), Seq()))
- }
-
- test("coalesced RDDs") {
- sc = new SparkContext("local", "test")
val data = sc.parallelize(1 to 10, 10)
val coalesced1 = data.coalesce(2)
@@ -236,7 +158,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
}
test("zipped RDDs") {
- sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
val zipped = nums.zip(nums.map(_ + 1.0))
assert(zipped.glom().map(_.toList).collect().toList ===
@@ -248,7 +169,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
}
test("partition pruning") {
- sc = new SparkContext("local", "test")
val data = sc.parallelize(1 to 10, 10)
// Note that split number starts from 0, so > 8 means only 10th partition left.
val prunedRdd = new PartitionPruningRDD(data, splitNum => splitNum > 8)
@@ -260,7 +180,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
test("mapWith") {
import java.util.Random
- sc = new SparkContext("local", "test")
val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2)
val randoms = ones.mapWith(
(index: Int) => new Random(index + 42))
@@ -279,7 +198,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
test("flatMapWith") {
import java.util.Random
- sc = new SparkContext("local", "test")
val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2)
val randoms = ones.flatMapWith(
(index: Int) => new Random(index + 42))
@@ -301,7 +219,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
test("filterWith") {
import java.util.Random
- sc = new SparkContext("local", "test")
val ints = sc.makeRDD(Array(1, 2, 3, 4, 5, 6), 2)
val sample = ints.filterWith(
(index: Int) => new Random(index + 42))
@@ -319,7 +236,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
}
test("top with predefined ordering") {
- sc = new SparkContext("local", "test")
val nums = Array.range(1, 100000)
val ints = sc.makeRDD(scala.util.Random.shuffle(nums), 2)
val topK = ints.top(5)
@@ -328,7 +244,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
}
test("top with custom ordering") {
- sc = new SparkContext("local", "test")
val words = Vector("a", "b", "c", "d")
implicit val ord = implicitly[Ordering[String]].reverse
val rdd = sc.makeRDD(words, 2)
@@ -336,4 +251,37 @@ class RDDSuite extends FunSuite with LocalSparkContext {
assert(topK.size === 2)
assert(topK.sorted === Array("b", "a"))
}
+
+ test("takeSample") {
+ val data = sc.parallelize(1 to 100, 2)
+ for (seed <- 1 to 5) {
+ val sample = data.takeSample(withReplacement=false, 20, seed)
+ assert(sample.size === 20) // Got exactly 20 elements
+ assert(sample.toSet.size === 20) // Elements are distinct
+ assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
+ }
+ for (seed <- 1 to 5) {
+ val sample = data.takeSample(withReplacement=false, 200, seed)
+ assert(sample.size === 100) // Got only 100 elements
+ assert(sample.toSet.size === 100) // Elements are distinct
+ assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
+ }
+ for (seed <- 1 to 5) {
+ val sample = data.takeSample(withReplacement=true, 20, seed)
+ assert(sample.size === 20) // Got exactly 20 elements
+ assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
+ }
+ for (seed <- 1 to 5) {
+ val sample = data.takeSample(withReplacement=true, 100, seed)
+ assert(sample.size === 100) // Got exactly 100 elements
+ // Chance of getting all distinct elements is astronomically low, so test we got < 100
+ assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements")
+ }
+ for (seed <- 1 to 5) {
+ val sample = data.takeSample(withReplacement=true, 200, seed)
+ assert(sample.size === 200) // Got exactly 200 elements
+ // Chance of getting all distinct elements is still quite low, so test we got < 100
+ assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements")
+ }
+ }
}
diff --git a/core/src/test/scala/spark/SharedSparkContext.scala b/core/src/test/scala/spark/SharedSparkContext.scala
new file mode 100644
index 0000000000..1da79f9824
--- /dev/null
+++ b/core/src/test/scala/spark/SharedSparkContext.scala
@@ -0,0 +1,25 @@
+package spark
+
+import org.scalatest.Suite
+import org.scalatest.BeforeAndAfterAll
+
+/** Shares a local `SparkContext` between all tests in a suite and closes it at the end */
+trait SharedSparkContext extends BeforeAndAfterAll { self: Suite =>
+
+ @transient private var _sc: SparkContext = _
+
+ def sc: SparkContext = _sc
+
+ override def beforeAll() {
+ _sc = new SparkContext("local", "test")
+ super.beforeAll()
+ }
+
+ override def afterAll() {
+ if (_sc != null) {
+ LocalSparkContext.stop(_sc)
+ _sc = null
+ }
+ super.afterAll()
+ }
+}
diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala
index b967016cf7..950218fa28 100644
--- a/core/src/test/scala/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/spark/ShuffleSuite.scala
@@ -16,54 +16,9 @@ import spark.rdd.ShuffledRDD
import spark.SparkContext._
class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
-
- test("groupByKey") {
- sc = new SparkContext("local", "test")
- val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
- val groups = pairs.groupByKey().collect()
- assert(groups.size === 2)
- val valuesFor1 = groups.find(_._1 == 1).get._2
- assert(valuesFor1.toList.sorted === List(1, 2, 3))
- val valuesFor2 = groups.find(_._1 == 2).get._2
- assert(valuesFor2.toList.sorted === List(1))
- }
-
- test("groupByKey with duplicates") {
- sc = new SparkContext("local", "test")
- val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
- val groups = pairs.groupByKey().collect()
- assert(groups.size === 2)
- val valuesFor1 = groups.find(_._1 == 1).get._2
- assert(valuesFor1.toList.sorted === List(1, 1, 2, 3))
- val valuesFor2 = groups.find(_._1 == 2).get._2
- assert(valuesFor2.toList.sorted === List(1))
- }
-
- test("groupByKey with negative key hash codes") {
- sc = new SparkContext("local", "test")
- val pairs = sc.parallelize(Array((-1, 1), (-1, 2), (-1, 3), (2, 1)))
- val groups = pairs.groupByKey().collect()
- assert(groups.size === 2)
- val valuesForMinus1 = groups.find(_._1 == -1).get._2
- assert(valuesForMinus1.toList.sorted === List(1, 2, 3))
- val valuesFor2 = groups.find(_._1 == 2).get._2
- assert(valuesFor2.toList.sorted === List(1))
- }
-
- test("groupByKey with many output partitions") {
- sc = new SparkContext("local", "test")
- val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
- val groups = pairs.groupByKey(10).collect()
- assert(groups.size === 2)
- val valuesFor1 = groups.find(_._1 == 1).get._2
- assert(valuesFor1.toList.sorted === List(1, 2, 3))
- val valuesFor2 = groups.find(_._1 == 2).get._2
- assert(valuesFor2.toList.sorted === List(1))
- }
-
test("groupByKey with compression") {
try {
- System.setProperty("spark.blockManager.compress", "true")
+ System.setProperty("spark.shuffle.compress", "true")
sc = new SparkContext("local", "test")
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)), 4)
val groups = pairs.groupByKey(4).collect()
@@ -77,234 +32,6 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
}
}
- test("reduceByKey") {
- sc = new SparkContext("local", "test")
- val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
- val sums = pairs.reduceByKey(_+_).collect()
- assert(sums.toSet === Set((1, 7), (2, 1)))
- }
-
- test("reduceByKey with collectAsMap") {
- sc = new SparkContext("local", "test")
- val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
- val sums = pairs.reduceByKey(_+_).collectAsMap()
- assert(sums.size === 2)
- assert(sums(1) === 7)
- assert(sums(2) === 1)
- }
-
- test("reduceByKey with many output partitons") {
- sc = new SparkContext("local", "test")
- val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
- val sums = pairs.reduceByKey(_+_, 10).collect()
- assert(sums.toSet === Set((1, 7), (2, 1)))
- }
-
- test("reduceByKey with partitioner") {
- sc = new SparkContext("local", "test")
- val p = new Partitioner() {
- def numPartitions = 2
- def getPartition(key: Any) = key.asInstanceOf[Int]
- }
- val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 1), (0, 1))).partitionBy(p)
- val sums = pairs.reduceByKey(_+_)
- assert(sums.collect().toSet === Set((1, 4), (0, 1)))
- assert(sums.partitioner === Some(p))
- // count the dependencies to make sure there is only 1 ShuffledRDD
- val deps = new HashSet[RDD[_]]()
- def visit(r: RDD[_]) {
- for (dep <- r.dependencies) {
- deps += dep.rdd
- visit(dep.rdd)
- }
- }
- visit(sums)
- assert(deps.size === 2) // ShuffledRDD, ParallelCollection
- }
-
- test("join") {
- sc = new SparkContext("local", "test")
- val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
- val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
- val joined = rdd1.join(rdd2).collect()
- assert(joined.size === 4)
- assert(joined.toSet === Set(
- (1, (1, 'x')),
- (1, (2, 'x')),
- (2, (1, 'y')),
- (2, (1, 'z'))
- ))
- }
-
- test("join all-to-all") {
- sc = new SparkContext("local", "test")
- val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (1, 3)))
- val rdd2 = sc.parallelize(Array((1, 'x'), (1, 'y')))
- val joined = rdd1.join(rdd2).collect()
- assert(joined.size === 6)
- assert(joined.toSet === Set(
- (1, (1, 'x')),
- (1, (1, 'y')),
- (1, (2, 'x')),
- (1, (2, 'y')),
- (1, (3, 'x')),
- (1, (3, 'y'))
- ))
- }
-
- test("leftOuterJoin") {
- sc = new SparkContext("local", "test")
- val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
- val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
- val joined = rdd1.leftOuterJoin(rdd2).collect()
- assert(joined.size === 5)
- assert(joined.toSet === Set(
- (1, (1, Some('x'))),
- (1, (2, Some('x'))),
- (2, (1, Some('y'))),
- (2, (1, Some('z'))),
- (3, (1, None))
- ))
- }
-
- test("rightOuterJoin") {
- sc = new SparkContext("local", "test")
- val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
- val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
- val joined = rdd1.rightOuterJoin(rdd2).collect()
- assert(joined.size === 5)
- assert(joined.toSet === Set(
- (1, (Some(1), 'x')),
- (1, (Some(2), 'x')),
- (2, (Some(1), 'y')),
- (2, (Some(1), 'z')),
- (4, (None, 'w'))
- ))
- }
-
- test("join with no matches") {
- sc = new SparkContext("local", "test")
- val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
- val rdd2 = sc.parallelize(Array((4, 'x'), (5, 'y'), (5, 'z'), (6, 'w')))
- val joined = rdd1.join(rdd2).collect()
- assert(joined.size === 0)
- }
-
- test("join with many output partitions") {
- sc = new SparkContext("local", "test")
- val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
- val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
- val joined = rdd1.join(rdd2, 10).collect()
- assert(joined.size === 4)
- assert(joined.toSet === Set(
- (1, (1, 'x')),
- (1, (2, 'x')),
- (2, (1, 'y')),
- (2, (1, 'z'))
- ))
- }
-
- test("groupWith") {
- sc = new SparkContext("local", "test")
- val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
- val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
- val joined = rdd1.groupWith(rdd2).collect()
- assert(joined.size === 4)
- assert(joined.toSet === Set(
- (1, (ArrayBuffer(1, 2), ArrayBuffer('x'))),
- (2, (ArrayBuffer(1), ArrayBuffer('y', 'z'))),
- (3, (ArrayBuffer(1), ArrayBuffer())),
- (4, (ArrayBuffer(), ArrayBuffer('w')))
- ))
- }
-
- test("zero-partition RDD") {
- sc = new SparkContext("local", "test")
- val emptyDir = Files.createTempDir()
- val file = sc.textFile(emptyDir.getAbsolutePath)
- assert(file.partitions.size == 0)
- assert(file.collect().toList === Nil)
- // Test that a shuffle on the file works, because this used to be a bug
- assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil)
- }
-
- test("keys and values") {
- sc = new SparkContext("local", "test")
- val rdd = sc.parallelize(Array((1, "a"), (2, "b")))
- assert(rdd.keys.collect().toList === List(1, 2))
- assert(rdd.values.collect().toList === List("a", "b"))
- }
-
- test("default partitioner uses partition size") {
- sc = new SparkContext("local", "test")
- // specify 2000 partitions
- val a = sc.makeRDD(Array(1, 2, 3, 4), 2000)
- // do a map, which loses the partitioner
- val b = a.map(a => (a, (a * 2).toString))
- // then a group by, and see we didn't revert to 2 partitions
- val c = b.groupByKey()
- assert(c.partitions.size === 2000)
- }
-
- test("default partitioner uses largest partitioner") {
- sc = new SparkContext("local", "test")
- val a = sc.makeRDD(Array((1, "a"), (2, "b")), 2)
- val b = sc.makeRDD(Array((1, "a"), (2, "b")), 2000)
- val c = a.join(b)
- assert(c.partitions.size === 2000)
- }
-
- test("subtract") {
- sc = new SparkContext("local", "test")
- val a = sc.parallelize(Array(1, 2, 3), 2)
- val b = sc.parallelize(Array(2, 3, 4), 4)
- val c = a.subtract(b)
- assert(c.collect().toSet === Set(1))
- assert(c.partitions.size === a.partitions.size)
- }
-
- test("subtract with narrow dependency") {
- sc = new SparkContext("local", "test")
- // use a deterministic partitioner
- val p = new Partitioner() {
- def numPartitions = 5
- def getPartition(key: Any) = key.asInstanceOf[Int]
- }
- // partitionBy so we have a narrow dependency
- val a = sc.parallelize(Array((1, "a"), (2, "b"), (3, "c"))).partitionBy(p)
- // more partitions/no partitioner so a shuffle dependency
- val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4)
- val c = a.subtract(b)
- assert(c.collect().toSet === Set((1, "a"), (3, "c")))
- // Ideally we could keep the original partitioner...
- assert(c.partitioner === None)
- }
-
- test("subtractByKey") {
- sc = new SparkContext("local", "test")
- val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 2)
- val b = sc.parallelize(Array((2, 20), (3, 30), (4, 40)), 4)
- val c = a.subtractByKey(b)
- assert(c.collect().toSet === Set((1, "a"), (1, "a")))
- assert(c.partitions.size === a.partitions.size)
- }
-
- test("subtractByKey with narrow dependency") {
- sc = new SparkContext("local", "test")
- // use a deterministic partitioner
- val p = new Partitioner() {
- def numPartitions = 5
- def getPartition(key: Any) = key.asInstanceOf[Int]
- }
- // partitionBy so we have a narrow dependency
- val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c"))).partitionBy(p)
- // more partitions/no partitioner so a shuffle dependency
- val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4)
- val c = a.subtractByKey(b)
- assert(c.collect().toSet === Set((1, "a"), (1, "a")))
- assert(c.partitioner.get === p)
- }
-
test("shuffle non-zero block size") {
sc = new SparkContext("local-cluster[2,1,512]", "test")
val NUM_BLOCKS = 3
@@ -367,6 +94,30 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
assert(nonEmptyBlocks.size <= 4)
}
+ test("zero sized blocks without kryo") {
+ // Use a local cluster with 2 processes to make sure there are both local and remote blocks
+ sc = new SparkContext("local-cluster[2,1,512]", "test")
+
+ // 10 partitions from 4 keys
+ val NUM_BLOCKS = 10
+ val a = sc.parallelize(1 to 4, NUM_BLOCKS)
+ val b = a.map(x => (x, x*2))
+
+ // NOTE: The default Java serializer should create zero-sized blocks
+ val c = new ShuffledRDD(b, new HashPartitioner(10))
+
+ val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
+ assert(c.count === 4)
+
+ val blockSizes = (0 until NUM_BLOCKS).flatMap { id =>
+ val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id)
+ statuses.map(x => x._2)
+ }
+ val nonEmptyBlocks = blockSizes.filter(x => x > 0)
+
+ // We should have at most 4 non-zero sized partitions
+ assert(nonEmptyBlocks.size <= 4)
+ }
}
object ShuffleSuite {
diff --git a/core/src/test/scala/spark/SizeEstimatorSuite.scala b/core/src/test/scala/spark/SizeEstimatorSuite.scala
index e235ef2f67..b5c8525f91 100644
--- a/core/src/test/scala/spark/SizeEstimatorSuite.scala
+++ b/core/src/test/scala/spark/SizeEstimatorSuite.scala
@@ -35,7 +35,7 @@ class SizeEstimatorSuite
var oldOops: String = _
override def beforeAll() {
- // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case
+ // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case
oldArch = System.setProperty("os.arch", "amd64")
oldOops = System.setProperty("spark.test.useCompressedOops", "true")
}
@@ -46,54 +46,54 @@ class SizeEstimatorSuite
}
test("simple classes") {
- expect(16)(SizeEstimator.estimate(new DummyClass1))
- expect(16)(SizeEstimator.estimate(new DummyClass2))
- expect(24)(SizeEstimator.estimate(new DummyClass3))
- expect(24)(SizeEstimator.estimate(new DummyClass4(null)))
- expect(48)(SizeEstimator.estimate(new DummyClass4(new DummyClass3)))
+ assert(SizeEstimator.estimate(new DummyClass1) === 16)
+ assert(SizeEstimator.estimate(new DummyClass2) === 16)
+ assert(SizeEstimator.estimate(new DummyClass3) === 24)
+ assert(SizeEstimator.estimate(new DummyClass4(null)) === 24)
+ assert(SizeEstimator.estimate(new DummyClass4(new DummyClass3)) === 48)
}
// NOTE: The String class definition varies across JDK versions (1.6 vs. 1.7) and vendors
// (Sun vs IBM). Use a DummyString class to make tests deterministic.
test("strings") {
- expect(40)(SizeEstimator.estimate(DummyString("")))
- expect(48)(SizeEstimator.estimate(DummyString("a")))
- expect(48)(SizeEstimator.estimate(DummyString("ab")))
- expect(56)(SizeEstimator.estimate(DummyString("abcdefgh")))
+ assert(SizeEstimator.estimate(DummyString("")) === 40)
+ assert(SizeEstimator.estimate(DummyString("a")) === 48)
+ assert(SizeEstimator.estimate(DummyString("ab")) === 48)
+ assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 56)
}
test("primitive arrays") {
- expect(32)(SizeEstimator.estimate(new Array[Byte](10)))
- expect(40)(SizeEstimator.estimate(new Array[Char](10)))
- expect(40)(SizeEstimator.estimate(new Array[Short](10)))
- expect(56)(SizeEstimator.estimate(new Array[Int](10)))
- expect(96)(SizeEstimator.estimate(new Array[Long](10)))
- expect(56)(SizeEstimator.estimate(new Array[Float](10)))
- expect(96)(SizeEstimator.estimate(new Array[Double](10)))
- expect(4016)(SizeEstimator.estimate(new Array[Int](1000)))
- expect(8016)(SizeEstimator.estimate(new Array[Long](1000)))
+ assert(SizeEstimator.estimate(new Array[Byte](10)) === 32)
+ assert(SizeEstimator.estimate(new Array[Char](10)) === 40)
+ assert(SizeEstimator.estimate(new Array[Short](10)) === 40)
+ assert(SizeEstimator.estimate(new Array[Int](10)) === 56)
+ assert(SizeEstimator.estimate(new Array[Long](10)) === 96)
+ assert(SizeEstimator.estimate(new Array[Float](10)) === 56)
+ assert(SizeEstimator.estimate(new Array[Double](10)) === 96)
+ assert(SizeEstimator.estimate(new Array[Int](1000)) === 4016)
+ assert(SizeEstimator.estimate(new Array[Long](1000)) === 8016)
}
test("object arrays") {
// Arrays containing nulls should just have one pointer per element
- expect(56)(SizeEstimator.estimate(new Array[String](10)))
- expect(56)(SizeEstimator.estimate(new Array[AnyRef](10)))
+ assert(SizeEstimator.estimate(new Array[String](10)) === 56)
+ assert(SizeEstimator.estimate(new Array[AnyRef](10)) === 56)
// For object arrays with non-null elements, each object should take one pointer plus
// however many bytes that class takes. (Note that Array.fill calls the code in its
// second parameter separately for each object, so we get distinct objects.)
- expect(216)(SizeEstimator.estimate(Array.fill(10)(new DummyClass1)))
- expect(216)(SizeEstimator.estimate(Array.fill(10)(new DummyClass2)))
- expect(296)(SizeEstimator.estimate(Array.fill(10)(new DummyClass3)))
- expect(56)(SizeEstimator.estimate(Array(new DummyClass1, new DummyClass2)))
+ assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass1)) === 216)
+ assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass2)) === 216)
+ assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass3)) === 296)
+ assert(SizeEstimator.estimate(Array(new DummyClass1, new DummyClass2)) === 56)
// Past size 100, our samples 100 elements, but we should still get the right size.
- expect(28016)(SizeEstimator.estimate(Array.fill(1000)(new DummyClass3)))
+ assert(SizeEstimator.estimate(Array.fill(1000)(new DummyClass3)) === 28016)
// If an array contains the *same* element many times, we should only count it once.
val d1 = new DummyClass1
- expect(72)(SizeEstimator.estimate(Array.fill(10)(d1))) // 10 pointers plus 8-byte object
- expect(432)(SizeEstimator.estimate(Array.fill(100)(d1))) // 100 pointers plus 8-byte object
+ assert(SizeEstimator.estimate(Array.fill(10)(d1)) === 72) // 10 pointers plus 8-byte object
+ assert(SizeEstimator.estimate(Array.fill(100)(d1)) === 432) // 100 pointers plus 8-byte object
// Same thing with huge array containing the same element many times. Note that this won't
// return exactly 4032 because it can't tell that *all* the elements will equal the first
@@ -111,10 +111,10 @@ class SizeEstimatorSuite
val initialize = PrivateMethod[Unit]('initialize)
SizeEstimator invokePrivate initialize()
- expect(40)(SizeEstimator.estimate(DummyString("")))
- expect(48)(SizeEstimator.estimate(DummyString("a")))
- expect(48)(SizeEstimator.estimate(DummyString("ab")))
- expect(56)(SizeEstimator.estimate(DummyString("abcdefgh")))
+ assert(SizeEstimator.estimate(DummyString("")) === 40)
+ assert(SizeEstimator.estimate(DummyString("a")) === 48)
+ assert(SizeEstimator.estimate(DummyString("ab")) === 48)
+ assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 56)
resetOrClear("os.arch", arch)
}
@@ -128,10 +128,10 @@ class SizeEstimatorSuite
val initialize = PrivateMethod[Unit]('initialize)
SizeEstimator invokePrivate initialize()
- expect(56)(SizeEstimator.estimate(DummyString("")))
- expect(64)(SizeEstimator.estimate(DummyString("a")))
- expect(64)(SizeEstimator.estimate(DummyString("ab")))
- expect(72)(SizeEstimator.estimate(DummyString("abcdefgh")))
+ assert(SizeEstimator.estimate(DummyString("")) === 56)
+ assert(SizeEstimator.estimate(DummyString("a")) === 64)
+ assert(SizeEstimator.estimate(DummyString("ab")) === 64)
+ assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 72)
resetOrClear("os.arch", arch)
resetOrClear("spark.test.useCompressedOops", oops)
diff --git a/core/src/test/scala/spark/SortingSuite.scala b/core/src/test/scala/spark/SortingSuite.scala
index 495f957e53..f7bf207c68 100644
--- a/core/src/test/scala/spark/SortingSuite.scala
+++ b/core/src/test/scala/spark/SortingSuite.scala
@@ -5,16 +5,14 @@ import org.scalatest.BeforeAndAfter
import org.scalatest.matchers.ShouldMatchers
import SparkContext._
-class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers with Logging {
-
+class SortingSuite extends FunSuite with SharedSparkContext with ShouldMatchers with Logging {
+
test("sortByKey") {
- sc = new SparkContext("local", "test")
val pairs = sc.parallelize(Array((1, 0), (2, 0), (0, 0), (3, 0)), 2)
- assert(pairs.sortByKey().collect() === Array((0,0), (1,0), (2,0), (3,0)))
+ assert(pairs.sortByKey().collect() === Array((0,0), (1,0), (2,0), (3,0)))
}
test("large array") {
- sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr, 2)
@@ -24,7 +22,6 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w
}
test("large array with one split") {
- sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr, 2)
@@ -32,9 +29,8 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w
assert(sorted.partitions.size === 1)
assert(sorted.collect() === pairArr.sortBy(_._1))
}
-
+
test("large array with many partitions") {
- sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr, 2)
@@ -42,9 +38,8 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w
assert(sorted.partitions.size === 20)
assert(sorted.collect() === pairArr.sortBy(_._1))
}
-
+
test("sort descending") {
- sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr, 2)
@@ -52,15 +47,13 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w
}
test("sort descending with one split") {
- sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr, 1)
assert(pairs.sortByKey(false, 1).collect() === pairArr.sortWith((x, y) => x._1 > y._1))
}
-
+
test("sort descending with many partitions") {
- sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr, 2)
@@ -68,7 +61,6 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w
}
test("more partitions than elements") {
- sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(10) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr, 30)
@@ -76,14 +68,12 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w
}
test("empty RDD") {
- sc = new SparkContext("local", "test")
val pairArr = new Array[(Int, Int)](0)
val pairs = sc.parallelize(pairArr, 2)
assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1))
}
test("partition balancing") {
- sc = new SparkContext("local", "test")
val pairArr = (1 to 1000).map(x => (x, x)).toArray
val sorted = sc.parallelize(pairArr, 4).sortByKey()
assert(sorted.collect() === pairArr.sortBy(_._1))
@@ -99,7 +89,6 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w
}
test("partition balancing for descending sort") {
- sc = new SparkContext("local", "test")
val pairArr = (1 to 1000).map(x => (x, x)).toArray
val sorted = sc.parallelize(pairArr, 4).sortByKey(false)
assert(sorted.collect() === pairArr.sortBy(_._1).reverse)
diff --git a/core/src/test/scala/spark/UnpersistSuite.scala b/core/src/test/scala/spark/UnpersistSuite.scala
new file mode 100644
index 0000000000..94776e7572
--- /dev/null
+++ b/core/src/test/scala/spark/UnpersistSuite.scala
@@ -0,0 +1,30 @@
+package spark
+
+import org.scalatest.FunSuite
+import org.scalatest.concurrent.Timeouts._
+import org.scalatest.time.{Span, Millis}
+import spark.SparkContext._
+
+class UnpersistSuite extends FunSuite with LocalSparkContext {
+ test("unpersist RDD") {
+ sc = new SparkContext("local", "test")
+ val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
+ rdd.count
+ assert(sc.persistentRdds.isEmpty === false)
+ rdd.unpersist()
+ assert(sc.persistentRdds.isEmpty === true)
+
+ failAfter(Span(3000, Millis)) {
+ try {
+ while (! sc.getRDDStorageInfo.isEmpty) {
+ Thread.sleep(200)
+ }
+ } catch {
+ case _ => { Thread.sleep(10) }
+ // Do nothing. We might see exceptions because block manager
+ // is racing this thread to remove entries from the driver.
+ }
+ }
+ assert(sc.getRDDStorageInfo.isEmpty === true)
+ }
+}
diff --git a/core/src/test/scala/spark/UtilsSuite.scala b/core/src/test/scala/spark/UtilsSuite.scala
index ed4701574f..4a113e16bf 100644
--- a/core/src/test/scala/spark/UtilsSuite.scala
+++ b/core/src/test/scala/spark/UtilsSuite.scala
@@ -27,24 +27,49 @@ class UtilsSuite extends FunSuite {
assert(os.toByteArray.toList.equals(bytes.toList))
}
- test("memoryStringToMb"){
- assert(Utils.memoryStringToMb("1") == 0)
- assert(Utils.memoryStringToMb("1048575") == 0)
- assert(Utils.memoryStringToMb("3145728") == 3)
+ test("memoryStringToMb") {
+ assert(Utils.memoryStringToMb("1") === 0)
+ assert(Utils.memoryStringToMb("1048575") === 0)
+ assert(Utils.memoryStringToMb("3145728") === 3)
- assert(Utils.memoryStringToMb("1024k") == 1)
- assert(Utils.memoryStringToMb("5000k") == 4)
- assert(Utils.memoryStringToMb("4024k") == Utils.memoryStringToMb("4024K"))
+ assert(Utils.memoryStringToMb("1024k") === 1)
+ assert(Utils.memoryStringToMb("5000k") === 4)
+ assert(Utils.memoryStringToMb("4024k") === Utils.memoryStringToMb("4024K"))
- assert(Utils.memoryStringToMb("1024m") == 1024)
- assert(Utils.memoryStringToMb("5000m") == 5000)
- assert(Utils.memoryStringToMb("4024m") == Utils.memoryStringToMb("4024M"))
+ assert(Utils.memoryStringToMb("1024m") === 1024)
+ assert(Utils.memoryStringToMb("5000m") === 5000)
+ assert(Utils.memoryStringToMb("4024m") === Utils.memoryStringToMb("4024M"))
- assert(Utils.memoryStringToMb("2g") == 2048)
- assert(Utils.memoryStringToMb("3g") == Utils.memoryStringToMb("3G"))
+ assert(Utils.memoryStringToMb("2g") === 2048)
+ assert(Utils.memoryStringToMb("3g") === Utils.memoryStringToMb("3G"))
- assert(Utils.memoryStringToMb("2t") == 2097152)
- assert(Utils.memoryStringToMb("3t") == Utils.memoryStringToMb("3T"))
+ assert(Utils.memoryStringToMb("2t") === 2097152)
+ assert(Utils.memoryStringToMb("3t") === Utils.memoryStringToMb("3T"))
+ }
+
+ test("splitCommandString") {
+ assert(Utils.splitCommandString("") === Seq())
+ assert(Utils.splitCommandString("a") === Seq("a"))
+ assert(Utils.splitCommandString("aaa") === Seq("aaa"))
+ assert(Utils.splitCommandString("a b c") === Seq("a", "b", "c"))
+ assert(Utils.splitCommandString(" a b\t c ") === Seq("a", "b", "c"))
+ assert(Utils.splitCommandString("a 'b c'") === Seq("a", "b c"))
+ assert(Utils.splitCommandString("a 'b c' d") === Seq("a", "b c", "d"))
+ assert(Utils.splitCommandString("'b c'") === Seq("b c"))
+ assert(Utils.splitCommandString("a \"b c\"") === Seq("a", "b c"))
+ assert(Utils.splitCommandString("a \"b c\" d") === Seq("a", "b c", "d"))
+ assert(Utils.splitCommandString("\"b c\"") === Seq("b c"))
+ assert(Utils.splitCommandString("a 'b\" c' \"d' e\"") === Seq("a", "b\" c", "d' e"))
+ assert(Utils.splitCommandString("a\t'b\nc'\nd") === Seq("a", "b\nc", "d"))
+ assert(Utils.splitCommandString("a \"b\\\\c\"") === Seq("a", "b\\c"))
+ assert(Utils.splitCommandString("a \"b\\\"c\"") === Seq("a", "b\"c"))
+ assert(Utils.splitCommandString("a 'b\\\"c'") === Seq("a", "b\\\"c"))
+ assert(Utils.splitCommandString("'a'b") === Seq("ab"))
+ assert(Utils.splitCommandString("'a''b'") === Seq("ab"))
+ assert(Utils.splitCommandString("\"a\"b") === Seq("ab"))
+ assert(Utils.splitCommandString("\"a\"\"b\"") === Seq("ab"))
+ assert(Utils.splitCommandString("''") === Seq(""))
+ assert(Utils.splitCommandString("\"\"") === Seq(""))
}
}
diff --git a/core/src/test/scala/spark/ZippedPartitionsSuite.scala b/core/src/test/scala/spark/ZippedPartitionsSuite.scala
index 5f60aa75d7..96cb295f45 100644
--- a/core/src/test/scala/spark/ZippedPartitionsSuite.scala
+++ b/core/src/test/scala/spark/ZippedPartitionsSuite.scala
@@ -17,9 +17,8 @@ object ZippedPartitionsSuite {
}
}
-class ZippedPartitionsSuite extends FunSuite with LocalSparkContext {
+class ZippedPartitionsSuite extends FunSuite with SharedSparkContext {
test("print sizes") {
- sc = new SparkContext("local", "test")
val data1 = sc.makeRDD(Array(1, 2, 3, 4), 2)
val data2 = sc.makeRDD(Array("1", "2", "3", "4", "5", "6"), 2)
val data3 = sc.makeRDD(Array(1.0, 2.0), 2)
diff --git a/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala b/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala
new file mode 100644
index 0000000000..699901f1a1
--- /dev/null
+++ b/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala
@@ -0,0 +1,104 @@
+package spark.scheduler
+
+import java.util.Properties
+import java.util.concurrent.LinkedBlockingQueue
+import org.scalatest.FunSuite
+import org.scalatest.matchers.ShouldMatchers
+import scala.collection.mutable
+import spark._
+import spark.SparkContext._
+
+
+class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
+
+ test("inner method") {
+ sc = new SparkContext("local", "joblogger")
+ val joblogger = new JobLogger {
+ def createLogWriterTest(jobID: Int) = createLogWriter(jobID)
+ def closeLogWriterTest(jobID: Int) = closeLogWriter(jobID)
+ def getRddNameTest(rdd: RDD[_]) = getRddName(rdd)
+ def buildJobDepTest(jobID: Int, stage: Stage) = buildJobDep(jobID, stage)
+ }
+ type MyRDD = RDD[(Int, Int)]
+ def makeRdd(
+ numPartitions: Int,
+ dependencies: List[Dependency[_]]
+ ): MyRDD = {
+ val maxPartition = numPartitions - 1
+ return new MyRDD(sc, dependencies) {
+ override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] =
+ throw new RuntimeException("should not be reached")
+ override def getPartitions = (0 to maxPartition).map(i => new Partition {
+ override def index = i
+ }).toArray
+ }
+ }
+ val jobID = 5
+ val parentRdd = makeRdd(4, Nil)
+ val shuffleDep = new ShuffleDependency(parentRdd, null)
+ val rootRdd = makeRdd(4, List(shuffleDep))
+ val shuffleMapStage = new Stage(1, parentRdd, Some(shuffleDep), Nil, jobID)
+ val rootStage = new Stage(0, rootRdd, None, List(shuffleMapStage), jobID)
+
+ joblogger.onStageSubmitted(SparkListenerStageSubmitted(rootStage, 4))
+ joblogger.getRddNameTest(parentRdd) should be (parentRdd.getClass.getName)
+ parentRdd.setName("MyRDD")
+ joblogger.getRddNameTest(parentRdd) should be ("MyRDD")
+ joblogger.createLogWriterTest(jobID)
+ joblogger.getJobIDtoPrintWriter.size should be (1)
+ joblogger.buildJobDepTest(jobID, rootStage)
+ joblogger.getJobIDToStages.get(jobID).get.size should be (2)
+ joblogger.getStageIDToJobID.get(0) should be (Some(jobID))
+ joblogger.getStageIDToJobID.get(1) should be (Some(jobID))
+ joblogger.closeLogWriterTest(jobID)
+ joblogger.getStageIDToJobID.size should be (0)
+ joblogger.getJobIDToStages.size should be (0)
+ joblogger.getJobIDtoPrintWriter.size should be (0)
+ }
+
+ test("inner variables") {
+ sc = new SparkContext("local[4]", "joblogger")
+ val joblogger = new JobLogger {
+ override protected def closeLogWriter(jobID: Int) =
+ getJobIDtoPrintWriter.get(jobID).foreach { fileWriter =>
+ fileWriter.close()
+ }
+ }
+ sc.addSparkListener(joblogger)
+ val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) }
+ rdd.reduceByKey(_+_).collect()
+
+ joblogger.getLogDir should be ("/tmp/spark")
+ joblogger.getJobIDtoPrintWriter.size should be (1)
+ joblogger.getStageIDToJobID.size should be (2)
+ joblogger.getStageIDToJobID.get(0) should be (Some(0))
+ joblogger.getStageIDToJobID.get(1) should be (Some(0))
+ joblogger.getJobIDToStages.size should be (1)
+ }
+
+
+ test("interface functions") {
+ sc = new SparkContext("local[4]", "joblogger")
+ val joblogger = new JobLogger {
+ var onTaskEndCount = 0
+ var onJobEndCount = 0
+ var onJobStartCount = 0
+ var onStageCompletedCount = 0
+ var onStageSubmittedCount = 0
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = onTaskEndCount += 1
+ override def onJobEnd(jobEnd: SparkListenerJobEnd) = onJobEndCount += 1
+ override def onJobStart(jobStart: SparkListenerJobStart) = onJobStartCount += 1
+ override def onStageCompleted(stageCompleted: StageCompleted) = onStageCompletedCount += 1
+ override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) = onStageSubmittedCount += 1
+ }
+ sc.addSparkListener(joblogger)
+ val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) }
+ rdd.reduceByKey(_+_).collect()
+
+ joblogger.onJobStartCount should be (1)
+ joblogger.onJobEndCount should be (1)
+ joblogger.onTaskEndCount should be (8)
+ joblogger.onStageSubmittedCount should be (2)
+ joblogger.onStageCompletedCount should be (2)
+ }
+}
diff --git a/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala
index 42a87d8b90..48aa67c543 100644
--- a/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala
+++ b/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala
@@ -77,7 +77,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
class SaveStageInfo extends SparkListener {
val stageInfos = mutable.Buffer[StageInfo]()
- def onStageCompleted(stage: StageCompleted) {
+ override def onStageCompleted(stage: StageCompleted) {
stageInfos += stage.stageInfo
}
}
diff --git a/docs/configuration.md b/docs/configuration.md
index 17fdbf04d1..3266db7af1 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -22,26 +22,30 @@ the copy executable.
Inside `spark-env.sh`, you *must* set at least the following two variables:
-* `SCALA_HOME`, to point to your Scala installation.
+* `SCALA_HOME`, to point to your Scala installation, or `SCALA_LIBRARY_PATH` to point to the directory for Scala
+ library JARs (if you install Scala as a Debian or RPM package, there is no `SCALA_HOME`, but these libraries
+ are in a separate path, typically /usr/share/java; look for `scala-library.jar`).
* `MESOS_NATIVE_LIBRARY`, if you are [running on a Mesos cluster](running-on-mesos.html).
-In addition, there are four other variables that control execution. These can be set *either in `spark-env.sh`
-or in each job's driver program*, because they will automatically be propagated to workers from the driver.
-For a multi-user environment, we recommend setting the in the driver program instead of `spark-env.sh`, so
-that different user jobs can use different amounts of memory, JVM options, etc.
+In addition, there are four other variables that control execution. These should be set *in the environment that
+launches the job's driver program* instead of `spark-env.sh`, because they will be automatically propagated to
+workers. Setting these per-job instead of in `spark-env.sh` ensures that different jobs can have different settings
+for these variables.
-* `SPARK_MEM`, to set the amount of memory used per node (this should be in the same format as the
- JVM's -Xmx option, e.g. `300m` or `1g`)
* `SPARK_JAVA_OPTS`, to add JVM options. This includes any system properties that you'd like to pass with `-D`.
* `SPARK_CLASSPATH`, to add elements to Spark's classpath.
* `SPARK_LIBRARY_PATH`, to add search directories for native libraries.
+* `SPARK_MEM`, to set the amount of memory used per node. This should be in the same format as the
+ JVM's -Xmx option, e.g. `300m` or `1g`. Note that this option will soon be deprecated in favor of
+ the `spark.executor.memory` system property, so we recommend using that in new code.
-Note that if you do set these in `spark-env.sh`, they will override the values set by user programs, which
-is undesirable; you can choose to have `spark-env.sh` set them only if the user program hasn't, as follows:
+Beware that if you do set these variables in `spark-env.sh`, they will override the values set by user programs,
+which is undesirable; if you prefer, you can choose to have `spark-env.sh` set them only if the user program
+hasn't, as follows:
{% highlight bash %}
-if [ -z "$SPARK_MEM" ] ; then
- SPARK_MEM="1g"
+if [ -z "$SPARK_JAVA_OPTS" ] ; then
+ SPARK_JAVA_OPTS="-verbose:gc"
fi
{% endhighlight %}
@@ -55,11 +59,18 @@ val sc = new SparkContext(...)
{% endhighlight %}
Most of the configurable system properties control internal settings that have reasonable default values. However,
-there are at least four properties that you will commonly want to control:
+there are at least five properties that you will commonly want to control:
<table class="table">
<tr><th>Property Name</th><th>Default</th><th>Meaning</th></tr>
<tr>
+ <td>spark.executor.memory</td>
+ <td>512m</td>
+ <td>
+ Amount of memory to use per executor process, in the same format as JVM memory strings (e.g. `512m`, `2g`).
+ </td>
+</tr>
+<tr>
<td>spark.serializer</td>
<td>spark.JavaSerializer</td>
<td>
@@ -260,6 +271,13 @@ Apart from these, the following properties are also available, and may be useful
applications). Note that any RDD that persists in memory for more than this duration will be cleared as well.
</td>
</tr>
+<tr>
+ <td>spark.streaming.blockInterval</td>
+ <td>200</td>
+ <td>
+ Duration (milliseconds) of how long to batch new objects coming from network receivers.
+ </td>
+</tr>
</table>
diff --git a/docs/ec2-scripts.md b/docs/ec2-scripts.md
index dc57035eba..eab8a0ff20 100644
--- a/docs/ec2-scripts.md
+++ b/docs/ec2-scripts.md
@@ -106,9 +106,8 @@ permissions on your private key file, you can run `launch` with the
# Configuration
You can edit `/root/spark/conf/spark-env.sh` on each machine to set Spark configuration options, such
-as JVM options and, most crucially, the amount of memory to use per machine (`SPARK_MEM`).
-This file needs to be copied to **every machine** to reflect the change. The easiest way to do this
-is to use a script we provide called `copy-dir`. First edit your `spark-env.sh` file on the master,
+as JVM options. This file needs to be copied to **every machine** to reflect the change. The easiest way to
+do this is to use a script we provide called `copy-dir`. First edit your `spark-env.sh` file on the master,
then run `~/spark-ec2/copy-dir /root/spark/conf` to RSYNC it to all the workers.
The [configuration guide](configuration.html) describes the available configuration options.
diff --git a/docs/python-programming-guide.md b/docs/python-programming-guide.md
index 3a7a8db4a6..e8aaac74d0 100644
--- a/docs/python-programming-guide.md
+++ b/docs/python-programming-guide.md
@@ -17,24 +17,23 @@ There are a few key differences between the Python and Scala APIs:
* Python is dynamically typed, so RDDs can hold objects of different types.
* PySpark does not currently support the following Spark features:
- Special functions on RDDs of doubles, such as `mean` and `stdev`
- - `lookup`
+ - `lookup`, `sample` and `sort`
- `persist` at storage levels other than `MEMORY_ONLY`
- - `sample`
- - `sort`
+ - Execution on Windows -- this is slated for a future release
In PySpark, RDDs support the same methods as their Scala counterparts but take Python functions and return Python collection types.
Short functions can be passed to RDD methods using Python's [`lambda`](http://www.diveintopython.net/power_of_introspection/lambda_functions.html) syntax:
{% highlight python %}
logData = sc.textFile(logFile).cache()
-errors = logData.filter(lambda s: 'ERROR' in s.split())
+errors = logData.filter(lambda line: "ERROR" in line)
{% endhighlight %}
You can also pass functions that are defined using the `def` keyword; this is useful for more complicated functions that cannot be expressed using `lambda`:
{% highlight python %}
def is_error(line):
- return 'ERROR' in line.split()
+ return "ERROR" in line
errors = logData.filter(is_error)
{% endhighlight %}
@@ -43,8 +42,7 @@ Functions can access objects in enclosing scopes, although modifications to thos
{% highlight python %}
error_keywords = ["Exception", "Error"]
def is_error(line):
- words = line.split()
- return any(keyword in words for keyword in error_keywords)
+ return any(keyword in line for keyword in error_keywords)
errors = logData.filter(is_error)
{% endhighlight %}
diff --git a/docs/scala-programming-guide.md b/docs/scala-programming-guide.md
index b0da130fcb..e9cf9ef36f 100644
--- a/docs/scala-programming-guide.md
+++ b/docs/scala-programming-guide.md
@@ -43,12 +43,18 @@ new SparkContext(master, appName, [sparkHome], [jars])
The `master` parameter is a string specifying a [Spark or Mesos cluster URL](#master-urls) to connect to, or a special "local" string to run in local mode, as described below. `appName` is a name for your application, which will be shown in the cluster web UI. Finally, the last two parameters are needed to deploy your code to a cluster if running in distributed mode, as described later.
-In the Spark shell, a special interpreter-aware SparkContext is already created for you, in the variable called `sc`. Making your own SparkContext will not work. You can set which master the context connects to using the `MASTER` environment variable. For example, to run on four cores, use
+In the Spark shell, a special interpreter-aware SparkContext is already created for you, in the variable called `sc`. Making your own SparkContext will not work. You can set which master the context connects to using the `MASTER` environment variable, and you can add JARs to the classpath with the `ADD_JARS` variable. For example, to run `spark-shell` on four cores, use
{% highlight bash %}
$ MASTER=local[4] ./spark-shell
{% endhighlight %}
+Or, to also add `code.jar` to its classpath, use:
+
+{% highlight bash %}
+$ MASTER=local[4] ADD_JARS=code.jar ./spark-shell
+{% endhighlight %}
+
### Master URLs
The master URL passed to Spark can be in one of the following formats:
@@ -78,7 +84,7 @@ If you want to run your job on a cluster, you will need to specify the two optio
* `sparkHome`: The path at which Spark is installed on your worker machines (it should be the same on all of them).
* `jars`: A list of JAR files on the local machine containing your job's code and any dependencies, which Spark will deploy to all the worker nodes. You'll need to package your job into a set of JARs using your build system. For example, if you're using SBT, the [sbt-assembly](https://github.com/sbt/sbt-assembly) plugin is a good way to make a single JAR with your code and dependencies.
-If you run `spark-shell` on a cluster, any classes you define in the shell will automatically be distributed.
+If you run `spark-shell` on a cluster, you can add JARs to it by specifying the `ADD_JARS` environment variable before you launch it. This variable should contain a comma-separated list of JARs. For example, `ADD_JARS=a.jar,b.jar ./spark-shell` will launch a shell with `a.jar` and `b.jar` on its classpath. In addition, any new classes you define in the shell will automatically be distributed.
# Resilient Distributed Datasets (RDDs)
diff --git a/docs/tuning.md b/docs/tuning.md
index 32c7ab86e9..5ffca54481 100644
--- a/docs/tuning.md
+++ b/docs/tuning.md
@@ -157,9 +157,9 @@ their work directories), *not* on your driver program.
**Cache Size Tuning**
-One important configuration parameter for GC is the amount of memory that should be used for
-caching RDDs. By default, Spark uses 66% of the configured memory (`SPARK_MEM`) to cache RDDs. This means that
- 33% of memory is available for any objects created during task execution.
+One important configuration parameter for GC is the amount of memory that should be used for caching RDDs.
+By default, Spark uses 66% of the configured executor memory (`spark.executor.memory` or `SPARK_MEM`) to
+cache RDDs. This means that 33% of memory is available for any objects created during task execution.
In case your tasks slow down and you find that your JVM is garbage-collecting frequently or running out of
memory, lowering this value will help reduce the memory consumption. To change this to say 50%, you can call
diff --git a/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala b/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala
index c3a9e491ba..9202e65e09 100644
--- a/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala
+++ b/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala
@@ -37,7 +37,7 @@ object KafkaWordCount {
ssc.checkpoint("checkpoint")
val topicpMap = topics.split(",").map((_,numThreads.toInt)).toMap
- val lines = ssc.kafkaStream[String](zkQuorum, group, topicpMap)
+ val lines = ssc.kafkaStream(zkQuorum, group, topicpMap)
val words = lines.flatMap(_.split(" "))
val wordCounts = words.map(x => (x, 1l)).reduceByKeyAndWindow(add _, subtract _, Minutes(10), Seconds(2), 2)
wordCounts.print()
diff --git a/examples/src/main/scala/spark/streaming/examples/StatefulNetworkWordCount.scala b/examples/src/main/scala/spark/streaming/examples/StatefulNetworkWordCount.scala
new file mode 100644
index 0000000000..51c3c9f9b4
--- /dev/null
+++ b/examples/src/main/scala/spark/streaming/examples/StatefulNetworkWordCount.scala
@@ -0,0 +1,50 @@
+package spark.streaming.examples
+
+import spark.streaming._
+import spark.streaming.StreamingContext._
+
+/**
+ * Counts words cumulatively in UTF8 encoded, '\n' delimited text received from the network every second.
+ * Usage: StatefulNetworkWordCount <master> <hostname> <port>
+ * <master> is the Spark master URL. In local mode, <master> should be 'local[n]' with n > 1.
+ * <hostname> and <port> describe the TCP server that Spark Streaming would connect to receive data.
+ *
+ * To run this on your local machine, you need to first run a Netcat server
+ * `$ nc -lk 9999`
+ * and then run the example
+ * `$ ./run spark.streaming.examples.StatefulNetworkWordCount local[2] localhost 9999`
+ */
+object StatefulNetworkWordCount {
+ def main(args: Array[String]) {
+ if (args.length < 3) {
+ System.err.println("Usage: StatefulNetworkWordCount <master> <hostname> <port>\n" +
+ "In local mode, <master> should be 'local[n]' with n > 1")
+ System.exit(1)
+ }
+
+ val updateFunc = (values: Seq[Int], state: Option[Int]) => {
+ val currentCount = values.foldLeft(0)(_ + _)
+
+ val previousCount = state.getOrElse(0)
+
+ Some(currentCount + previousCount)
+ }
+
+ // Create the context with a 1 second batch size
+ val ssc = new StreamingContext(args(0), "NetworkWordCumulativeCountUpdateStateByKey", Seconds(1),
+ System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR")))
+ ssc.checkpoint(".")
+
+ // Create a NetworkInputDStream on target ip:port and count the
+ // words in input stream of \n delimited test (eg. generated by 'nc')
+ val lines = ssc.socketTextStream(args(1), args(2).toInt)
+ val words = lines.flatMap(_.split(" "))
+ val wordDstream = words.map(x => (x, 1))
+
+ // Update the cumulative count using updateStateByKey
+ // This will give a Dstream made of state (which is the cumulative count of the words)
+ val stateDstream = wordDstream.updateStateByKey[Int](updateFunc)
+ stateDstream.print()
+ ssc.start()
+ }
+}
diff --git a/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdCMS.scala b/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdCMS.scala
index a9642100e3..528778ed72 100644
--- a/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdCMS.scala
+++ b/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdCMS.scala
@@ -26,8 +26,8 @@ import spark.SparkContext._
*/
object TwitterAlgebirdCMS {
def main(args: Array[String]) {
- if (args.length < 3) {
- System.err.println("Usage: TwitterAlgebirdCMS <master> <twitter_username> <twitter_password>" +
+ if (args.length < 1) {
+ System.err.println("Usage: TwitterAlgebirdCMS <master>" +
" [filter1] [filter2] ... [filter n]")
System.exit(1)
}
@@ -40,12 +40,11 @@ object TwitterAlgebirdCMS {
// K highest frequency elements to take
val TOPK = 10
- val Array(master, username, password) = args.slice(0, 3)
- val filters = args.slice(3, args.length)
+ val (master, filters) = (args.head, args.tail)
val ssc = new StreamingContext(master, "TwitterAlgebirdCMS", Seconds(10),
System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR")))
- val stream = ssc.twitterStream(username, password, filters, StorageLevel.MEMORY_ONLY_SER)
+ val stream = ssc.twitterStream(None, filters, StorageLevel.MEMORY_ONLY_SER)
val users = stream.map(status => status.getUser.getId)
diff --git a/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdHLL.scala b/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdHLL.scala
index f3288bfb85..896e9fd8af 100644
--- a/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdHLL.scala
+++ b/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdHLL.scala
@@ -21,20 +21,19 @@ import spark.streaming.dstream.TwitterInputDStream
*/
object TwitterAlgebirdHLL {
def main(args: Array[String]) {
- if (args.length < 3) {
- System.err.println("Usage: TwitterAlgebirdHLL <master> <twitter_username> <twitter_password>" +
+ if (args.length < 1) {
+ System.err.println("Usage: TwitterAlgebirdHLL <master>" +
" [filter1] [filter2] ... [filter n]")
System.exit(1)
}
/** Bit size parameter for HyperLogLog, trades off accuracy vs size */
val BIT_SIZE = 12
- val Array(master, username, password) = args.slice(0, 3)
- val filters = args.slice(3, args.length)
+ val (master, filters) = (args.head, args.tail)
val ssc = new StreamingContext(master, "TwitterAlgebirdHLL", Seconds(5),
System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR")))
- val stream = ssc.twitterStream(username, password, filters, StorageLevel.MEMORY_ONLY_SER)
+ val stream = ssc.twitterStream(None, filters, StorageLevel.MEMORY_ONLY_SER)
val users = stream.map(status => status.getUser.getId)
diff --git a/examples/src/main/scala/spark/streaming/examples/TwitterPopularTags.scala b/examples/src/main/scala/spark/streaming/examples/TwitterPopularTags.scala
index 9d4494c6f2..65f0b6d352 100644
--- a/examples/src/main/scala/spark/streaming/examples/TwitterPopularTags.scala
+++ b/examples/src/main/scala/spark/streaming/examples/TwitterPopularTags.scala
@@ -12,18 +12,17 @@ import spark.SparkContext._
*/
object TwitterPopularTags {
def main(args: Array[String]) {
- if (args.length < 3) {
- System.err.println("Usage: TwitterPopularTags <master> <twitter_username> <twitter_password>" +
+ if (args.length < 1) {
+ System.err.println("Usage: TwitterPopularTags <master>" +
" [filter1] [filter2] ... [filter n]")
System.exit(1)
}
- val Array(master, username, password) = args.slice(0, 3)
- val filters = args.slice(3, args.length)
+ val (master, filters) = (args.head, args.tail)
val ssc = new StreamingContext(master, "TwitterPopularTags", Seconds(2),
System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR")))
- val stream = ssc.twitterStream(username, password, filters)
+ val stream = ssc.twitterStream(None, filters)
val hashTags = stream.flatMap(status => status.getText.split(" ").filter(_.startsWith("#")))
diff --git a/mllib/data/als/test.data b/mllib/data/als/test.data
new file mode 100644
index 0000000000..e476cc23e0
--- /dev/null
+++ b/mllib/data/als/test.data
@@ -0,0 +1,16 @@
+1,1,5.0
+1,2,1.0
+1,3,5.0
+1,4,1.0
+2,1,5.0
+2,2,1.0
+2,3,5.0
+2,4,1.0
+3,1,1.0
+3,2,5.0
+3,3,1.0
+3,4,5.0
+4,1,1.0
+4,2,5.0
+4,3,1.0
+4,4,5.0
diff --git a/mllib/data/lr-data/random.data b/mllib/data/lr-data/random.data
new file mode 100755
index 0000000000..29bcb8acba
--- /dev/null
+++ b/mllib/data/lr-data/random.data
@@ -0,0 +1,1000 @@
+0.0,-0.19138793197590276 0.7834675900121327
+1.0,3.712420417753061 3.55967640829891
+0.0,-0.3173743619974614 0.9034702789806682
+1.0,4.759494447180777 3.407011867344781
+0.0,-0.7078607074437426 -0.7866705652344417
+1.0,2.6708084832010215 2.5322909406378016
+0.0,-0.07553885038446313 -0.1297104483563081
+1.0,2.759487072285262 2.474689814713741
+0.0,-2.2199161547238107 0.7543109438660762
+1.0,1.922617509832946 1.9412373902594937
+0.0,0.8140942462004225 1.883920822277784
+1.0,1.7649295902120172 3.8195077526061363
+0.0,-1.1173052428096684 -1.468964723960145
+1.0,1.8733449544967458 2.913026590975709
+0.0,-0.11212965215910947 1.068087981775071
+1.0,2.3368459971730227 5.453870208593922
+0.0,-1.2802488543364463 -0.47218504171867676
+1.0,4.1917343620336895 3.5602286778418355
+0.0,0.5995976502137177 -0.797374550890321
+1.0,3.721592294428238 4.824418090974808
+0.0,-0.0721649164244053 -1.3952880192542576
+1.0,3.609764030146346 3.4730043476891277
+0.0,-1.5078269860498976 -2.6460421495665987
+1.0,1.8510254911824193 1.6748364225650059
+0.0,1.021485727769095 -0.14476425336866738
+1.0,4.10105000223134 2.3772502437548493
+0.0,2.6132710211418675 -1.061646527586342
+1.0,2.6444875273854653 4.043302750329545
+0.0,1.115723715938777 0.38401588153403887
+1.0,2.045759949164019 3.156447533448806
+0.0,-1.0543022640565405 -0.6820337845705753
+1.0,3.535337069948117 3.8121122972294965
+0.0,0.9427529503486505 -0.25123516319259886
+1.0,3.9611643301316795 3.3144121016644443
+0.0,-0.15013188927817916 0.8178862482229886
+1.0,3.200504584029051 2.3088398886136057
+0.0,0.819731993393585 -0.47386644109886344
+1.0,3.283317566020217 3.4828146842654513
+0.0,-2.3283941193793303 -0.6148925379529
+1.0,3.901670215294089 3.6356776610143324
+0.0,-0.28635769830042973 0.049586437072917544
+1.0,3.1114746381043927 3.6314805300338775
+0.0,-1.3085536069757229 0.11172767926766304
+1.0,3.3676979357140744 4.689661419564771
+0.0,-1.5820787210442733 1.3226576351191428
+1.0,2.5957586701668207 3.0648240201825923
+0.0,-2.116823743560968 0.272822309954307
+1.0,3.31672509500716 3.870172182480263
+0.0,0.09751166932653511 0.6469052579904877
+1.0,2.0609623373451305 3.9496181906908694
+0.0,0.5238217321419351 -1.2424816480725946
+1.0,3.5731384504449717 5.293293512805712
+0.0,-0.8507917425723299 -1.2243124053200718
+1.0,3.3060954421001867 3.1337045819604565
+0.0,1.5066706426420082 0.04176666807070882
+1.0,4.197316426430547 2.327643377792433
+0.0,-1.8068158696573955 -1.6380836149377855
+1.0,3.568239793850545 3.561688791420822
+0.0,0.4705756905309871 1.1991675114038487
+1.0,4.85003762884306 4.253420553408024
+0.0,0.7595792932847568 0.014062431397674205
+1.0,1.6984862661221896 1.7746925013882613
+0.0,0.1132294255888917 -0.09228036942051128
+1.0,3.766092539171029 2.765647342841482
+0.0,1.053401788561791 -1.0588667339849278
+1.0,2.780021685872393 3.239478188786074
+0.0,0.4042022490052266 1.0982210323828034
+1.0,2.4939569547402063 2.4615506964861273
+0.0,0.4469359967563411 0.3880418183993791
+1.0,2.7943749030887486 3.742182807141721
+0.0,-0.4418685162293727 0.802180923066725
+1.0,3.711213212127241 4.620177703831104
+0.0,0.10737314976605918 -1.5716142960765325
+1.0,4.0522289913808365 3.77562942835957
+0.0,1.4798827061781141 1.1638601205648005
+1.0,3.6758023575825547 3.115500589955362
+0.0,-1.803338141681238 -0.639996207387159
+1.0,2.044667029270621 3.04922768663927
+0.0,-0.06067427095346295 1.394611410740688
+1.0,4.626495834477846 2.995800202291488
+0.0,-0.2770274350630315 0.4521526506693692
+1.0,3.130857841268635 3.76858860814448
+0.0,2.163400739017478 -1.303601716798734
+1.0,2.9131896969824367 3.4288919990054167
+0.0,-0.7145108501670207 1.4189762494365543
+1.0,3.535768896041034 1.4894011726406373
+0.0,1.605614523747256 0.29974289519139824
+1.0,2.413678734728178 2.1826316767457183
+0.0,-0.8821932593373774 0.26432786248412726
+1.0,2.0878695933047116 3.5277388966365177
+0.0,-1.107001191509183 0.38421647065699477
+1.0,2.6462094774496454 2.273786785429519
+0.0,1.0712046043765102 -1.1889735666835115
+1.0,3.7458483094910666 1.3868020542832566
+0.0,-0.8403883736429167 -0.7163969561320671
+1.0,3.3359151000342195 3.2382001552279576
+0.0,0.13309387098922537 0.938761191821517
+1.0,2.083439571838502 3.2204948086228944
+0.0,1.3030219848568272 0.5976630914634896
+1.0,2.7602376200551317 2.200505791897739
+0.0,-0.9458633178207942 0.0490955863627428
+1.0,3.7998466026531883 1.9291683955712686
+0.0,-1.327236501803235 0.06915643957270164
+1.0,3.4740573335685925 2.1080735512507114
+0.0,0.8627688253416859 -1.961802291046532
+1.0,3.5108780392869776 3.9854745964798326
+0.0,-0.69537574439301 0.2436269580373554
+1.0,2.920286302932126 4.704192389485899
+0.0,-2.031190954684878 -0.7843052045579578
+1.0,1.6768848711259499 1.345658047606076
+0.0,0.9234894202027507 -0.38179572928866495
+1.0,3.1710339307651334 4.129874876536583
+0.0,-2.5086697007630376 -0.2638692986795807
+1.0,2.079400422215581 3.124756711992435
+0.0,-0.1388012859869782 0.3698243463601514
+1.0,2.665728164475424 4.574860576068532
+0.0,0.11967116650891912 -0.8792117975750646
+1.0,3.042630437105455 2.7245525508413677
+0.0,0.6078023848042808 -0.7977233104047035
+1.0,3.3340709038589638 4.962729210819017
+0.0,0.6373101353982795 1.1335021278327686
+1.0,3.3821397455119446 4.349379573895378
+0.0,-0.9140176931412027 -0.03428220013900756
+1.0,4.579963977595727 3.8322809335521484
+0.0,-0.43958506434874983 0.21259366700539037
+1.0,2.644701808902675 3.945416465403505
+0.0,-1.119921743746522 -0.2089105317801997
+1.0,2.5480553203091922 3.123344220515146
+0.0,0.8723990414181355 1.11150972420879
+1.0,4.479600967837827 2.8645066949820057
+0.0,-0.003869320481891422 0.24756134775982133
+1.0,3.237294368758498 4.642548547098718
+0.0,0.34643329685515545 0.029869480691029456
+1.0,2.6324740490008893 1.2577448307260846
+0.0,-0.4416403319035849 -1.4597062027342758
+1.0,1.764049052224297 3.649850384544675
+0.0,0.6779287737716254 -1.9489876700506967
+1.0,1.4286669812409405 2.4906452014102416
+0.0,-1.2271599940693638 0.9869686407012563
+1.0,3.6244117441765993 2.36879554315985
+0.0,-0.11422653411940642 0.4741905017884626
+1.0,3.6192153991840694 2.149436181779614
+0.0,0.45425900443207484 -1.357987041493406
+1.0,4.312295702128074 3.7596991900930252
+0.0,-0.35153502234686884 -0.6297451691082592
+1.0,3.4901363450669476 2.0630236379093243
+0.0,-1.5343533005821828 -0.23745688647461852
+1.0,4.775056734905926 5.291243824646301
+0.0,-1.032123659747431 0.8458711875294105
+1.0,2.3091889606097844 3.3688150059111215
+0.0,0.7854236849909306 0.6742463927844289
+1.0,3.284779531346899 2.855746734955609
+0.0,0.380579394855332 -1.2378905330462027
+1.0,2.540193014555953 3.245568950444961
+0.0,-0.5491810448400926 -2.3179482776107894
+1.0,3.481785462949587 1.8870182253717969
+0.0,-0.06833732101790825 2.178923334945784
+1.0,1.1663083809702222 1.8919272314310458
+0.0,-0.7801536433937879 -1.4185984368350903
+1.0,1.457713814592066 3.0323739348144048
+0.0,-0.16377716798970973 0.09678021896691058
+1.0,2.2294515799173094 1.6179126855486068
+0.0,-0.5845552895984718 -0.8095679531228397
+1.0,2.024328902209618 2.4660315284543888
+0.0,0.2037503424802764 1.5767438723426828
+1.0,3.5058983262252643 3.292836693091364
+0.0,-1.4004772080893082 0.6150928060180622
+1.0,4.610936499146778 3.3674445809820313
+0.0,-0.7325641160695897 -3.0469742419403225
+1.0,2.6778956983269926 4.049681967443553
+0.0,-0.3375932473421461 -0.32976087151423067
+1.0,3.975838378562512 1.2032482992228626
+0.0,-1.6622711226380826 -0.6954676646542216
+1.0,3.1601568512397256 2.7472491112914357
+0.0,0.6739969973916968 1.3608866192945286
+1.0,3.097978499063888 3.88429576456391
+0.0,-0.16445244300279913 0.631410854999902
+1.0,4.244875698991619 3.0464568222900477
+0.0,0.1749522197766453 -0.3295077792829936
+1.0,4.158913950688044 1.1836177376726964
+0.0,-1.8286320279969996 -0.6355826362111864
+1.0,2.4795264391445326 0.8073937061906746
+0.0,-0.5095499320702017 -0.8451757050184052
+1.0,3.6489546081475206 2.7405880916534957
+0.0,-0.11733097334574003 0.020300758125140466
+1.0,1.9034123919197892 4.036941742254072
+0.0,-0.4678304671259669 -0.7653895561277071
+1.0,2.555027220737054 4.205906511993216
+0.0,0.1952150967011765 1.2402178923240337
+1.0,3.532371144429582 2.395018092924601
+0.0,1.4682834110821084 2.2292327929025078
+1.0,2.1160331256749663 3.7157102308564824
+0.0,1.3973790173654674 -1.1902799121683607
+1.0,3.4775573554170616 3.0459058509488557
+0.0,-2.215337088722839 0.7693588032777773
+1.0,2.3298220860458976 1.5924630285528396
+0.0,1.260641664088144 1.5474089692944746
+1.0,4.460878990061944 2.595950219349794
+0.0,-1.8214944389802914 -1.9733205363211535
+1.0,4.41874870213851 2.4975116019313264
+0.0,1.2037921250123007 -0.7057578432831773
+1.0,3.042628088030598 3.7366256492570136
+0.0,-0.02609770715133313 -0.01975791007372346
+1.0,1.123824442324706 3.5115607224884466
+0.0,0.3466005704292144 -1.206858960323042
+1.0,3.044152779557358 2.4308738719304266
+0.0,-0.8292396838183249 -0.5768591341562801
+1.0,2.9898679252543325 3.3291086316901484
+0.0,0.6033357093153775 0.18738779274832332
+1.0,3.2777482224094916 2.2676548172839714
+0.0,-0.7104360487845565 -1.0365712508175688
+1.0,2.617802272534323 1.887796671556582
+0.0,-0.21008998836798706 -2.4424443035468957
+1.0,3.9387085143031317 2.368798316318223
+0.0,-0.65027380204969 0.4757828709083824
+1.0,1.6786020855223545 1.62019388696364
+0.0,0.40325101156361803 0.26629562725726075
+1.0,2.4614637796912167 2.778406744842399
+0.0,-0.4327374795655596 0.5643009301153851
+1.0,2.6419358755663103 2.1911675067034206
+0.0,-0.06058610052148417 0.6118154934715632
+1.0,4.134485645832481 4.214482766162727
+0.0,-2.091472947105952 -0.21279450874188077
+1.0,3.7664041746453503 0.5848083052756543
+0.0,0.20187441248519114 0.7310035835212488
+1.0,3.6821251396696817 1.2016937526237272
+0.0,0.16248871053987612 -0.8547163523143474
+1.0,3.1725037691095834 3.051265058839004
+0.0,-1.7466975308858639 -0.048497170816597705
+1.0,4.296665913992498 4.432036327276331
+0.0,-0.49371042139965376 -1.3162216335880739
+1.0,3.0767376272412292 2.4082404056282467
+0.0,0.6517145281009619 -0.15229289422910688
+1.0,3.8556129079007406 4.932746403550176
+0.0,2.467072616559744 -0.6570760874457315
+1.0,3.8722558954619446 2.398547361219584
+0.0,-0.996362973160808 -0.24663573264285635
+1.0,2.058960472055059 0.09020868936476445
+0.0,1.1921444033047794 -1.2205820383864918
+1.0,3.499255855340612 4.26015377680707
+0.0,0.46495431359796363 -0.3535071804767937
+1.0,3.2772715993311534 1.8496849599545144
+0.0,0.9200766227075026 1.0153595739730128
+1.0,3.7395665378166516 4.161859093428991
+0.0,-1.3445731221950805 0.3711182438638966
+1.0,1.974184816991473 2.3758202020218637
+0.0,0.25747673028745044 1.4898729695115611
+1.0,3.643667737073963 2.5171980898063024
+0.0,-0.7491175934837044 1.807998586131331
+1.0,3.024294668483263 2.745713910567566
+0.0,-2.9902104324990075 0.48847563269083094
+1.0,2.693457241550706 4.067192099378729
+0.0,1.0010822910854564 1.065617155304199
+1.0,2.6231328305267576 3.2530925652040796
+0.0,-1.569524799794976 0.10080365850268516
+1.0,5.543177898986999 3.149276748958176
+0.0,-0.2697035609845456 -0.3834981890675749
+1.0,5.5737716796876935 3.134627621089238
+0.0,0.16848836970122472 1.7680681560270155
+1.0,2.984578320659214 3.8081853301923743
+0.0,2.00864307305994 -1.1769936806590435
+1.0,2.4301644281026538 1.5357007015355957
+0.0,-1.251515087462618 -1.0023388301407077
+1.0,2.7783106123714036 3.4753675099443138
+0.0,1.2067779830446301 -1.1138369735803868
+1.0,2.660559526103853 0.9246419639107195
+0.0,-0.2120078291751072 0.553871125085326
+1.0,3.2961674182984613 4.1840551114889655
+0.0,-1.7407002661640898 -0.13494920714243758
+1.0,2.61652747199719 2.606431158365525
+0.0,0.1810536358726569 -0.7041543708042312
+1.0,0.6618977487425206 4.43976232230529
+0.0,-1.1056190552516114 -0.26273698119076755
+1.0,3.245745718364984 0.9585399121419127
+0.0,0.451245033031027 0.3966692171364385
+1.0,0.7000962854359294 2.5787278270774685
+0.0,-0.20657738352563298 -0.3054434424581368
+1.0,2.194893094322135 1.2265276851138993
+0.0,1.6478689673866447 -1.2217538409516264
+1.0,2.6520153534620268 4.253943157694819
+0.0,-1.091459682813003 -1.5933476790183565
+1.0,2.381978388803204 2.5725801073346375
+0.0,-1.7089448316753346 -0.40058783295112843
+1.0,4.692976595302646 2.293610804758882
+0.0,-0.8154594160076379 0.9100123432125261
+1.0,1.8893957859271135 2.365552941116367
+0.0,1.4750445045587657 -0.5730495722105764
+1.0,4.627946484342315 4.01023129091373
+0.0,-0.5740578222548407 -0.9010801407945085
+1.0,1.1844352711236998 1.0077910117111921
+0.0,-1.1904557430938465 -0.972229300373332
+1.0,1.9514043869587852 2.6603232743467817
+0.0,-0.11744191317950421 1.8160954524210857
+1.0,2.796337014232012 3.45131164191957
+0.0,1.1908754571951825 1.37388641966138
+1.0,3.1347230127964805 3.4874636513372774
+0.0,1.4279445191621287 0.4142573535049987
+1.0,3.2845746999649457 2.942571828876143
+0.0,1.0418078095097314 -0.515727237947711
+1.0,3.0672407807876674 3.593602465858237
+0.0,0.1070041194341431 0.013584199138111364
+1.0,2.831124413123504 2.5083468687281196
+0.0,1.9088191143015583 1.1943157723052062
+1.0,2.888463730373365 3.8588231186101716
+0.0,0.3344825700647222 1.4902421889158837
+1.0,5.1805240354926285 2.347000348613805
+0.0,-0.14736761539184529 -1.3764336595247777
+1.0,4.945788020165247 4.520764535128319
+0.0,0.48089579766964224 -1.0406729486881927
+1.0,3.115699146536788 3.0271206455481905
+0.0,0.8816867514268375 -0.7885530518936628
+1.0,3.293642905051253 4.129500570671647
+0.0,0.021019117419869213 -1.0983625263034136
+1.0,3.4712873315273884 2.8896550248710255
+0.0,1.336463967380889 0.1782538924176004
+1.0,2.9674559623039674 2.1702990000666977
+0.0,-0.9137873001694705 -1.6488427315604255
+1.0,2.425720985355789 3.336546225859983
+0.0,-2.3622279944776245 0.33443034793657744
+1.0,3.557057454549674 0.9654984504665607
+0.0,0.4924227412613347 0.8572441753897001
+1.0,2.903599258175698 1.9821387894597133
+0.0,-0.562864152759892 -1.41025535274598
+1.0,2.621542267864135 3.0896861639721602
+0.0,-0.9659016052287058 1.8601390770202668
+1.0,2.73394050343452 1.5908844566159697
+0.0,0.316736908826005 0.2857224419323005
+1.0,2.3312567009140532 5.596694984859762
+0.0,0.3137619371424862 -0.1840942808000176
+1.0,3.857644883242267 1.7425846536145542
+0.0,-0.10204795362718587 3.253153279848385
+1.0,1.991635750012152 3.0091345292604816
+0.0,0.6187841242310289 0.9589700354301842
+1.0,2.9773010080735895 3.723750625441197
+0.0,-0.8890787476930039 0.6057780620635984
+1.0,3.2341068438464773 4.238588226643048
+0.0,-0.6100941277292691 -1.5125630779121992
+1.0,3.378840902739636 2.0705801293719017
+0.0,1.9736225258875286 1.725383750563661
+1.0,1.8874237286900284 3.9061132751393997
+0.0,-0.0823939289302894 1.8958431169469556
+1.0,1.5927855001333566 4.6310125064091965
+0.0,0.3112044157520983 -1.7878471816057036
+1.0,4.34881513764263 3.4693940014863784
+0.0,1.052103622850019 -0.16912252356217902
+1.0,3.167179956507673 2.8792495587252507
+0.0,0.16791453003538387 -0.8546142448164881
+1.0,3.0538805073215953 3.4494667407676842
+0.0,-0.9500475678227512 0.06998146933806365
+1.0,3.8909913837847467 2.6813428719208763
+0.0,-0.09976816220585052 -1.4875944011133129
+1.0,3.1791447205478742 4.424991854067018
+0.0,1.0999643223476656 -1.1200747827607145
+1.0,5.222367041159025 1.2015274537211948
+0.0,-0.2848179798736651 0.401703345435371
+1.0,3.92690552314874 0.5307127426832543
+0.0,-0.6771410319499919 -0.5806616553853885
+1.0,3.611779415106116 3.3322298911093533
+0.0,-1.359189339369671 -0.03773529290863042
+1.0,4.696002594470123 1.4346348756461187
+0.0,-1.0094856636150293 0.19687532044013809
+1.0,3.2169383066148383 3.2307201581236473
+0.0,0.7836015359045666 0.2941037782687062
+1.0,3.7317041306588012 3.7985843457251107
+0.0,-0.3693168101963429 1.4513472421644549
+1.0,4.398703283685875 2.654636797434109
+0.0,0.02043081741683321 0.20805199015337653
+1.0,2.324187503797731 3.8819865944906566
+0.0,1.671377007435211 1.3731572027338659
+1.0,4.534630721644852 1.1543799480085444
+0.0,-0.3253127279932509 -0.8285225286171498
+1.0,3.993821155042294 0.7056403589045206
+0.0,1.194500226045371 0.638917136862092
+1.0,2.72148063695256 3.858678264350294
+0.0,-0.1905653672336637 0.8969404368665279
+1.0,1.9587911397509248 3.937696894952624
+0.0,-1.1358853052995896 1.4443151501322575
+1.0,3.7551091652428026 2.475478572543473
+0.0,-0.9167034706173607 -1.7549316646340103
+1.0,1.4669571532496661 3.2025879996118567
+0.0,-0.9673112226998997 0.13104324478779786
+1.0,5.129589009385082 2.962228456981596
+0.0,-1.038791699676283 0.3394661925580474
+1.0,4.0067362767396055 3.7808733451013863
+0.0,0.4607763000001474 0.3165842402170894
+1.0,3.470781763864157 3.1917117382789906
+0.0,-1.0759836593672722 2.1677955321765423
+1.0,1.8061608083541592 2.1368201192592524
+0.0,0.18913968729195288 -0.6832055159990379
+1.0,2.222086435460701 2.462434683952491
+0.0,1.1697195016246194 -0.6482703204844716
+1.0,0.9469729137532825 2.564223951962673
+0.0,-0.2596612587018774 1.3675954564898984
+1.0,3.3498722540414603 2.8411678301395655
+0.0,0.15549061976540607 -0.8795816620250406
+1.0,3.2166810907529517 3.3909740833940147
+0.0,-0.27777898312342497 1.5708467895548373
+1.0,3.5590852623593734 3.022687446035052
+0.0,0.8854804450462548 -0.1674059547432505
+1.0,5.592380230543062 2.046846128948299
+0.0,-0.38403645419139704 -0.6879614453050698
+1.0,1.2059037878354082 3.1373448113023263
+0.0,-0.9332349591768346 0.3271191223126651
+1.0,2.6941262027196444 2.0016455336591275
+0.0,1.985628476449888 -1.720937514961405
+1.0,1.52678578836386 3.6524268651279113
+0.0,0.14930924959259012 0.3549736192569231
+1.0,2.5081810800507904 4.502494324423253
+0.0,1.3659157029970181 -1.4064298168920828
+1.0,2.8947698041280185 3.871692848909248
+0.0,-0.19002791703482588 0.8099829390725909
+1.0,3.0481549176670555 4.05245395484312
+0.0,-0.014729952199541938 0.43445426055411474
+1.0,3.0874888030440486 3.89317889717026
+0.0,0.9521743475193137 0.16292125350371375
+1.0,3.0564028575123805 3.150394468127784
+0.0,-2.5565867181635724 1.1693524400747453
+1.0,3.963399476624186 2.655863627219969
+0.0,2.0594134768376584 1.4326082874689938
+1.0,3.9415985004601524 4.816989711315565
+0.0,0.4986273362656531 -0.30506819506279537
+1.0,2.7697598834307633 2.0292290332215512
+0.0,-0.4716043983943112 1.4692631198715722
+1.0,3.4127279940145883 3.078218915501194
+0.0,-0.28649487641740207 -0.8009455078808752
+1.0,2.645854233845017 4.028461076417125
+0.0,-1.2333241385253426 -0.2850384355482007
+1.0,2.4938754741404976 1.3466482769013481
+0.0,0.6872021385233428 -0.5159203960430369
+1.0,3.136974388668967 1.69291587793452
+0.0,0.9532239280401443 2.619265789851879
+1.0,2.570576389986536 2.548658346643033
+0.0,-1.030037965987706 0.2814883160676786
+1.0,2.510605023939257 2.3227098241155213
+0.0,2.4171507836629256 1.245606490445435
+1.0,3.5520681299250985 0.7442734445298673
+0.0,1.1940577980770877 1.6319950123919318
+1.0,2.708933998825159 2.118496371335553
+0.0,0.26808250222082186 2.5727974909556437
+1.0,3.221534693193204 3.073316472650363
+0.0,-0.6915734756410544 0.25168141600713434
+1.0,1.839319878312068 1.765565689559382
+0.0,1.708990562782385 1.1196517028520787
+1.0,2.1942131633492643 3.733776318231434
+0.0,1.4884941762679373 -0.5221400677305167
+1.0,2.425026062564176 4.814343944240822
+0.0,-1.3572570451352999 0.04542725800519613
+1.0,3.211869589232063 0.01498355271713292
+0.0,1.6170759581287553 0.7420944718274473
+1.0,1.8096883146020295 1.2063063122336204
+0.0,0.8326608996906895 -0.9760063002065638
+1.0,3.60415819299222 3.905143144181063
+0.0,0.9709971797789466 -1.0644382680658016
+1.0,2.8104103693138778 3.5792951568581017
+0.0,-1.021059644329913 -0.25967578007654707
+1.0,2.4020556940935216 3.8705560506781826
+0.0,-2.704107564850001 -0.14300257306795375
+1.0,3.7681081908063643 2.5433599278958297
+0.0,-0.537043950598385 0.8892208622861
+1.0,3.894301374710518 2.76168141850308
+0.0,-0.8416385593366815 1.3377079857054535
+1.0,1.4560861866861152 1.9464951398785584
+0.0,0.8974462212548237 -0.9027814165394935
+1.0,2.848274393366227 4.089266410865265
+0.0,-1.9874388443190703 -2.0515326123686
+1.0,1.7443330286532606 5.182730816947559
+0.0,1.9345124573698136 0.15482916596109797
+1.0,3.730890742221753 3.4571088485293173
+0.0,-0.7591467032951466 0.7817400181511722
+1.0,1.9612060838774241 1.7874104906670758
+0.0,0.04241602781710118 1.7624663777014242
+1.0,2.983106574446788 2.057794179835603
+0.0,-2.2675373876565272 0.1810247094230928
+1.0,1.8242036739605434 3.2897838599534053
+0.0,0.42135250345103276 0.9201551657148959
+1.0,2.3324158301116547 3.2735600739611406
+0.0,-2.503382611181759 -0.604428052499623
+1.0,2.1068571110070753 1.3987709205712464
+0.0,-0.25006447102137164 1.1597904649452788
+1.0,3.6610503210650105 2.389802330720335
+0.0,0.6655774387829471 -0.7657689612002381
+1.0,3.85820287126228 5.653287382126853
+0.0,0.08244241317513575 0.4755361735454262
+1.0,3.6029514045048234 3.0483730792265247
+0.0,1.0276000901424318 -0.569237094330588
+1.0,2.484863163042475 3.4464671311141046
+0.0,0.24588867824456415 -0.7355421671684942
+1.0,2.8757627634577396 1.3730139621444188
+0.0,0.911649033206053 -1.0562220913143838
+1.0,0.6701966948829261 3.8815519088585195
+0.0,1.0649444423673609 0.5738944212075908
+1.0,3.1272553354329955 5.18450239514651
+0.0,-1.8305691156390467 -1.2811179644895232
+1.0,4.326027257587544 1.9589219729995737
+0.0,-0.2278417247639679 -0.6436775444106994
+1.0,3.9854139754166136 2.8662622299102947
+0.0,-0.33177487577648573 0.7122237484053809
+1.0,2.7631237758865255 2.490470927953921
+0.0,-0.2989203275224733 -0.9063254275476191
+1.0,2.7739570950234254 3.333596743208583
+0.0,-0.12025132003053318 -1.2251715775331837
+1.0,3.9028268386113307 2.580334438085556
+0.0,0.3114518803226873 0.35489645702286177
+1.0,2.8765994073916112 4.251640702192294
+0.0,-3.0895947568085367 -1.0526550179589378
+1.0,3.5182345295490216 2.764855512391279
+0.0,0.5749621254042305 0.7148834016467635
+1.0,4.039448299164001 2.377396087740471
+0.0,1.7077800661629936 -0.23711282974122355
+1.0,2.883211311171089 3.5259606315833287
+0.0,-1.0304518163976537 -0.16271910447066004
+1.0,3.8284470175501504 1.0841759781704199
+0.0,-1.3620621426919217 0.8678141368192274
+1.0,3.831976508070298 2.3592788803510505
+0.0,0.8398199934902235 0.8458121179021545
+1.0,2.166979759191688 4.408250411844058
+0.0,-1.2009412161006234 -0.04486968047943732
+1.0,3.0041897020427517 1.67577082931885
+0.0,-1.0550850035108499 2.6114061208535673
+1.0,1.46399823823424 3.6863318429400627
+0.0,-0.439942118867861 0.8107733517611471
+1.0,2.799907981207793 3.1021389011201244
+0.0,0.40512996190803663 -0.2720769110918539
+1.0,2.936414720731187 2.6121553148876706
+0.0,0.7864503163458285 0.879685137879171
+1.0,3.497848931993103 3.93953696354328
+0.0,1.0898800025299487 -0.3780987477521812
+1.0,3.0737866861658834 3.8281246288654067
+0.0,1.0100369320198321 -0.36412797089680377
+1.0,4.977156552398557 1.9361263628969327
+0.0,1.1948682006514484 -1.0421380659408503
+1.0,2.3707352395183743 3.319087891488442
+0.0,0.14662871945444525 -1.125277513770441
+1.0,4.18636170602371 5.079790109963499
+0.0,0.5213830491310841 2.5489667538554355
+1.0,3.456121838657517 2.9777488007628823
+0.0,1.3942157902546204 -0.7392170745991694
+1.0,4.027857416272539 2.5520251242493615
+0.0,0.6677437543225546 -0.7054702957392922
+1.0,2.419993627501343 3.147115729790262
+0.0,-1.1891285195785104 0.7121837556662985
+1.0,2.6768950566988114 2.746092902448666
+0.0,-0.5581632736462642 -0.8475377022167101
+1.0,2.2877649074222144 3.360822129377224
+0.0,0.12427410923130733 -0.029877611579596446
+1.0,2.1363649823278976 2.040672619624904
+0.0,0.164296403698455 -0.7853340225962958
+1.0,2.2867454265483063 2.920796736914219
+0.0,0.030938689766481568 0.02840531713718885
+1.0,4.935402862397514 4.984097800264938
+0.0,-0.49323021214001667 -0.009344009957387383
+1.0,2.2590589178865788 2.784700488476081
+0.0,-1.7996451721642797 -0.08927843209025701
+1.0,2.7189425454136047 3.366984002518318
+0.0,-0.4732503966611213 2.41667617281343
+1.0,1.914172722581019 2.723688261246487
+0.0,0.6854209215843875 -0.6321377274037409
+1.0,4.7025333481932705 2.6561807763401646
+0.0,0.016511529980536163 -0.4064291762993186
+1.0,1.3841179371371182 3.367159685928979
+0.0,-0.525665902025766 0.3189849885462113
+1.0,2.1237941386456276 3.4141040859263914
+0.0,-1.3977733609952327 1.6180332199555512
+1.0,3.3282228318571496 2.9879449742002184
+0.0,-1.3911999737510374 -0.47876736354905697
+1.0,3.071461319022103 3.902142645231827
+0.0,-1.4616870328596612 0.4234223737141411
+1.0,3.3069543201402576 1.3522887907099401
+0.0,0.1771175002160632 0.7092577154896049
+1.0,2.561517669553921 3.2663130772229185
+0.0,0.8635080818806004 1.7578935533355913
+1.0,3.3054989034355793 3.4205399612822633
+0.0,-0.5525474134214131 -0.008874526853035592
+1.0,5.024607965706471 3.377256085775693
+0.0,0.6499316691799448 0.7636813929956143
+1.0,1.7211648540475015 3.7290596058136307
+0.0,-0.4312096678787339 0.4723353140241522
+1.0,1.6269397815780402 1.9613109767814954
+0.0,0.06589250830042476 0.5659627954925366
+1.0,1.4141705667382305 2.9411215895612255
+0.0,-0.30655047441372724 1.134312621267185
+1.0,4.079371134159225 3.7127217011979767
+0.0,-0.11148410319718746 1.504423362990177
+1.0,3.21908765035085 1.5284527951297098
+0.0,0.38879874604519066 -0.7718569898512835
+1.0,3.0387686435299197 1.9571679686339727
+0.0,0.0432538958325193 -0.609046739618082
+1.0,3.858513576900389 2.3343789318227595
+0.0,-1.594606569379673 2.0291869081775498
+1.0,4.418575803606943 3.634284954659144
+0.0,-1.5657043498774568 0.48528442006547645
+1.0,3.7474369990653518 2.417108621170513
+0.0,-0.4087178618516316 -0.5585629524971241
+1.0,2.8830052178069345 2.714807180476644
+0.0,1.0200529614238536 1.633454495011907
+1.0,2.161101444560085 2.722233198993495
+0.0,0.8905571055499505 0.3531260808046299
+1.0,1.5770402091220281 2.5197577954902615
+0.0,0.19603489193696402 0.4391781215510938
+1.0,3.285302297900197 2.5981032583297274
+0.0,-1.7728311957227578 2.226646036588897
+1.0,2.212402423781055 2.994783519362575
+0.0,-0.26351331835428804 0.6197161896115081
+1.0,2.5101464936050144 2.747453537535198
+0.0,1.083443472210967 -0.7471502465676395
+1.0,2.618022142084275 3.201094589808021
+0.0,-0.10243507468644107 -1.5307780048431203
+1.0,2.0479014235932986 2.7174445598757764
+0.0,-0.2530316183327909 1.5105959457792464
+1.0,2.616239369128394 3.1011058356715644
+0.0,2.0703487677159997 -1.23039689097027
+1.0,2.00559575849234 3.088170264353322
+0.0,0.751453701775929 -0.34079600956200146
+1.0,2.6436129383324625 0.6934715851263205
+0.0,0.4735774669250165 0.24981500600111478
+1.0,3.614102521076285 3.297655445774221
+0.0,-0.8397190394129946 2.0791729859494583
+1.0,2.5800847823336372 2.312770726398467
+0.0,0.9528690775719402 -4.054641847252764
+1.0,1.6631425491523402 4.465488566725185
+0.0,-0.40442215938144854 2.1662912065078923
+1.0,3.2025444402071472 0.954639816329502
+0.0,0.8484611241529962 -0.6531501762867838
+1.0,2.907155165379039 4.494838051538261
+0.0,1.1473298350419248 -0.7604213061923158
+1.0,4.406872541176625 2.616395889868952
+0.0,-1.0643453307576694 0.32269083514118757
+1.0,3.4229771635424653 5.404174358063928
+0.0,0.8223012341648268 -2.0705983787489455
+1.0,0.6519219290294926 3.317297519573949
+0.0,0.6661739745821234 0.21368601256080724
+1.0,2.8092516816651187 2.9407143882873363
+0.0,-2.0396349059310626 0.6660958962860263
+1.0,1.621401319049101 2.120514741629026
+0.0,-0.6673242389540511 -1.033336539766657
+1.0,2.4729967381312257 2.0622671692969314
+0.0,0.318696287733599 0.7696143248064906
+1.0,-0.3310542190127661 2.503572170101248
+0.0,-0.024545405442632163 1.2826535279165514
+1.0,2.08361065329982 1.7709137020843035
+0.0,-0.03325908838419148 2.127731976717063
+1.0,0.8920712229737089 2.267227052639782
+0.0,2.4226620796703706 -1.5422597801969735
+1.0,2.6125707261695665 4.136941962252239
+0.0,0.710000430684373 -0.2365544035810329
+1.0,3.587983407259662 2.371118916918134
+0.0,1.548716105657387 2.6039797648647527
+1.0,2.288647833469394 2.8514285941696564
+0.0,0.5407956769257948 -1.4250712589214616
+1.0,3.9999271279969157 4.647262641336589
+0.0,0.46916438504363506 -0.16114805677977867
+1.0,3.9351714928555133 3.017851089635014
+0.0,-0.24683125971847 0.8686956304798523
+1.0,2.445900548419883 2.601998949302925
+0.0,0.9708272515136681 0.9540365110832763
+1.0,2.0889493306284472 1.670700190658552
+0.0,0.7573519355244429 -0.6731075400854291
+1.0,2.9938559890272676 0.5796453404844417
+0.0,-0.42350233780111274 0.1072223004754211
+1.0,3.22502989165533 3.2744724666391045
+0.0,-0.051171179793716125 0.035749085667007977
+1.0,4.256076524642883 3.956646576238979
+0.0,0.44715068158575316 -0.10904823199444005
+1.0,3.754239074295241 2.4862504435534283
+0.0,-0.12025734941101636 0.6682754649328633
+1.0,2.9673795614648815 3.6207880514009263
+0.0,-2.250093626462795 -0.49148713538228506
+1.0,1.7335315087131171 4.234455598757855
+0.0,-0.5145677322324603 -1.8872464244504652
+1.0,3.1524408905920547 2.534903833671654
+0.0,1.4188237424906527 -1.987300018397619
+1.0,3.025903676999244 2.1652631630581847
+0.0,0.5008343534015861 0.28011601768758965
+1.0,2.0039218613662197 2.3639397631018015
+0.0,1.342528231824729 1.0036076495884643
+1.0,3.3281244751369985 2.4251038991267277
+0.0,-0.38845861664115766 -1.5147629282596704
+1.0,2.613448357242925 4.463712912575443
+0.0,-0.19439583983218703 0.676381234314577
+1.0,1.0400516553104269 2.3981508685333424
+0.0,0.9469554018478826 -0.08144910777086176
+1.0,3.179705969662961 3.768848690124549
+0.0,0.39855441813668835 -1.6301847736954416
+1.0,2.1915941615815226 2.7947789889097763
+0.0,1.6023287643577222 0.05432794979410767
+1.0,1.5758610206949497 3.8709473262823777
+0.0,-1.3109119301269387 -0.8645189055395048
+1.0,3.715865055565244 1.9360512196442488
+0.0,-0.2073998491467907 -1.178882579876182
+1.0,2.565062666629786 2.3121370465462494
+0.0,-0.41397768670851737 -0.6674761320605563
+1.0,2.941938460212705 3.537877403937825
+0.0,0.5954231185191001 1.6839554319972647
+1.0,4.591360208911688 1.4381368838271187
+0.0,-1.3221878199013057 0.786799353955043
+1.0,0.6498018470693379 2.2143413646510095
+0.0,0.5346452265922554 0.45599002729248733
+1.0,2.668100742914233 2.679883986650412
+0.0,-0.22428284967184606 -1.0003823373608314
+1.0,4.233871998643562 3.3423521548333897
+0.0,0.7800144346305873 1.6512542456242612
+1.0,3.3192955924982677 4.664828345688715
+0.0,-0.9059493298933676 -0.42207747354389447
+1.0,3.1776956110847916 1.1393123509452483
+0.0,-0.5246202787832872 1.0246845701853746
+1.0,4.732113325540828 1.29018271893586
+0.0,0.9863596225434407 0.7506968948666005
+1.0,2.911409852038849 2.626474556246977
+0.0,0.8545346747310709 -2.1711133879380955
+1.0,2.476689592134109 4.03136160709651
+0.0,0.43108249592457043 0.4589971218864913
+1.0,3.2333287857145825 2.188137362144206
+0.0,1.4405649581445525 0.4131214094941824
+1.0,2.0631468420251093 3.807898318807702
+0.0,0.43964401099781425 0.6669437158150616
+1.0,2.165843657939062 4.109647016182597
+0.0,-0.9735452695016392 -0.6172105570335473
+1.0,3.169794653766589 3.2721053734106
+0.0,1.3129166037688875 -1.2040138532590103
+1.0,2.211361701514339 1.025981622029549
+0.0,0.3653350359702278 0.5229315457444437
+1.0,3.372206428302252 4.163685355869495
+0.0,-0.8690030167652726 0.3226849491596335
+1.0,4.188509026227427 2.1137749377457076
+0.0,2.2174789916979933 0.8249932442083762
+1.0,3.9224824525785706 2.9436443006575925
+0.0,0.1370905200148926 -0.043320354739616776
+1.0,3.1118662077850807 1.4983207834379917
+0.0,-0.5304073850344787 -0.4219778391981189
+1.0,1.2153552376808336 3.4749521622043438
+0.0,-2.545970043914331 -0.5480647959096547
+1.0,1.8097968872175412 4.733523163055134
+0.0,-0.5599306916727819 0.4648015112295201
+1.0,3.0242901796172204 4.354893518146392
+0.0,-0.49175893973189483 1.8635231981223406
+1.0,3.923889822736733 4.199324033436554
+0.0,0.32931083529824645 -1.2038529291812745
+1.0,2.8430570026355904 3.2581768028655214
+0.0,0.08015643729775149 -0.5281238499521005
+1.0,1.0251176552841985 2.452443183841665
+0.0,-1.4000614002792062 -0.4723026702712555
+1.0,4.642753244692533 3.5777684251625153
+0.0,-0.9732069449126244 -0.7507666182081589
+1.0,2.284811103731081 2.6226837934175817
+0.0,1.4938320459354653 1.2271703303402608
+1.0,2.5217907633717935 1.9804499278889345
+0.0,0.9177851256816916 -1.196945923903535
+1.0,2.650515007788954 0.9818159554114416
+0.0,-0.4172435945582116 0.11930551874205601
+1.0,1.8203127944592765 3.3069324017397594
+0.0,0.08195935202288789 -0.2585763476071969
+1.0,2.14910426585678 4.146147361847687
+0.0,1.578290774885182 0.16149960053586573
+1.0,1.2607405323635168 2.940350340912184
+0.0,1.6722138822230346 -0.5454073192477626
+1.0,0.3769561517619793 4.029314828130509
+0.0,-0.012008811772440746 0.2577932550827986
+1.0,2.330909580388283 3.1650439747088024
+0.0,-1.4224384024201595 -0.6369918128076046
+1.0,3.451178380794735 2.7553545272536746
+0.0,-0.7913135079702314 -0.012217405089490006
+1.0,3.7918310740082424 3.3927876820084033
+0.0,0.41016650792928255 0.3521369094279198
+1.0,2.380867149491576 3.7533007228820754
+0.0,-0.2787273586680994 1.3553543015884186
+1.0,2.8933236071325226 1.7975563396445144
+0.0,-0.4868680345968448 0.058461169788172784
+1.0,3.484434144626577 3.5622013162506683
+0.0,1.171904838026115 0.1162839888503951
+1.0,1.8132727587691455 2.238018140780368
+0.0,0.8114997821213137 -1.712768034302675
+1.0,2.977061410695451 2.802894970831404
+0.0,1.7141760742336318 0.5672102391229309
+1.0,3.2929421353515185 3.3754831695793945
+0.0,-2.280170614413754 -0.4912881923146271
+1.0,4.182771547422101 3.5331418354105812
+0.0,-0.2544453921577854 0.4682744998445509
+1.0,1.9236524545763007 2.628837510538455
+0.0,0.6645491524745186 -2.398604366119661
+1.0,3.50840713613987 3.7182332137428955
+0.0,-1.4532823239751684 -0.9916580822162051
+1.0,2.769613688635247 4.72661442603805
+0.0,-1.090104082054257 0.486265921887567
+1.0,3.4900626627065003 3.03025323652533
+0.0,1.4518716691137106 -0.10218738652959546
+1.0,2.745034544461333 4.366809709694589
+0.0,-0.17197050309086373 0.13673125942508174
+1.0,2.4934379443680985 2.954734256628178
+0.0,0.14078971520128297 -0.5401300324197861
+1.0,3.640563349517043 5.163454382169049
+0.0,1.0264020194022627 -0.8738489740165843
+1.0,3.791458514669831 2.2038333093620834
+0.0,-3.075231830613813 2.04054404065675
+1.0,4.647422323558612 3.5220753128741427
+0.0,-0.6423734479152313 0.5403500050100541
+1.0,1.5985339514690007 2.73447434771563
+0.0,-0.04474684215568748 -0.21477212224970194
+1.0,2.6701891009654792 3.9776885659794505
+0.0,-0.4714276238216119 1.4235807729101415
+1.0,3.5551789183755806 2.7057825768035104
+0.0,1.108254774651522 0.8596053056731966
+1.0,3.0623366138774983 2.718494058918926
+0.0,-1.375827910513567 0.011994162356159788
+1.0,3.841407434840553 2.8434319292302304
+0.0,-0.7149712282755271 0.1811986378283469
+1.0,5.155524316715826 2.1468464150279747
+0.0,-0.06822014690491127 -0.15801546435311806
+1.0,3.4838423066641173 4.211572262022802
+0.0,1.455177312877137 -0.9388697017811595
+1.0,3.917344840727481 3.569507254920478
+0.0,-2.080636526173827 -1.2489913979804321
+1.0,4.904327940183608 3.4289745068714295
+0.0,-1.4744723958060084 0.2930577753686633
+1.0,2.810346752831796 2.4062885063635333
+0.0,-0.17365054648101302 -2.26263747840141
+1.0,4.077713960215311 3.841309768575811
+0.0,1.581178479362914 -0.9672846912018417
+1.0,4.516244757634386 2.9078781629204054
+0.0,-1.5890391289381882 -0.4092245513024253
+1.0,3.359480708344044 3.7375262649030123
+0.0,1.5675385032786122 0.9010632060589036
+1.0,3.8564874267647644 3.060660915266198
+0.0,-0.2482500870678099 0.29655946916337894
+1.0,3.1672692968701397 1.1973226392521306
+0.0,-1.4471523637168304 0.5370395414503478
+1.0,4.814859889188941 2.229750617440331
+0.0,0.2812295731325761 0.6044036116090106
+1.0,2.4884527354338903 1.4171627784171204
+0.0,1.173099753717184 0.7948729712563257
+1.0,1.5092479631180256 4.1412277875509105
+0.0,-1.1453508695714685 -0.15567849492271865
+1.0,1.9397046305500465 3.430755367623314
+0.0,-1.6689604208958047 -1.161942047896626
+1.0,4.287905082572467 2.643797664646416
+0.0,0.5691715436318573 -0.6013793142266736
+1.0,2.622904412483301 1.769830678112635
+0.0,-1.0627706066421603 -1.2962746926911266
+1.0,2.5818494635089886 2.9547836545958663
+0.0,-1.555832778500785 0.6050365213516793
+1.0,0.6877755924513469 3.0627330470806617
+0.0,-0.6945984937358738 -0.5355659085722678
+1.0,3.631758943383 2.6990914911890194
+0.0,-0.10204034384758799 1.2650405538373874
+1.0,2.8618200471403488 2.7676923144816237
+0.0,-1.2337428464512885 -0.7151041760567872
+1.0,3.5209869997316807 3.280763138579491
+0.0,0.3700095159793621 -0.8614396246939711
+1.0,2.698616090611572 3.2205340189872795
+0.0,-0.8069663812258417 -0.07956402748767083
+1.0,2.929873320056276 4.030067053746698
+0.0,-1.2316919288622938 1.245687935224532
+1.0,2.9285679560367055 2.9682906465530783
+0.0,-0.3965578686363537 1.1748126835359254
+1.0,4.002714110052464 4.370338584188975
+0.0,-0.6084107635744659 -0.6092872315132073
+1.0,3.293912876563504 3.5843332356258464
+0.0,-0.8145032742370918 1.4050967895930515
+1.0,1.991600071099763 2.343264260750465
+0.0,-0.9433799779882722 1.5943129187456013
+1.0,2.369037146473894 1.9827898318071764
+0.0,-0.26885731570182714 0.47421918725401946
+1.0,3.263006333756187 3.0441051541001443
+0.0,0.21785408377528742 0.5754303556190559
+1.0,2.941128899266118 1.240818619804987
+0.0,0.736142634408259 -1.3173589352849961
+1.0,3.2027184783050644 2.9218716893221766
+0.0,1.9216539101612737 -2.2400666381338694
+1.0,2.4823406743823426 3.429705681271458
+0.0,0.0666674809216063 -0.976496437708073
+1.0,3.206108328915537 2.0828009180110976
+0.0,-0.11582094814525531 2.5093876016868366
+1.0,2.5373176496966328 2.32926952602907
+0.0,-0.9237765727032562 0.9342845305943139
+1.0,2.5300867778672123 3.2754703213122753
+0.0,0.13837351460348038 0.2533025702882705
+1.0,4.556185356940701 0.7629684714626066
+0.0,-1.8251759895063635 0.6966019254550819
+1.0,4.905392053322123 4.111245902434462
+0.0,0.09886105139472441 1.4093224263552915
+1.0,2.0484713074013223 4.874632770975326
+0.0,-0.040609033066195156 -1.3446008307073973
+1.0,3.678642687565624 4.156505531118834
+0.0,0.052003196801406706 1.2239229001362555
+1.0,3.4376496474012876 2.417529764306501
+0.0,-0.09054032070414311 -1.7571173217955876
+1.0,3.230032966809188 3.5965216835420546
+0.0,0.9100014718072797 0.5615698517199065
+1.0,3.938728443662248 3.2945250621813273
+0.0,-0.9205165004286314 -0.01425448590777016
+1.0,1.907285344344031 3.8629943281683987
+0.0,-0.8160057252300347 -0.2757475590440447
+1.0,2.3076630082503926 3.2283118851645476
+0.0,1.3000520665928303 0.581203895654615
+1.0,3.8425274250736887 3.6133028383400414
+0.0,0.13694776598217193 -1.1659103408047182
+1.0,2.688548985689179 1.5486856086329917
+0.0,-0.14378057635986438 -1.4649914115754739
+1.0,3.923705106138171 3.8281415874634783
+0.0,1.3334544187579878 -0.048721556115349604
+1.0,3.320777445436592 2.947489296620178
+0.0,-0.36251547004650103 -0.2886015741883188
+1.0,3.2163584307843567 2.9285953038088373
+0.0,0.5437339741631225 -0.23459273264636704
+1.0,2.820666118654177 4.0305429519659395
+0.0,0.04808393980018175 0.42285718084497675
+1.0,1.4686721107589078 2.6605885841423067
+0.0,1.1873828480862414 0.5487600196906772
+1.0,3.425690422789916 4.252827757634791
+0.0,-0.7323210179394448 -0.9818194354330615
+1.0,3.018263609974841 2.914037267945018
+0.0,1.005159548514262 -0.5055899932767433
+1.0,4.566046579419102 5.545663797862058
+0.0,-0.7129346827436536 2.2938920919917742
+1.0,2.869336979055624 2.5688122980246684
+0.0,1.5201806096451054 -0.7414084378784415
+1.0,1.71558426191034 2.4576286538624794
+0.0,0.8090326808020629 0.26208059965589425
+1.0,3.0163716479573077 2.4747608384001056
+0.0,0.47627288733283857 1.3085076289292734
+1.0,3.3891272567835684 3.20832981462489
+0.0,1.0488767400026389 1.2318533170755142
+1.0,3.3428160616141853 2.5497426855885075
+0.0,-0.6411040361810151 -0.4290410178863531
+1.0,2.219119637941564 2.6621113083439254
+0.0,1.5621125506487947 0.7273124535333745
+1.0,3.1459765929197636 1.3663869759433418
+0.0,-0.05263982623034547 0.43675636434345644
+1.0,1.890191705836878 3.435071392429276
+0.0,0.28718983621307775 -2.438042507707637
+1.0,5.717207001359904 2.2303522388797035
+0.0,0.17636841934036573 -0.2202348356695646
+1.0,2.7426941364254294 3.9506423829670734
+0.0,-1.118995077703066 0.6062681312772151
+1.0,4.510963440028501 2.4497214672006575
+0.0,0.07601426739661686 1.4712413920907517
+1.0,2.472822799411239 4.045939967967948
+0.0,-2.2061186560242603 0.32560701091997957
+1.0,3.250675248798315 3.268273446922124
+0.0,-0.024542349115316425 1.5505593308513355
+1.0,2.5654508852779654 2.9476923150082874
+0.0,0.8070230851041806 1.0614288963806608
+1.0,4.0121013342203655 1.7608333223695753
+0.0,-0.6895596222836047 0.035498410809669464
+1.0,1.697905057706837 4.053746875797327
+0.0,-0.3311042917990167 -0.09180266122060314
+1.0,3.720796880080382 4.467214289132983
+0.0,-0.318673057944378 -3.1474317710285202
+1.0,4.809204233917482 4.55250051737848
+0.0,0.596445093094233 0.41780789823963405
+1.0,4.432965399675368 3.4638105151117617
+0.0,-0.10285141484897965 1.747950423830727
+1.0,2.1513849154027014 3.9020766404442933
+0.0,1.5988780419195843 -0.08753929889987294
+1.0,0.9867334105272594 3.017081919852008
+0.0,-1.4952194834476749 1.0187701527429442
+1.0,2.2468599817570376 2.5883807516977395
+0.0,-1.804930212071194 0.3519094744696904
+1.0,4.1524048686549975 2.39387437993355
+0.0,0.7077190974093445 0.5703893640810606
+1.0,3.551726989450847 2.4786821848615985
+0.0,1.866022101379231 0.23733176192158173
+1.0,2.636453843734601 3.2607059005922467
+0.0,1.0052825898444602 0.5988275134415102
+1.0,2.643754787324359 3.72363185525656
+0.0,-0.9925822461102075 0.060644514219670244
+1.0,3.8994350969658136 1.9246001662480055
+0.0,0.6513177047637154 0.04450296971216735
+1.0,2.4564101844841106 3.6785165656991596
+0.0,0.2606556093620563 -0.6172755504020078
+1.0,2.4170362032345674 0.8639272362396189
+0.0,-0.6416537078444019 1.8622433251026849
+1.0,2.0247632881021267 2.538336421666863
+0.0,-1.0177991501405648 -0.8522549981552515
+1.0,3.3426117902650185 3.1635532244875586
+0.0,-0.08963512689480763 1.4555128614393191
+1.0,3.7470117779591092 3.414476280017385
+0.0,0.7721815837750134 -0.17297061945116646
+1.0,3.823597567639877 4.2427688079492665
+0.0,-0.6905817293226868 0.5838402640342898
+1.0,3.005258204213709 2.7252310853631125
+0.0,0.963732273262942 -1.3950688358262504
+1.0,3.2803836447761934 3.448945851174787
+0.0,-0.11576488451784747 1.8796627145034757
+1.0,3.905782244273501 3.3853014175990412
+0.0,0.3786078767939069 0.4054987293824608
+1.0,4.251338642737948 3.2212804055347375
+0.0,1.785664685579919 -0.4528337660796719
+1.0,0.9522164714530392 4.648272724469027
+0.0,2.06805484281029 0.3211833348167774
+1.0,3.2063266406360875 3.20907719820361
+0.0,-0.18542396323311192 -0.4721814985954186
+1.0,1.2468417100913183 2.988063666542869
+0.0,-0.9089767150726245 0.049627884005341995
+1.0,3.570670591235201 1.812766580123238
+0.0,1.9973417232460495 -0.17709723581574177
+1.0,2.810527831677345 2.0292239826226717
+0.0,0.06390562956663569 0.9110683296487658
+1.0,4.449308253046676 2.5895593413305997
+0.0,-0.18596846882351442 1.2495641818989083
+1.0,2.1189215966743986 3.7928094437779283
diff --git a/mllib/data/ridge-data/lpsa.data b/mllib/data/ridge-data/lpsa.data
new file mode 100644
index 0000000000..fdd16e36b4
--- /dev/null
+++ b/mllib/data/ridge-data/lpsa.data
@@ -0,0 +1,67 @@
+-0.4307829,-1.63735562648104 -2.00621178480549 -1.86242597251066 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
+-0.1625189,-1.98898046126935 -0.722008756122123 -0.787896192088153 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
+-0.1625189,-1.57881887548545 -2.1887840293994 1.36116336875686 -1.02470580167082 -0.522940888712441 -0.863171185425945 0.342627053981254 -0.155348103855541
+-0.1625189,-2.16691708463163 -0.807993896938655 -0.787896192088153 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
+0.3715636,-0.507874475300631 -0.458834049396776 -0.250631301876899 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
+0.7654678,-2.03612849966376 -0.933954647105133 -1.86242597251066 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
+0.8544153,-0.557312518810673 -0.208756571683607 -0.787896192088153 0.990146852537193 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
+1.2669476,-0.929360463147704 -0.0578991819441687 0.152317365781542 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
+1.2669476,-2.28833047634983 -0.0706369432557794 -0.116315079324086 0.80409888772376 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
+1.2669476,0.223498042876113 -1.41471935455355 -0.116315079324086 -1.02470580167082 -0.522940888712441 -0.29928234305568 0.342627053981254 0.199211097885341
+1.3480731,0.107785900236813 -1.47221551299731 0.420949810887169 -1.02470580167082 -0.522940888712441 -0.863171185425945 0.342627053981254 -0.687186906466865
+1.446919,0.162180092313795 -1.32557369901905 0.286633588334355 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
+1.4701758,-1.49795329918548 -0.263601072284232 0.823898478545609 0.788388310173035 -0.522940888712441 -0.29928234305568 0.342627053981254 0.199211097885341
+1.4929041,0.796247055396743 0.0476559407005752 0.286633588334355 -1.02470580167082 -0.522940888712441 0.394013435896129 -1.04215728919298 -0.864466507337306
+1.5581446,-1.62233848461465 -0.843294091975396 -3.07127197548598 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
+1.5993876,-0.990720665490831 0.458513517212311 0.823898478545609 1.07379746308195 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
+1.6389967,-0.171901281967138 -0.489197399065355 -0.65357996953534 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
+1.6956156,-1.60758252338831 -0.590700340358265 -0.65357996953534 -0.619561070667254 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
+1.7137979,0.366273918511144 -0.414014962912583 -0.116315079324086 0.232904453212813 -0.522940888712441 0.971228997418125 0.342627053981254 1.26288870310799
+1.8000583,-0.710307384579833 0.211731938156277 0.152317365781542 -1.02470580167082 -0.522940888712441 -0.442797990776478 0.342627053981254 1.61744790484887
+1.8484548,-0.262791728113881 -1.16708345615721 0.420949810887169 0.0846342590816532 -0.522940888712441 0.163172393491611 0.342627053981254 1.97200710658975
+1.8946169,0.899043117369237 -0.590700340358265 0.152317365781542 -1.02470580167082 -0.522940888712441 1.28643254437683 -1.04215728919298 -0.864466507337306
+1.9242487,-0.903451690500615 1.07659722048274 0.152317365781542 1.28380453408541 -0.522940888712441 -0.442797990776478 -1.04215728919298 -0.864466507337306
+2.008214,-0.0633337899773081 -1.38088970920094 0.958214701098423 0.80409888772376 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
+2.0476928,-1.15393789990757 -0.961853075398404 -0.116315079324086 -1.02470580167082 -0.522940888712441 -0.442797990776478 -1.04215728919298 -0.864466507337306
+2.1575593,0.0620203721138446 0.0657973885499142 1.22684714620405 -0.468824786336838 -0.522940888712441 1.31421001659859 1.72741139715549 -0.332627704725983
+2.1916535,-0.75731027755674 -2.92717970468456 0.018001143228728 -1.02470580167082 -0.522940888712441 -0.863171185425945 0.342627053981254 -0.332627704725983
+2.2137539,1.11226993252773 1.06484916245061 0.555266033439982 0.877691038550889 1.89254797819741 1.43890404648442 0.342627053981254 0.376490698755783
+2.2772673,-0.468768642850639 -1.43754788774533 -1.05652863719378 0.576050411655607 -0.522940888712441 0.0120483832567209 0.342627053981254 -0.687186906466865
+2.2975726,-0.618884859896728 -1.1366360750781 -0.519263746982526 -1.02470580167082 -0.522940888712441 -0.863171185425945 3.11219574032972 1.97200710658975
+2.3272777,-0.651431999123483 0.55329161145762 -0.250631301876899 1.11210019001038 -0.522940888712441 -0.179808625688859 -1.04215728919298 -0.864466507337306
+2.5217206,0.115499102435224 -0.512233676577595 0.286633588334355 1.13650173283446 -0.522940888712441 -0.179808625688859 0.342627053981254 -0.155348103855541
+2.5533438,0.266341329949937 -0.551137885443386 -0.384947524429713 0.354857790686005 -0.522940888712441 -0.863171185425945 0.342627053981254 -0.332627704725983
+2.5687881,1.16902610257751 0.855491905752846 2.03274448152093 1.22628985326088 1.89254797819741 2.02833774827712 3.11219574032972 2.68112551007152
+2.6567569,-0.218972367124187 0.851192298581141 0.555266033439982 -1.02470580167082 -0.522940888712441 -0.863171185425945 0.342627053981254 0.908329501367106
+2.677591,0.263121415733908 1.4142681068416 0.018001143228728 1.35980653053822 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
+2.7180005,-0.0704736333296423 1.52000996595417 0.286633588334355 1.39364261119802 -0.522940888712441 -0.863171185425945 0.342627053981254 -0.332627704725983
+2.7942279,-0.751957286017338 0.316843561689933 -1.99674219506348 0.911736065044475 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
+2.8063861,-0.685277652430997 1.28214038482516 0.823898478545609 0.232904453212813 -0.522940888712441 -0.863171185425945 0.342627053981254 -0.155348103855541
+2.8124102,-0.244991501432929 0.51882005949686 -0.384947524429713 0.823246560137838 -0.522940888712441 -0.863171185425945 0.342627053981254 0.553770299626224
+2.8419982,-0.75731027755674 2.09041984898851 1.22684714620405 1.53428167116843 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
+2.8535925,1.20962937075363 -0.242882661178889 1.09253092365124 -1.02470580167082 -0.522940888712441 1.24263233939889 3.11219574032972 2.50384590920108
+2.9204698,0.570886990493502 0.58243883987948 0.555266033439982 1.16006887775962 -0.522940888712441 1.07357183940747 0.342627053981254 1.61744790484887
+2.9626924,0.719758684343624 0.984970304132004 1.09253092365124 1.52137230773457 -0.522940888712441 -0.179808625688859 0.342627053981254 -0.509907305596424
+2.9626924,-1.52406140158064 1.81975700990333 0.689582255992796 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
+2.9729753,-0.132431544081234 2.68769877553723 1.09253092365124 1.53428167116843 -0.522940888712441 -0.442797990776478 0.342627053981254 -0.687186906466865
+3.0130809,0.436161292804989 -0.0834447307428255 -0.519263746982526 -1.02470580167082 1.89254797819741 1.07357183940747 0.342627053981254 1.26288870310799
+3.0373539,-0.161195191984091 -0.671900359186746 1.7641120364153 1.13650173283446 -0.522940888712441 -0.863171185425945 0.342627053981254 0.0219314970149
+3.2752562,1.39927182372944 0.513852869452676 0.689582255992796 -1.02470580167082 1.89254797819741 1.49394503405693 0.342627053981254 -0.155348103855541
+3.3375474,1.51967002306341 -0.852203755696565 0.555266033439982 -0.104527297798983 1.89254797819741 1.85927724828569 0.342627053981254 0.908329501367106
+3.3928291,0.560725834706224 1.87867703391426 1.09253092365124 1.39364261119802 -0.522940888712441 0.486423065822545 0.342627053981254 1.26288870310799
+3.4355988,1.00765532502814 1.69426310090641 1.89842825896812 1.53428167116843 -0.522940888712441 -0.863171185425945 0.342627053981254 -0.509907305596424
+3.4578927,1.10152996153577 -0.10927271844907 0.689582255992796 -1.02470580167082 1.89254797819741 1.97630171771485 0.342627053981254 1.61744790484887
+3.5160131,0.100001934217311 -1.30380956369388 0.286633588334355 0.316555063757567 -0.522940888712441 0.28786643052924 0.342627053981254 0.553770299626224
+3.5307626,0.987291634724086 -0.36279314978779 -0.922212414640967 0.232904453212813 -0.522940888712441 1.79270085261407 0.342627053981254 1.26288870310799
+3.5652984,1.07158528137575 0.606453149641961 1.7641120364153 -0.432854616994416 1.89254797819741 0.528504607720369 0.342627053981254 0.199211097885341
+3.5876769,0.180156323255198 0.188987436375017 -0.519263746982526 1.09956763075594 -0.522940888712441 0.708239632330506 0.342627053981254 0.199211097885341
+3.6309855,1.65687973755377 -0.256675483533719 0.018001143228728 -1.02470580167082 1.89254797819741 1.79270085261407 0.342627053981254 1.26288870310799
+3.6800909,0.5720085322365 0.239854450210939 -0.787896192088153 1.0605418233138 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
+3.7123518,0.323806133438225 -0.606717660886078 -0.250631301876899 -1.02470580167082 1.89254797819741 0.342907418101747 0.342627053981254 0.199211097885341
+3.9843437,1.23668206715898 2.54220539083611 0.152317365781542 -1.02470580167082 1.89254797819741 1.89037692416194 0.342627053981254 1.26288870310799
+3.993603,0.180156323255198 0.154448192444669 1.62979581386249 0.576050411655607 1.89254797819741 0.708239632330506 0.342627053981254 1.79472750571931
+4.029806,1.60906277046565 1.10378605019827 0.555266033439982 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
+4.1295508,1.0036214996026 0.113496885050331 -0.384947524429713 0.860016436332751 1.89254797819741 -0.863171185425945 0.342627053981254 -0.332627704725983
+4.3851468,1.25591974271076 0.577607033774471 0.555266033439982 -1.02470580167082 1.89254797819741 1.07357183940747 0.342627053981254 1.26288870310799
+4.6844434,2.09650591351268 0.625488598331018 -2.66832330782754 -1.02470580167082 1.89254797819741 1.67954222367555 0.342627053981254 0.553770299626224
+5.477509,1.30028987435881 0.338383613253713 0.555266033439982 1.00481276295349 1.89254797819741 1.24263233939889 0.342627053981254 1.97200710658975
diff --git a/mllib/src/main/scala/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/spark/mllib/clustering/KMeans.scala
new file mode 100644
index 0000000000..b0e141ff32
--- /dev/null
+++ b/mllib/src/main/scala/spark/mllib/clustering/KMeans.scala
@@ -0,0 +1,317 @@
+package spark.mllib.clustering
+
+import scala.collection.mutable.ArrayBuffer
+import scala.util.Random
+
+import spark.{SparkContext, RDD}
+import spark.SparkContext._
+import spark.Logging
+import spark.mllib.util.MLUtils
+
+import org.jblas.DoubleMatrix
+
+
+/**
+ * K-means clustering with support for multiple parallel runs and a k-means++ like initialization
+ * mode (the k-means|| algorithm by Bahmani et al). When multiple concurrent runs are requested,
+ * they are executed together with joint passes over the data for efficiency.
+ *
+ * This is an iterative algorithm that will make multiple passes over the data, so any RDDs given
+ * to it should be cached by the user.
+ */
+class KMeans private (
+ var k: Int,
+ var maxIterations: Int,
+ var runs: Int,
+ var initializationMode: String,
+ var initializationSteps: Int,
+ var epsilon: Double)
+ extends Serializable with Logging
+{
+ private type ClusterCenters = Array[Array[Double]]
+
+ def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4)
+
+ /** Set the number of clusters to create (k). Default: 2. */
+ def setK(k: Int): KMeans = {
+ this.k = k
+ this
+ }
+
+ /** Set maximum number of iterations to run. Default: 20. */
+ def setMaxIterations(maxIterations: Int): KMeans = {
+ this.maxIterations = maxIterations
+ this
+ }
+
+ /**
+ * Set the initialization algorithm. This can be either "random" to choose random points as
+ * initial cluster centers, or "k-means||" to use a parallel variant of k-means++
+ * (Bahmani et al., Scalable K-Means++, VLDB 2012). Default: k-means||.
+ */
+ def setInitializationMode(initializationMode: String): KMeans = {
+ if (initializationMode != KMeans.RANDOM && initializationMode != KMeans.K_MEANS_PARALLEL) {
+ throw new IllegalArgumentException("Invalid initialization mode: " + initializationMode)
+ }
+ this.initializationMode = initializationMode
+ this
+ }
+
+ /**
+ * Set the number of runs of the algorithm to execute in parallel. We initialize the algorithm
+ * this many times with random starting conditions (configured by the initialization mode), then
+ * return the best clustering found over any run. Default: 1.
+ */
+ def setRuns(runs: Int): KMeans = {
+ if (runs <= 0) {
+ throw new IllegalArgumentException("Number of runs must be positive")
+ }
+ this.runs = runs
+ this
+ }
+
+ /**
+ * Set the number of steps for the k-means|| initialization mode. This is an advanced
+ * setting -- the default of 5 is almost always enough. Default: 5.
+ */
+ def setInitializationSteps(initializationSteps: Int): KMeans = {
+ if (initializationSteps <= 0) {
+ throw new IllegalArgumentException("Number of initialization steps must be positive")
+ }
+ this.initializationSteps = initializationSteps
+ this
+ }
+
+ /**
+ * Set the distance threshold within which we've consider centers to have converged.
+ * If all centers move less than this Euclidean distance, we stop iterating one run.
+ */
+ def setEpsilon(epsilon: Double): KMeans = {
+ this.epsilon = epsilon
+ this
+ }
+
+ /**
+ * Train a K-means model on the given set of points; `data` should be cached for high
+ * performance, because this is an iterative algorithm.
+ */
+ def train(data: RDD[Array[Double]]): KMeansModel = {
+ // TODO: check whether data is persistent; this needs RDD.storageLevel to be publicly readable
+
+ val sc = data.sparkContext
+
+ val centers = if (initializationMode == KMeans.RANDOM) {
+ initRandom(data)
+ } else {
+ initKMeansParallel(data)
+ }
+
+ val active = Array.fill(runs)(true)
+ val costs = Array.fill(runs)(0.0)
+
+ var activeRuns = new ArrayBuffer[Int] ++ (0 until runs)
+ var iteration = 0
+
+ // Execute iterations of Lloyd's algorithm until all runs have converged
+ while (iteration < maxIterations && !activeRuns.isEmpty) {
+ type WeightedPoint = (DoubleMatrix, Long)
+ def mergeContribs(p1: WeightedPoint, p2: WeightedPoint): WeightedPoint = {
+ (p1._1.addi(p2._1), p1._2 + p2._2)
+ }
+
+ val activeCenters = activeRuns.map(r => centers(r)).toArray
+ val costAccums = activeRuns.map(_ => sc.accumulator(0.0))
+
+ // Find the sum and count of points mapping to each center
+ val totalContribs = data.mapPartitions { points =>
+ val runs = activeCenters.length
+ val k = activeCenters(0).length
+ val dims = activeCenters(0)(0).length
+
+ val sums = Array.fill(runs, k)(new DoubleMatrix(dims))
+ val counts = Array.fill(runs, k)(0L)
+
+ for (point <- points; (centers, runIndex) <- activeCenters.zipWithIndex) {
+ val (bestCenter, cost) = KMeans.findClosest(centers, point)
+ costAccums(runIndex) += cost
+ sums(runIndex)(bestCenter).addi(new DoubleMatrix(point))
+ counts(runIndex)(bestCenter) += 1
+ }
+
+ val contribs = for (i <- 0 until runs; j <- 0 until k) yield {
+ ((i, j), (sums(i)(j), counts(i)(j)))
+ }
+ contribs.iterator
+ }.reduceByKey(mergeContribs).collectAsMap()
+
+ // Update the cluster centers and costs for each active run
+ for ((run, i) <- activeRuns.zipWithIndex) {
+ var changed = false
+ for (j <- 0 until k) {
+ val (sum, count) = totalContribs((i, j))
+ if (count != 0) {
+ val newCenter = sum.divi(count).data
+ if (MLUtils.squaredDistance(newCenter, centers(run)(j)) > epsilon * epsilon) {
+ changed = true
+ }
+ centers(run)(j) = newCenter
+ }
+ }
+ if (!changed) {
+ active(run) = false
+ logInfo("Run " + run + " finished in " + (iteration + 1) + " iterations")
+ }
+ costs(run) = costAccums(i).value
+ }
+
+ activeRuns = activeRuns.filter(active(_))
+ iteration += 1
+ }
+
+ val bestRun = costs.zipWithIndex.min._2
+ new KMeansModel(centers(bestRun))
+ }
+
+ /**
+ * Initialize `runs` sets of cluster centers at random.
+ */
+ private def initRandom(data: RDD[Array[Double]]): Array[ClusterCenters] = {
+ // Sample all the cluster centers in one pass to avoid repeated scans
+ val sample = data.takeSample(true, runs * k, new Random().nextInt())
+ Array.tabulate(runs)(r => sample.slice(r * k, (r + 1) * k))
+ }
+
+ /**
+ * Initialize `runs` sets of cluster centers using the k-means|| algorithm by Bahmani et al.
+ * (Bahmani et al., Scalable K-Means++, VLDB 2012). This is a variant of k-means++ that tries
+ * to find with dissimilar cluster centers by starting with a random center and then doing
+ * passes where more centers are chosen with probability proportional to their squared distance
+ * to the current cluster set. It results in a provable approximation to an optimal clustering.
+ *
+ * The original paper can be found at http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf.
+ */
+ private def initKMeansParallel(data: RDD[Array[Double]]): Array[ClusterCenters] = {
+ // Initialize each run's center to a random point
+ val seed = new Random().nextInt()
+ val sample = data.takeSample(true, runs, seed)
+ val centers = Array.tabulate(runs)(r => ArrayBuffer(sample(r)))
+
+ // On each step, sample 2 * k points on average for each run with probability proportional
+ // to their squared distance from that run's current centers
+ for (step <- 0 until initializationSteps) {
+ val centerArrays = centers.map(_.toArray)
+ val sumCosts = data.flatMap { point =>
+ for (r <- 0 until runs) yield (r, KMeans.pointCost(centerArrays(r), point))
+ }.reduceByKey(_ + _).collectAsMap()
+ val chosen = data.mapPartitionsWithIndex { (index, points) =>
+ val rand = new Random(seed ^ (step << 16) ^ index)
+ for {
+ p <- points
+ r <- 0 until runs
+ if rand.nextDouble() < KMeans.pointCost(centerArrays(r), p) * 2 * k / sumCosts(r)
+ } yield (r, p)
+ }.collect()
+ for ((r, p) <- chosen) {
+ centers(r) += p
+ }
+ }
+
+ // Finally, we might have a set of more than k candidate centers for each run; weigh each
+ // candidate by the number of points in the dataset mapping to it and run a local k-means++
+ // on the weighted centers to pick just k of them
+ val centerArrays = centers.map(_.toArray)
+ val weightMap = data.flatMap { p =>
+ for (r <- 0 until runs) yield ((r, KMeans.findClosest(centerArrays(r), p)._1), 1.0)
+ }.reduceByKey(_ + _).collectAsMap()
+ val finalCenters = (0 until runs).map { r =>
+ val myCenters = centers(r).toArray
+ val myWeights = (0 until myCenters.length).map(i => weightMap.getOrElse((r, i), 0.0)).toArray
+ LocalKMeans.kMeansPlusPlus(r, myCenters, myWeights, k, 30)
+ }
+
+ finalCenters.toArray
+ }
+}
+
+
+/**
+ * Top-level methods for calling K-means clustering.
+ */
+object KMeans {
+ // Initialization mode names
+ val RANDOM = "random"
+ val K_MEANS_PARALLEL = "k-means||"
+
+ def train(
+ data: RDD[Array[Double]],
+ k: Int,
+ maxIterations: Int,
+ runs: Int,
+ initializationMode: String)
+ : KMeansModel =
+ {
+ new KMeans().setK(k)
+ .setMaxIterations(maxIterations)
+ .setRuns(runs)
+ .setInitializationMode(initializationMode)
+ .train(data)
+ }
+
+ def train(data: RDD[Array[Double]], k: Int, maxIterations: Int, runs: Int): KMeansModel = {
+ train(data, k, maxIterations, runs, K_MEANS_PARALLEL)
+ }
+
+ def train(data: RDD[Array[Double]], k: Int, maxIterations: Int): KMeansModel = {
+ train(data, k, maxIterations, 1, K_MEANS_PARALLEL)
+ }
+
+ /**
+ * Return the index of the closest point in `centers` to `point`, as well as its distance.
+ */
+ private[mllib] def findClosest(centers: Array[Array[Double]], point: Array[Double])
+ : (Int, Double) =
+ {
+ var bestDistance = Double.PositiveInfinity
+ var bestIndex = 0
+ for (i <- 0 until centers.length) {
+ val distance = MLUtils.squaredDistance(point, centers(i))
+ if (distance < bestDistance) {
+ bestDistance = distance
+ bestIndex = i
+ }
+ }
+ (bestIndex, bestDistance)
+ }
+
+ /**
+ * Return the K-means cost of a given point against the given cluster centers.
+ */
+ private[mllib] def pointCost(centers: Array[Array[Double]], point: Array[Double]): Double = {
+ var bestDistance = Double.PositiveInfinity
+ for (i <- 0 until centers.length) {
+ val distance = MLUtils.squaredDistance(point, centers(i))
+ if (distance < bestDistance) {
+ bestDistance = distance
+ }
+ }
+ bestDistance
+ }
+
+ def main(args: Array[String]) {
+ if (args.length != 4) {
+ println("Usage: KMeans <master> <input_file> <k> <max_iterations>")
+ System.exit(1)
+ }
+ val (master, inputFile, k, iters) = (args(0), args(1), args(2).toInt, args(3).toInt)
+ val sc = new SparkContext(master, "KMeans")
+ val data = sc.textFile(inputFile).map(line => line.split(' ').map(_.toDouble))
+ val model = KMeans.train(data, k, iters)
+ val cost = model.computeCost(data)
+ println("Cluster centers:")
+ for (c <- model.clusterCenters) {
+ println(" " + c.mkString(" "))
+ }
+ println("Cost: " + cost)
+ System.exit(0)
+ }
+}
diff --git a/mllib/src/main/scala/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/spark/mllib/clustering/KMeansModel.scala
new file mode 100644
index 0000000000..4fd0646160
--- /dev/null
+++ b/mllib/src/main/scala/spark/mllib/clustering/KMeansModel.scala
@@ -0,0 +1,27 @@
+package spark.mllib.clustering
+
+import spark.RDD
+import spark.SparkContext._
+import spark.mllib.util.MLUtils
+
+
+/**
+ * A clustering model for K-means. Each point belongs to the cluster with the closest center.
+ */
+class KMeansModel(val clusterCenters: Array[Array[Double]]) extends Serializable {
+ /** Total number of clusters. */
+ def k: Int = clusterCenters.length
+
+ /** Return the cluster index that a given point belongs to. */
+ def predict(point: Array[Double]): Int = {
+ KMeans.findClosest(clusterCenters, point)._1
+ }
+
+ /**
+ * Return the K-means cost (sum of squared distances of points to their nearest center) for this
+ * model on the given data.
+ */
+ def computeCost(data: RDD[Array[Double]]): Double = {
+ data.map(p => KMeans.pointCost(clusterCenters, p)).sum
+ }
+}
diff --git a/mllib/src/main/scala/spark/mllib/clustering/LocalKMeans.scala b/mllib/src/main/scala/spark/mllib/clustering/LocalKMeans.scala
new file mode 100644
index 0000000000..e12b3be251
--- /dev/null
+++ b/mllib/src/main/scala/spark/mllib/clustering/LocalKMeans.scala
@@ -0,0 +1,88 @@
+package spark.mllib.clustering
+
+import scala.util.Random
+
+import org.jblas.{DoubleMatrix, SimpleBlas}
+
+/**
+ * An utility object to run K-means locally. This is private to the ML package because it's used
+ * in the initialization of KMeans but not meant to be publicly exposed.
+ */
+private[mllib] object LocalKMeans {
+ /**
+ * Run K-means++ on the weighted point set `points`. This first does the K-means++
+ * initialization procedure and then roudns of Lloyd's algorithm.
+ */
+ def kMeansPlusPlus(
+ seed: Int,
+ points: Array[Array[Double]],
+ weights: Array[Double],
+ k: Int,
+ maxIterations: Int)
+ : Array[Array[Double]] =
+ {
+ val rand = new Random(seed)
+ val dimensions = points(0).length
+ val centers = new Array[Array[Double]](k)
+
+ // Initialize centers by sampling using the k-means++ procedure
+ centers(0) = pickWeighted(rand, points, weights)
+ for (i <- 1 until k) {
+ // Pick the next center with a probability proportional to cost under current centers
+ val curCenters = centers.slice(0, i)
+ val sum = points.zip(weights).map { case (p, w) =>
+ w * KMeans.pointCost(curCenters, p)
+ }.sum
+ val r = rand.nextDouble() * sum
+ var cumulativeScore = 0.0
+ var j = 0
+ while (j < points.length && cumulativeScore < r) {
+ cumulativeScore += weights(j) * KMeans.pointCost(curCenters, points(j))
+ j += 1
+ }
+ centers(i) = points(j-1)
+ }
+
+ // Run up to maxIterations iterations of Lloyd's algorithm
+ val oldClosest = Array.fill(points.length)(-1)
+ var iteration = 0
+ var moved = true
+ while (moved && iteration < maxIterations) {
+ moved = false
+ val sums = Array.fill(k)(new DoubleMatrix(dimensions))
+ val counts = Array.fill(k)(0.0)
+ for ((p, i) <- points.zipWithIndex) {
+ val index = KMeans.findClosest(centers, p)._1
+ SimpleBlas.axpy(weights(i), new DoubleMatrix(p), sums(index))
+ counts(index) += weights(i)
+ if (index != oldClosest(i)) {
+ moved = true
+ oldClosest(i) = index
+ }
+ }
+ // Update centers
+ for (i <- 0 until k) {
+ if (counts(i) == 0.0) {
+ // Assign center to a random point
+ centers(i) = points(rand.nextInt(points.length))
+ } else {
+ centers(i) = sums(i).divi(counts(i)).data
+ }
+ }
+ iteration += 1
+ }
+
+ centers
+ }
+
+ private def pickWeighted[T](rand: Random, data: Array[T], weights: Array[Double]): T = {
+ val r = rand.nextDouble() * weights.sum
+ var i = 0
+ var curWeight = 0.0
+ while (i < data.length && curWeight < r) {
+ curWeight += weights(i)
+ i += 1
+ }
+ data(i - 1)
+ }
+}
diff --git a/mllib/src/main/scala/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/spark/mllib/optimization/Gradient.scala
new file mode 100644
index 0000000000..90b0999a5e
--- /dev/null
+++ b/mllib/src/main/scala/spark/mllib/optimization/Gradient.scala
@@ -0,0 +1,33 @@
+package spark.mllib.optimization
+
+import org.jblas.DoubleMatrix
+
+abstract class Gradient extends Serializable {
+ /**
+ * Compute the gradient for a given row of data.
+ *
+ * @param data - One row of data. Row matrix of size 1xn where n is the number of features.
+ * @param label - Label for this data item.
+ * @param weights - Column matrix containing weights for every feature.
+ */
+ def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix):
+ (DoubleMatrix, Double)
+}
+
+class LogisticGradient extends Gradient {
+ override def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix):
+ (DoubleMatrix, Double) = {
+ val margin: Double = -1.0 * data.dot(weights)
+ val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label
+
+ val gradient = data.mul(gradientMultiplier)
+ val loss =
+ if (margin > 0) {
+ math.log(1 + math.exp(0 - margin))
+ } else {
+ math.log(1 + math.exp(margin)) - margin
+ }
+
+ (gradient, loss)
+ }
+}
diff --git a/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala
new file mode 100644
index 0000000000..eff853f379
--- /dev/null
+++ b/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala
@@ -0,0 +1,62 @@
+package spark.mllib.optimization
+
+import spark.{Logging, RDD, SparkContext}
+import spark.SparkContext._
+
+import org.jblas.DoubleMatrix
+
+import scala.collection.mutable.ArrayBuffer
+
+
+object GradientDescent {
+
+ /**
+ * Run gradient descent in parallel using mini batches.
+ * Based on Matlab code written by John Duchi.
+ *
+ * @param data - Input data for SGD. RDD of form (label, [feature values]).
+ * @param gradient - Gradient object that will be used to compute the gradient.
+ * @param updater - Updater object that will be used to update the model.
+ * @param stepSize - stepSize to be used during update.
+ * @param numIters - number of iterations that SGD should be run.
+ * @param miniBatchFraction - fraction of the input data set that should be used for
+ * one iteration of SGD. Default value 1.0.
+ *
+ * @return weights - Column matrix containing weights for every feature.
+ * @return lossHistory - Array containing the loss computed for every iteration.
+ */
+ def runMiniBatchSGD(
+ data: RDD[(Double, Array[Double])],
+ gradient: Gradient,
+ updater: Updater,
+ stepSize: Double,
+ numIters: Int,
+ miniBatchFraction: Double=1.0) : (DoubleMatrix, Array[Double]) = {
+
+ val lossHistory = new ArrayBuffer[Double](numIters)
+
+ val nfeatures: Int = data.take(1)(0)._2.length
+ val nexamples: Long = data.count()
+ val miniBatchSize = nexamples * miniBatchFraction
+
+ // Initialize weights as a column matrix
+ var weights = DoubleMatrix.ones(nfeatures)
+ var reg_val = 0.0
+
+ for (i <- 1 to numIters) {
+ val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42+i).map {
+ case (y, features) =>
+ val featuresRow = new DoubleMatrix(features.length, 1, features:_*)
+ val (grad, loss) = gradient.compute(featuresRow, y, weights)
+ (grad, loss)
+ }.reduce((a, b) => (a._1.addi(b._1), a._2 + b._2))
+
+ lossHistory.append(lossSum / miniBatchSize + reg_val)
+ val update = updater.compute(weights, gradientSum.div(miniBatchSize), stepSize, i)
+ weights = update._1
+ reg_val = update._2
+ }
+
+ (weights, lossHistory.toArray)
+ }
+}
diff --git a/mllib/src/main/scala/spark/mllib/optimization/Updater.scala b/mllib/src/main/scala/spark/mllib/optimization/Updater.scala
new file mode 100644
index 0000000000..ea80bfcbfd
--- /dev/null
+++ b/mllib/src/main/scala/spark/mllib/optimization/Updater.scala
@@ -0,0 +1,27 @@
+package spark.mllib.optimization
+
+import org.jblas.DoubleMatrix
+
+abstract class Updater extends Serializable {
+ /**
+ * Compute an updated value for weights given the gradient, stepSize and iteration number.
+ *
+ * @param weightsOld - Column matrix of size nx1 where n is the number of features.
+ * @param gradient - Column matrix of size nx1 where n is the number of features.
+ * @param stepSize - step size across iterations
+ * @param iter - Iteration number
+ *
+ * @return weightsNew - Column matrix containing updated weights
+ * @return reg_val - regularization value
+ */
+ def compute(weightsOlds: DoubleMatrix, gradient: DoubleMatrix, stepSize: Double, iter: Int):
+ (DoubleMatrix, Double)
+}
+
+class SimpleUpdater extends Updater {
+ override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix,
+ stepSize: Double, iter: Int): (DoubleMatrix, Double) = {
+ val normGradient = gradient.mul(stepSize / math.sqrt(iter))
+ (weightsOld.sub(normGradient), 0)
+ }
+}
diff --git a/mllib/src/main/scala/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/spark/mllib/recommendation/ALS.scala
new file mode 100644
index 0000000000..6c9fb2359c
--- /dev/null
+++ b/mllib/src/main/scala/spark/mllib/recommendation/ALS.scala
@@ -0,0 +1,389 @@
+package spark.mllib.recommendation
+
+import scala.collection.mutable.{ArrayBuffer, BitSet}
+import scala.util.Random
+
+import spark.{HashPartitioner, Partitioner, SparkContext, RDD}
+import spark.storage.StorageLevel
+import spark.SparkContext._
+
+import org.jblas.{DoubleMatrix, SimpleBlas, Solve}
+
+
+/**
+ * Out-link information for a user or product block. This includes the original user/product IDs
+ * of the elements within this block, and the list of destination blocks that each user or
+ * product will need to send its feature vector to.
+ */
+private[recommendation] case class OutLinkBlock(
+ elementIds: Array[Int], shouldSend: Array[BitSet])
+
+
+/**
+ * In-link information for a user (or product) block. This includes the original user/product IDs
+ * of the elements within this block, as well as an array of indices and ratings that specify
+ * which user in the block will be rated by which products from each product block (or vice-versa).
+ * Specifically, if this InLinkBlock is for users, ratingsForBlock(b)(i) will contain two arrays,
+ * indices and ratings, for the i'th product that will be sent to us by product block b (call this
+ * P). These arrays represent the users that product P had ratings for (by their index in this
+ * block), as well as the corresponding rating for each one. We can thus use this information when
+ * we get product block b's message to update the corresponding users.
+ */
+private[recommendation] case class InLinkBlock(
+ elementIds: Array[Int], ratingsForBlock: Array[Array[(Array[Int], Array[Double])]])
+
+
+/**
+ * Alternating Least Squares matrix factorization.
+ *
+ * This is a blocked implementation of the ALS factorization algorithm that groups the two sets
+ * of factors (referred to as "users" and "products") into blocks and reduces communication by only
+ * sending one copy of each user vector to each product block on each iteration, and only for the
+ * product blocks that need that user's feature vector. This is achieved by precomputing some
+ * information about the ratings matrix to determine the "out-links" of each user (which blocks of
+ * products it will contribute to) and "in-link" information for each product (which of the feature
+ * vectors it receives from each user block it will depend on). This allows us to send only an
+ * array of feature vectors between each user block and product block, and have the product block
+ * find the users' ratings and update the products based on these messages.
+ */
+class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var lambda: Double)
+ extends Serializable
+{
+ def this() = this(-1, 10, 10, 0.01)
+
+ /**
+ * Set the number of blocks to parallelize the computation into; pass -1 for an auto-configured
+ * number of blocks. Default: -1.
+ */
+ def setBlocks(numBlocks: Int): ALS = {
+ this.numBlocks = numBlocks
+ this
+ }
+
+ /** Set the rank of the feature matrices computed (number of features). Default: 10. */
+ def setRank(rank: Int): ALS = {
+ this.rank = rank
+ this
+ }
+
+ /** Set the number of iterations to run. Default: 10. */
+ def setIterations(iterations: Int): ALS = {
+ this.iterations = iterations
+ this
+ }
+
+ /** Set the regularization parameter, lambda. Default: 0.01. */
+ def setLambda(lambda: Double): ALS = {
+ this.lambda = lambda
+ this
+ }
+
+ /**
+ * Run ALS with the configured parmeters on an input RDD of (user, product, rating) triples.
+ * Returns a MatrixFactorizationModel with feature vectors for each user and product.
+ */
+ def train(ratings: RDD[(Int, Int, Double)]): MatrixFactorizationModel = {
+ val numBlocks = if (this.numBlocks == -1) {
+ math.max(ratings.context.defaultParallelism, ratings.partitions.size)
+ } else {
+ this.numBlocks
+ }
+
+ val partitioner = new HashPartitioner(numBlocks)
+
+ val ratingsByUserBlock = ratings.map{ case (u, p, r) => (u % numBlocks, (u, p, r)) }
+ val ratingsByProductBlock = ratings.map{ case (u, p, r) => (p % numBlocks, (p, u, r)) }
+
+ val (userInLinks, userOutLinks) = makeLinkRDDs(numBlocks, ratingsByUserBlock)
+ val (productInLinks, productOutLinks) = makeLinkRDDs(numBlocks, ratingsByProductBlock)
+
+ // Initialize user and product factors randomly
+ val seed = new Random().nextInt()
+ var users = userOutLinks.mapValues(_.elementIds.map(u => randomFactor(rank, seed ^ u)))
+ var products = productOutLinks.mapValues(_.elementIds.map(p => randomFactor(rank, seed ^ ~p)))
+
+ for (iter <- 0 until iterations) {
+ // perform ALS update
+ products = updateFeatures(users, userOutLinks, productInLinks, partitioner, rank, lambda)
+ users = updateFeatures(products, productOutLinks, userInLinks, partitioner, rank, lambda)
+ }
+
+ // Flatten and cache the two final RDDs to un-block them
+ val usersOut = users.join(userOutLinks).flatMap { case (b, (factors, outLinkBlock)) =>
+ for (i <- 0 until factors.length) yield (outLinkBlock.elementIds(i), factors(i))
+ }
+ val productsOut = products.join(productOutLinks).flatMap { case (b, (factors, outLinkBlock)) =>
+ for (i <- 0 until factors.length) yield (outLinkBlock.elementIds(i), factors(i))
+ }
+
+ usersOut.persist()
+ productsOut.persist()
+
+ new MatrixFactorizationModel(rank, usersOut, productsOut)
+ }
+
+ /**
+ * Make the out-links table for a block of the users (or products) dataset given the list of
+ * (user, product, rating) values for the users in that block (or the opposite for products).
+ */
+ private def makeOutLinkBlock(numBlocks: Int, ratings: Array[(Int, Int, Double)]): OutLinkBlock = {
+ val userIds = ratings.map(_._1).distinct.sorted
+ val numUsers = userIds.length
+ val userIdToPos = userIds.zipWithIndex.toMap
+ val shouldSend = Array.fill(numUsers)(new BitSet(numBlocks))
+ for ((u, p, r) <- ratings) {
+ shouldSend(userIdToPos(u))(p % numBlocks) = true
+ }
+ OutLinkBlock(userIds, shouldSend)
+ }
+
+ /**
+ * Make the in-links table for a block of the users (or products) dataset given a list of
+ * (user, product, rating) values for the users in that block (or the opposite for products).
+ */
+ private def makeInLinkBlock(numBlocks: Int, ratings: Array[(Int, Int, Double)]): InLinkBlock = {
+ val userIds = ratings.map(_._1).distinct.sorted
+ val numUsers = userIds.length
+ val userIdToPos = userIds.zipWithIndex.toMap
+ val ratingsForBlock = new Array[Array[(Array[Int], Array[Double])]](numBlocks)
+ for (productBlock <- 0 until numBlocks) {
+ val ratingsInBlock = ratings.filter(t => t._2 % numBlocks == productBlock)
+ val ratingsByProduct = ratingsInBlock.groupBy(_._2) // (p, Seq[(u, p, r)])
+ .toArray
+ .sortBy(_._1)
+ .map{case (p, rs) => (rs.map(t => userIdToPos(t._1)), rs.map(_._3))}
+ ratingsForBlock(productBlock) = ratingsByProduct
+ }
+ InLinkBlock(userIds, ratingsForBlock)
+ }
+
+ /**
+ * Make RDDs of InLinkBlocks and OutLinkBlocks given an RDD of (blockId, (u, p, r)) values for
+ * the users (or (blockId, (p, u, r)) for the products). We create these simultaneously to avoid
+ * having to shuffle the (blockId, (u, p, r)) RDD twice, or to cache it.
+ */
+ private def makeLinkRDDs(numBlocks: Int, ratings: RDD[(Int, (Int, Int, Double))])
+ : (RDD[(Int, InLinkBlock)], RDD[(Int, OutLinkBlock)]) =
+ {
+ val grouped = ratings.partitionBy(new HashPartitioner(numBlocks))
+ val links = grouped.mapPartitionsWithIndex((blockId, elements) => {
+ val ratings = elements.map(_._2).toArray
+ val inLinkBlock = makeInLinkBlock(numBlocks, ratings)
+ val outLinkBlock = makeOutLinkBlock(numBlocks, ratings)
+ Iterator.single((blockId, (inLinkBlock, outLinkBlock)))
+ }, true)
+ links.persist(StorageLevel.MEMORY_AND_DISK)
+ (links.mapValues(_._1), links.mapValues(_._2))
+ }
+
+ /**
+ * Make a random factor vector with the given seed.
+ * TODO: Initialize things using mapPartitionsWithIndex to make it faster?
+ */
+ private def randomFactor(rank: Int, seed: Int): Array[Double] = {
+ val rand = new Random(seed)
+ Array.fill(rank)(rand.nextDouble)
+ }
+
+ /**
+ * Compute the user feature vectors given the current products (or vice-versa). This first joins
+ * the products with their out-links to generate a set of messages to each destination block
+ * (specifically, the features for the products that user block cares about), then groups these
+ * by destination and joins them with the in-link info to figure out how to update each user.
+ * It returns an RDD of new feature vectors for each user block.
+ */
+ private def updateFeatures(
+ products: RDD[(Int, Array[Array[Double]])],
+ productOutLinks: RDD[(Int, OutLinkBlock)],
+ userInLinks: RDD[(Int, InLinkBlock)],
+ partitioner: Partitioner,
+ rank: Int,
+ lambda: Double)
+ : RDD[(Int, Array[Array[Double]])] =
+ {
+ val numBlocks = products.partitions.size
+ productOutLinks.join(products).flatMap { case (bid, (outLinkBlock, factors)) =>
+ val toSend = Array.fill(numBlocks)(new ArrayBuffer[Array[Double]])
+ for (p <- 0 until outLinkBlock.elementIds.length; userBlock <- 0 until numBlocks) {
+ if (outLinkBlock.shouldSend(p)(userBlock)) {
+ toSend(userBlock) += factors(p)
+ }
+ }
+ toSend.zipWithIndex.map{ case (buf, idx) => (idx, (bid, buf.toArray)) }
+ }.groupByKey(partitioner)
+ .join(userInLinks)
+ .mapValues{ case (messages, inLinkBlock) => updateBlock(messages, inLinkBlock, rank, lambda) }
+ }
+
+ /**
+ * Compute the new feature vectors for a block of the users matrix given the list of factors
+ * it received from each product and its InLinkBlock.
+ */
+ def updateBlock(messages: Seq[(Int, Array[Array[Double]])], inLinkBlock: InLinkBlock,
+ rank: Int, lambda: Double)
+ : Array[Array[Double]] =
+ {
+ // Sort the incoming block factor messages by block ID and make them an array
+ val blockFactors = messages.sortBy(_._1).map(_._2).toArray // Array[Array[Double]]
+ val numBlocks = blockFactors.length
+ val numUsers = inLinkBlock.elementIds.length
+
+ // We'll sum up the XtXes using vectors that represent only the lower-triangular part, since
+ // the matrices are symmetric
+ val triangleSize = rank * (rank + 1) / 2
+ val userXtX = Array.fill(numUsers)(DoubleMatrix.zeros(triangleSize))
+ val userXy = Array.fill(numUsers)(DoubleMatrix.zeros(rank))
+
+ // Some temp variables to avoid memory allocation
+ val tempXtX = DoubleMatrix.zeros(triangleSize)
+ val fullXtX = DoubleMatrix.zeros(rank, rank)
+
+ // Compute the XtX and Xy values for each user by adding products it rated in each product block
+ for (productBlock <- 0 until numBlocks) {
+ for (p <- 0 until blockFactors(productBlock).length) {
+ val x = new DoubleMatrix(blockFactors(productBlock)(p))
+ fillXtX(x, tempXtX)
+ val (us, rs) = inLinkBlock.ratingsForBlock(productBlock)(p)
+ for (i <- 0 until us.length) {
+ userXtX(us(i)).addi(tempXtX)
+ SimpleBlas.axpy(rs(i), x, userXy(us(i)))
+ }
+ }
+ }
+
+ // Solve the least-squares problem for each user and return the new feature vectors
+ userXtX.zipWithIndex.map{ case (triangularXtX, index) =>
+ // Compute the full XtX matrix from the lower-triangular part we got above
+ fillFullMatrix(triangularXtX, fullXtX)
+ // Add regularization
+ (0 until rank).foreach(i => fullXtX.data(i*rank + i) += lambda)
+ // Solve the resulting matrix, which is symmetric and positive-definite
+ Solve.solvePositive(fullXtX, userXy(index)).data
+ }
+ }
+
+ /**
+ * Set xtxDest to the lower-triangular part of x transpose * x. For efficiency in summing
+ * these matrices, we store xtxDest as only rank * (rank+1) / 2 values, namely the values
+ * at (0,0), (1,0), (1,1), (2,0), (2,1), (2,2), etc in that order.
+ */
+ private def fillXtX(x: DoubleMatrix, xtxDest: DoubleMatrix) {
+ var i = 0
+ var pos = 0
+ while (i < x.length) {
+ var j = 0
+ while (j <= i) {
+ xtxDest.data(pos) = x.data(i) * x.data(j)
+ pos += 1
+ j += 1
+ }
+ i += 1
+ }
+ }
+
+ /**
+ * Given a triangular matrix in the order of fillXtX above, compute the full symmetric square
+ * matrix that it represents, storing it into destMatrix.
+ */
+ private def fillFullMatrix(triangularMatrix: DoubleMatrix, destMatrix: DoubleMatrix) {
+ val rank = destMatrix.rows
+ var i = 0
+ var pos = 0
+ while (i < rank) {
+ var j = 0
+ while (j <= i) {
+ destMatrix.data(i*rank + j) = triangularMatrix.data(pos)
+ destMatrix.data(j*rank + i) = triangularMatrix.data(pos)
+ pos += 1
+ j += 1
+ }
+ i += 1
+ }
+ }
+}
+
+
+/**
+ * Top-level methods for calling Alternating Least Squares (ALS) matrix factorizaton.
+ */
+object ALS {
+ /**
+ * Train a matrix factorization model given an RDD of ratings given by users to some products,
+ * in the form of (userID, productID, rating) pairs. We approximate the ratings matrix as the
+ * product of two lower-rank matrices of a given rank (number of features). To solve for these
+ * features, we run a given number of iterations of ALS. This is done using a level of
+ * parallelism given by `blocks`.
+ *
+ * @param ratings RDD of (userID, productID, rating) pairs
+ * @param rank number of features to use
+ * @param iterations number of iterations of ALS (recommended: 10-20)
+ * @param lambda regularization factor (recommended: 0.01)
+ * @param blocks level of parallelism to split computation into
+ */
+ def train(
+ ratings: RDD[(Int, Int, Double)],
+ rank: Int,
+ iterations: Int,
+ lambda: Double,
+ blocks: Int)
+ : MatrixFactorizationModel =
+ {
+ new ALS(blocks, rank, iterations, lambda).train(ratings)
+ }
+
+ /**
+ * Train a matrix factorization model given an RDD of ratings given by users to some products,
+ * in the form of (userID, productID, rating) pairs. We approximate the ratings matrix as the
+ * product of two lower-rank matrices of a given rank (number of features). To solve for these
+ * features, we run a given number of iterations of ALS. The level of parallelism is determined
+ * automatically based on the number of partitions in `ratings`.
+ *
+ * @param ratings RDD of (userID, productID, rating) pairs
+ * @param rank number of features to use
+ * @param iterations number of iterations of ALS (recommended: 10-20)
+ * @param lambda regularization factor (recommended: 0.01)
+ */
+ def train(ratings: RDD[(Int, Int, Double)], rank: Int, iterations: Int, lambda: Double)
+ : MatrixFactorizationModel =
+ {
+ train(ratings, rank, iterations, lambda, -1)
+ }
+
+ /**
+ * Train a matrix factorization model given an RDD of ratings given by users to some products,
+ * in the form of (userID, productID, rating) pairs. We approximate the ratings matrix as the
+ * product of two lower-rank matrices of a given rank (number of features). To solve for these
+ * features, we run a given number of iterations of ALS. The level of parallelism is determined
+ * automatically based on the number of partitions in `ratings`.
+ *
+ * @param ratings RDD of (userID, productID, rating) pairs
+ * @param rank number of features to use
+ * @param iterations number of iterations of ALS (recommended: 10-20)
+ */
+ def train(ratings: RDD[(Int, Int, Double)], rank: Int, iterations: Int)
+ : MatrixFactorizationModel =
+ {
+ train(ratings, rank, iterations, 0.01, -1)
+ }
+
+ def main(args: Array[String]) {
+ if (args.length != 5) {
+ println("Usage: ALS <master> <ratings_file> <rank> <iterations> <output_dir>")
+ System.exit(1)
+ }
+ val (master, ratingsFile, rank, iters, outputDir) =
+ (args(0), args(1), args(2).toInt, args(3).toInt, args(4))
+ val sc = new SparkContext(master, "ALS")
+ val ratings = sc.textFile(ratingsFile).map { line =>
+ val fields = line.split(',')
+ (fields(0).toInt, fields(1).toInt, fields(2).toDouble)
+ }
+ val model = ALS.train(ratings, rank, iters)
+ model.userFeatures.map{ case (id, vec) => id + "," + vec.mkString(" ") }
+ .saveAsTextFile(outputDir + "/userFeatures")
+ model.productFeatures.map{ case (id, vec) => id + "," + vec.mkString(" ") }
+ .saveAsTextFile(outputDir + "/productFeatures")
+ println("Final user/product features written to " + outputDir)
+ System.exit(0)
+ }
+}
diff --git a/mllib/src/main/scala/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/spark/mllib/recommendation/MatrixFactorizationModel.scala
new file mode 100644
index 0000000000..fb812a6dbe
--- /dev/null
+++ b/mllib/src/main/scala/spark/mllib/recommendation/MatrixFactorizationModel.scala
@@ -0,0 +1,23 @@
+package spark.mllib.recommendation
+
+import spark.RDD
+import spark.SparkContext._
+
+import org.jblas._
+
+class MatrixFactorizationModel(
+ val rank: Int,
+ val userFeatures: RDD[(Int, Array[Double])],
+ val productFeatures: RDD[(Int, Array[Double])])
+ extends Serializable
+{
+ /** Predict the rating of one user for one product. */
+ def predict(user: Int, product: Int): Double = {
+ val userVector = new DoubleMatrix(userFeatures.lookup(user).head)
+ val productVector = new DoubleMatrix(productFeatures.lookup(product).head)
+ userVector.dot(productVector)
+ }
+
+ // TODO: Figure out what good bulk prediction methods would look like.
+ // Probably want a way to get the top users for a product or vice-versa.
+}
diff --git a/mllib/src/main/scala/spark/mllib/regression/LogisticRegression.scala b/mllib/src/main/scala/spark/mllib/regression/LogisticRegression.scala
new file mode 100644
index 0000000000..448ab9dce9
--- /dev/null
+++ b/mllib/src/main/scala/spark/mllib/regression/LogisticRegression.scala
@@ -0,0 +1,158 @@
+package spark.mllib.regression
+
+import spark.{Logging, RDD, SparkContext}
+import spark.mllib.optimization._
+import spark.mllib.util.MLUtils
+
+import org.jblas.DoubleMatrix
+
+/**
+ * Logistic Regression using Stochastic Gradient Descent.
+ * Based on Matlab code written by John Duchi.
+ */
+class LogisticRegressionModel(
+ val weights: DoubleMatrix,
+ val intercept: Double,
+ val losses: Array[Double]) extends RegressionModel {
+
+ override def predict(testData: spark.RDD[Array[Double]]) = {
+ testData.map { x =>
+ val margin = new DoubleMatrix(1, x.length, x:_*).mmul(this.weights).get(0) + this.intercept
+ 1.0/ (1.0 + math.exp(margin * -1))
+ }
+ }
+
+ override def predict(testData: Array[Double]): Double = {
+ val dataMat = new DoubleMatrix(1, testData.length, testData:_*)
+ val margin = dataMat.mmul(this.weights).get(0) + this.intercept
+ 1.0/ (1.0 + math.exp(margin * -1))
+ }
+}
+
+class LogisticRegression private (var stepSize: Double, var miniBatchFraction: Double,
+ var numIters: Int)
+ extends Logging {
+
+ /**
+ * Construct a LogisticRegression object with default parameters
+ */
+ def this() = this(1.0, 1.0, 100)
+
+ /**
+ * Set the step size per-iteration of SGD. Default 1.0.
+ */
+ def setStepSize(step: Double) = {
+ this.stepSize = step
+ this
+ }
+
+ /**
+ * Set fraction of data to be used for each SGD iteration. Default 1.0.
+ */
+ def setMiniBatchFraction(fraction: Double) = {
+ this.miniBatchFraction = fraction
+ this
+ }
+
+ /**
+ * Set the number of iterations for SGD. Default 100.
+ */
+ def setNumIterations(iters: Int) = {
+ this.numIters = iters
+ this
+ }
+
+ def train(input: RDD[(Double, Array[Double])]): LogisticRegressionModel = {
+ // Add a extra variable consisting of all 1.0's for the intercept.
+ val data = input.map { case (y, features) =>
+ (y, Array(1.0, features:_*))
+ }
+
+ val (weights, losses) = GradientDescent.runMiniBatchSGD(
+ data, new LogisticGradient(), new SimpleUpdater(), stepSize, numIters, miniBatchFraction)
+
+ val weightsScaled = weights.getRange(1, weights.length)
+ val intercept = weights.get(0)
+
+ val model = new LogisticRegressionModel(weightsScaled, intercept, losses)
+
+ logInfo("Final model weights " + model.weights)
+ logInfo("Final model intercept " + model.intercept)
+ logInfo("Last 10 losses " + model.losses.takeRight(10).mkString(", "))
+ model
+ }
+}
+
+/**
+ * Top-level methods for calling Logistic Regression.
+ */
+object LogisticRegression {
+
+ /**
+ * Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed number
+ * of iterations of gradient descent using the specified step size. Each iteration uses
+ * `miniBatchFraction` fraction of the data to calculate the gradient.
+ *
+ * @param input RDD of (label, array of features) pairs.
+ * @param numIterations Number of iterations of gradient descent to run.
+ * @param stepSize Step size to be used for each iteration of gradient descent.
+ * @param miniBatchFraction Fraction of data to be used per iteration.
+ */
+ def train(
+ input: RDD[(Double, Array[Double])],
+ numIterations: Int,
+ stepSize: Double,
+ miniBatchFraction: Double)
+ : LogisticRegressionModel =
+ {
+ new LogisticRegression(stepSize, miniBatchFraction, numIterations).train(input)
+ }
+
+ /**
+ * Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed number
+ * of iterations of gradient descent using the specified step size. We use the entire data set to update
+ * the gradient in each iteration.
+ *
+ * @param input RDD of (label, array of features) pairs.
+ * @param stepSize Step size to be used for each iteration of Gradient Descent.
+ * @param numIterations Number of iterations of gradient descent to run.
+ * @return a LogisticRegressionModel which has the weights and offset from training.
+ */
+ def train(
+ input: RDD[(Double, Array[Double])],
+ numIterations: Int,
+ stepSize: Double)
+ : LogisticRegressionModel =
+ {
+ train(input, numIterations, stepSize, 1.0)
+ }
+
+ /**
+ * Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed number
+ * of iterations of gradient descent using a step size of 1.0. We use the entire data set to update
+ * the gradient in each iteration.
+ *
+ * @param input RDD of (label, array of features) pairs.
+ * @param numIterations Number of iterations of gradient descent to run.
+ * @return a LogisticRegressionModel which has the weights and offset from training.
+ */
+ def train(
+ input: RDD[(Double, Array[Double])],
+ numIterations: Int)
+ : LogisticRegressionModel =
+ {
+ train(input, numIterations, 1.0, 1.0)
+ }
+
+ def main(args: Array[String]) {
+ if (args.length != 4) {
+ println("Usage: LogisticRegression <master> <input_dir> <step_size> <niters>")
+ System.exit(1)
+ }
+ val sc = new SparkContext(args(0), "LogisticRegression")
+ val data = MLUtils.loadData(sc, args(1))
+ val model = LogisticRegression.train(data, args(3).toInt, args(2).toDouble)
+
+ sc.stop()
+ }
+}
diff --git a/mllib/src/main/scala/spark/mllib/regression/LogisticRegressionGenerator.scala b/mllib/src/main/scala/spark/mllib/regression/LogisticRegressionGenerator.scala
new file mode 100644
index 0000000000..9f6abab70b
--- /dev/null
+++ b/mllib/src/main/scala/spark/mllib/regression/LogisticRegressionGenerator.scala
@@ -0,0 +1,41 @@
+package spark.mllib.regression
+
+import scala.util.Random
+
+import org.jblas.DoubleMatrix
+
+import spark.{RDD, SparkContext}
+import spark.mllib.util.MLUtils
+
+object LogisticRegressionGenerator {
+
+ def main(args: Array[String]) {
+ if (args.length != 5) {
+ println("Usage: LogisticRegressionGenerator " +
+ "<master> <output_dir> <num_examples> <num_features> <num_partitions>")
+ System.exit(1)
+ }
+
+ val sparkMaster: String = args(0)
+ val outputPath: String = args(1)
+ val nexamples: Int = if (args.length > 2) args(2).toInt else 1000
+ val nfeatures: Int = if (args.length > 3) args(3).toInt else 2
+ val parts: Int = if (args.length > 4) args(4).toInt else 2
+ val eps = 3
+
+ val sc = new SparkContext(sparkMaster, "LogisticRegressionGenerator")
+
+ val data: RDD[(Double, Array[Double])] = sc.parallelize(0 until nexamples, parts).map { idx =>
+ val rnd = new Random(42 + idx)
+
+ val y = if (idx % 2 == 0) 0 else 1
+ val x = Array.fill[Double](nfeatures) {
+ rnd.nextGaussian() + (y * eps)
+ }
+ (y, x)
+ }
+
+ MLUtils.saveData(data, outputPath)
+ sc.stop()
+ }
+}
diff --git a/mllib/src/main/scala/spark/mllib/regression/Regression.scala b/mllib/src/main/scala/spark/mllib/regression/Regression.scala
new file mode 100644
index 0000000000..f79974c191
--- /dev/null
+++ b/mllib/src/main/scala/spark/mllib/regression/Regression.scala
@@ -0,0 +1,21 @@
+package spark.mllib.regression
+
+import spark.RDD
+
+trait RegressionModel {
+ /**
+ * Predict values for the given data set using the model trained.
+ *
+ * @param testData RDD representing data points to be predicted
+ * @return RDD[Double] where each entry contains the corresponding prediction
+ */
+ def predict(testData: RDD[Array[Double]]): RDD[Double]
+
+ /**
+ * Predict values for a single data point using the model trained.
+ *
+ * @param testData array representing a single data point
+ * @return Double prediction from the trained model
+ */
+ def predict(testData: Array[Double]): Double
+}
diff --git a/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala
new file mode 100644
index 0000000000..a6ececbeb6
--- /dev/null
+++ b/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala
@@ -0,0 +1,183 @@
+package spark.mllib.regression
+
+import spark.{Logging, RDD, SparkContext}
+import spark.SparkContext._
+import spark.mllib.util.MLUtils
+
+import org.jblas.DoubleMatrix
+import org.jblas.Solve
+
+/**
+ * Ridge Regression from Joseph Gonzalez's implementation in MLBase
+ */
+class RidgeRegressionModel(
+ val weights: DoubleMatrix,
+ val intercept: Double,
+ val lambdaOpt: Double,
+ val lambdas: List[(Double, Double, DoubleMatrix)])
+ extends RegressionModel {
+
+ override def predict(testData: RDD[Array[Double]]): RDD[Double] = {
+ testData.map { x =>
+ (new DoubleMatrix(1, x.length, x:_*).mmul(this.weights)).get(0) + this.intercept
+ }
+ }
+
+ override def predict(testData: Array[Double]): Double = {
+ (new DoubleMatrix(1, testData.length, testData:_*).mmul(this.weights)).get(0) + this.intercept
+ }
+}
+
+class RidgeRegression private (var lambdaLow: Double, var lambdaHigh: Double)
+ extends Logging {
+
+ def this() = this(0.0, 100.0)
+
+ /**
+ * Set the lower bound on binary search for lambda's. Default is 0.
+ */
+ def setLowLambda(low: Double) = {
+ this.lambdaLow = low
+ this
+ }
+
+ /**
+ * Set the upper bound on binary search for lambda's. Default is 100.0.
+ */
+ def setHighLambda(hi: Double) = {
+ this.lambdaHigh = hi
+ this
+ }
+
+ def train(input: RDD[(Double, Array[Double])]): RidgeRegressionModel = {
+ val nfeatures: Int = input.take(1)(0)._2.length
+ val nexamples: Long = input.count()
+
+ val (yMean, xColMean, xColSd) = MLUtils.computeStats(input, nfeatures, nexamples)
+
+ val data = input.map { case(y, features) =>
+ val yNormalized = y - yMean
+ val featuresMat = new DoubleMatrix(nfeatures, 1, features:_*)
+ val featuresNormalized = featuresMat.sub(xColMean).divi(xColSd)
+ (yNormalized, featuresNormalized.toArray)
+ }
+
+ // Compute XtX - Size of XtX is nfeatures by nfeatures
+ val XtX: DoubleMatrix = data.map { case (y, features) =>
+ val x = new DoubleMatrix(1, features.length, features:_*)
+ x.transpose().mmul(x)
+ }.reduce(_.addi(_))
+
+ // Compute Xt*y - Size of Xty is nfeatures by 1
+ val Xty: DoubleMatrix = data.map { case (y, features) =>
+ new DoubleMatrix(features.length, 1, features:_*).mul(y)
+ }.reduce(_.addi(_))
+
+ // Define a function to compute the leave one out cross validation error
+ // for a single example
+ def crossValidate(lambda: Double): (Double, Double, DoubleMatrix) = {
+ // Compute the MLE ridge regression parameter value
+
+ // Ridge Regression parameter = inv(XtX + \lambda*I) * Xty
+ val XtXlambda = DoubleMatrix.eye(nfeatures).muli(lambda).addi(XtX)
+ val w = Solve.solveSymmetric(XtXlambda, Xty)
+
+ val invXtX = Solve.solveSymmetric(XtXlambda, DoubleMatrix.eye(nfeatures))
+
+ // compute the generalized cross validation score
+ val cverror = data.map {
+ case (y, features) =>
+ val x = new DoubleMatrix(features.length, 1, features:_*)
+ val yhat = w.transpose().mmul(x).get(0)
+ val H_ii = x.transpose().mmul(invXtX).mmul(x).get(0)
+ val residual = (y - yhat) / (1.0 - H_ii)
+ residual * residual
+ }.reduce(_ + _) / nexamples
+
+ (lambda, cverror, w)
+ }
+
+ // Binary search for the best assignment to lambda.
+ def binSearch(low: Double, high: Double): List[(Double, Double, DoubleMatrix)] = {
+ val mid = (high - low) / 2 + low
+ val lowValue = crossValidate((mid - low) / 2 + low)
+ val highValue = crossValidate((high - mid) / 2 + mid)
+ val (newLow, newHigh) = if (lowValue._2 < highValue._2) {
+ (low, mid + (high-low)/4)
+ } else {
+ (mid - (high-low)/4, high)
+ }
+ if (newHigh - newLow > 1.0E-7) {
+ // :: is list prepend in Scala.
+ lowValue :: highValue :: binSearch(newLow, newHigh)
+ } else {
+ List(lowValue, highValue)
+ }
+ }
+
+ // Actually compute the best lambda
+ val lambdas = binSearch(lambdaLow, lambdaHigh).sortBy(_._1)
+
+ // Find the best parameter set by taking the lowest cverror.
+ val (lambdaOpt, cverror, weights) = lambdas.reduce((a, b) => if (a._2 < b._2) a else b)
+
+ // Return the model which contains the solution
+ val weightsScaled = weights.div(xColSd)
+ val intercept = yMean - (weights.transpose().mmul(xColMean.div(xColSd)).get(0))
+ val model = new RidgeRegressionModel(weightsScaled, intercept, lambdaOpt, lambdas)
+
+ logInfo("RidgeRegression: optimal lambda " + model.lambdaOpt)
+ logInfo("RidgeRegression: optimal weights " + model.weights)
+ logInfo("RidgeRegression: optimal intercept " + model.intercept)
+ logInfo("RidgeRegression: cross-validation error " + cverror)
+
+ model
+ }
+}
+/**
+ * Top-level methods for calling Ridge Regression.
+ */
+object RidgeRegression {
+
+ /**
+ * Train a ridge regression model given an RDD of (response, features) pairs.
+ * We use the closed form solution to compute the cross-validation score for
+ * a given lambda. The optimal lambda is computed by performing binary search
+ * between the provided bounds of lambda.
+ *
+ * @param input RDD of (response, array of features) pairs.
+ * @param lambdaLow lower bound used in binary search for lambda
+ * @param lambdaHigh upper bound used in binary search for lambda
+ */
+ def train(
+ input: RDD[(Double, Array[Double])],
+ lambdaLow: Double,
+ lambdaHigh: Double)
+ : RidgeRegressionModel =
+ {
+ new RidgeRegression(lambdaLow, lambdaHigh).train(input)
+ }
+
+ /**
+ * Train a ridge regression model given an RDD of (response, features) pairs.
+ * We use the closed form solution to compute the cross-validation score for
+ * a given lambda. The optimal lambda is computed by performing binary search
+ * between lambda values of 0 and 100.
+ *
+ * @param input RDD of (response, array of features) pairs.
+ */
+ def train(input: RDD[(Double, Array[Double])]) : RidgeRegressionModel = {
+ train(input, 0.0, 100.0)
+ }
+
+ def main(args: Array[String]) {
+ if (args.length != 2) {
+ println("Usage: RidgeRegression <master> <input_dir>")
+ System.exit(1)
+ }
+ val sc = new SparkContext(args(0), "RidgeRegression")
+ val data = MLUtils.loadData(sc, args(1))
+ val model = RidgeRegression.train(data, 0, 1000)
+ sc.stop()
+ }
+}
diff --git a/mllib/src/main/scala/spark/mllib/regression/RidgeRegressionGenerator.scala b/mllib/src/main/scala/spark/mllib/regression/RidgeRegressionGenerator.scala
new file mode 100644
index 0000000000..c9ac4a8b07
--- /dev/null
+++ b/mllib/src/main/scala/spark/mllib/regression/RidgeRegressionGenerator.scala
@@ -0,0 +1,55 @@
+package spark.mllib.regression
+
+import scala.util.Random
+
+import org.jblas.DoubleMatrix
+
+import spark.{RDD, SparkContext}
+import spark.mllib.util.MLUtils
+
+
+object RidgeRegressionGenerator {
+
+ def main(args: Array[String]) {
+ if (args.length != 5) {
+ println("Usage: RidgeRegressionGenerator " +
+ "<master> <output_dir> <num_examples> <num_features> <num_partitions>")
+ System.exit(1)
+ }
+
+ val sparkMaster: String = args(0)
+ val outputPath: String = args(1)
+ val nexamples: Int = if (args.length > 2) args(2).toInt else 1000
+ val nfeatures: Int = if (args.length > 3) args(3).toInt else 100
+ val parts: Int = if (args.length > 4) args(4).toInt else 2
+ val eps = 10
+
+ org.jblas.util.Random.seed(42)
+ val sc = new SparkContext(sparkMaster, "RidgeRegressionGenerator")
+
+ // Random values distributed uniformly in [-0.5, 0.5]
+ val w = DoubleMatrix.rand(nfeatures, 1).subi(0.5)
+ w.put(0, 0, 10)
+ w.put(1, 0, 10)
+
+ val data: RDD[(Double, Array[Double])] = sc.parallelize(0 until parts, parts).flatMap { p =>
+ org.jblas.util.Random.seed(42 + p)
+ val examplesInPartition = nexamples / parts
+
+ val X = DoubleMatrix.rand(examplesInPartition, nfeatures)
+ val y = X.mmul(w)
+
+ val rnd = new Random(42 + p)
+
+ val normalValues = Array.fill[Double](examplesInPartition)(rnd.nextGaussian() * eps)
+ val yObs = new DoubleMatrix(normalValues).addi(y)
+
+ Iterator.tabulate(examplesInPartition) { i =>
+ (yObs.get(i, 0), X.getRow(i).toArray)
+ }
+ }
+
+ MLUtils.saveData(data, outputPath)
+ sc.stop()
+ }
+}
diff --git a/mllib/src/main/scala/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/spark/mllib/util/MLUtils.scala
new file mode 100644
index 0000000000..0a4a037c71
--- /dev/null
+++ b/mllib/src/main/scala/spark/mllib/util/MLUtils.scala
@@ -0,0 +1,95 @@
+package spark.mllib.util
+
+import spark.{RDD, SparkContext}
+import spark.SparkContext._
+
+import org.jblas.DoubleMatrix
+
+/**
+ * Helper methods to load and save data
+ * Data format:
+ * <l>, <f1> <f2> ...
+ * where <f1>, <f2> are feature values in Double and <l> is the corresponding label as Double.
+ */
+object MLUtils {
+
+ /**
+ * @param sc SparkContext
+ * @param dir Directory to the input data files.
+ * @return An RDD of tuples. For each tuple, the first element is the label, and the second
+ * element represents the feature values (an array of Double).
+ */
+ def loadData(sc: SparkContext, dir: String): RDD[(Double, Array[Double])] = {
+ sc.textFile(dir).map { line =>
+ val parts = line.split(",")
+ val label = parts(0).toDouble
+ val features = parts(1).trim().split(" ").map(_.toDouble)
+ (label, features)
+ }
+ }
+
+ def saveData(data: RDD[(Double, Array[Double])], dir: String) {
+ val dataStr = data.map(x => x._1 + "," + x._2.mkString(" "))
+ dataStr.saveAsTextFile(dir)
+ }
+
+ /**
+ * Utility function to compute mean and standard deviation on a given dataset.
+ *
+ * @param data - input data set whose statistics are computed
+ * @param nfeatures - number of features
+ * @param nexamples - number of examples in input dataset
+ *
+ * @return (yMean, xColMean, xColSd) - Tuple consisting of
+ * yMean - mean of the labels
+ * xColMean - Row vector with mean for every column (or feature) of the input data
+ * xColSd - Row vector standard deviation for every column (or feature) of the input data.
+ */
+ def computeStats(data: RDD[(Double, Array[Double])], nfeatures: Int, nexamples: Long):
+ (Double, DoubleMatrix, DoubleMatrix) = {
+ val yMean: Double = data.map { case (y, features) => y }.reduce(_ + _) / nexamples
+
+ // NOTE: We shuffle X by column here to compute column sum and sum of squares.
+ val xColSumSq: RDD[(Int, (Double, Double))] = data.flatMap { case(y, features) =>
+ val nCols = features.length
+ // Traverse over every column and emit (col, value, value^2)
+ Iterator.tabulate(nCols) { i =>
+ (i, (features(i), features(i)*features(i)))
+ }
+ }.reduceByKey { case(x1, x2) =>
+ (x1._1 + x2._1, x1._2 + x2._2)
+ }
+ val xColSumsMap = xColSumSq.collectAsMap()
+
+ val xColMean = DoubleMatrix.zeros(nfeatures, 1)
+ val xColSd = DoubleMatrix.zeros(nfeatures, 1)
+
+ // Compute mean and unbiased variance using column sums
+ var col = 0
+ while (col < nfeatures) {
+ xColMean.put(col, xColSumsMap(col)._1 / nexamples)
+ val variance =
+ (xColSumsMap(col)._2 - (math.pow(xColSumsMap(col)._1, 2) / nexamples)) / (nexamples)
+ xColSd.put(col, math.sqrt(variance))
+ col += 1
+ }
+
+ (yMean, xColMean, xColSd)
+ }
+
+ /**
+ * Return the squared Euclidean distance between two vectors.
+ */
+ def squaredDistance(v1: Array[Double], v2: Array[Double]): Double = {
+ if (v1.length != v2.length) {
+ throw new IllegalArgumentException("Vector sizes don't match")
+ }
+ var i = 0
+ var sum = 0.0
+ while (i < v1.length) {
+ sum += (v1(i) - v2(i)) * (v1(i) - v2(i))
+ i += 1
+ }
+ sum
+ }
+}
diff --git a/mllib/src/test/resources/log4j.properties b/mllib/src/test/resources/log4j.properties
new file mode 100644
index 0000000000..390c92763c
--- /dev/null
+++ b/mllib/src/test/resources/log4j.properties
@@ -0,0 +1,11 @@
+# Set everything to be logged to the file core/target/unit-tests.log
+log4j.rootCategory=INFO, file
+log4j.appender.file=org.apache.log4j.FileAppender
+log4j.appender.file.append=false
+log4j.appender.file.file=ml/target/unit-tests.log
+log4j.appender.file.layout=org.apache.log4j.PatternLayout
+log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n
+
+# Ignore messages below warning level from Jetty, because it's a bit verbose
+log4j.logger.org.eclipse.jetty=WARN
+
diff --git a/mllib/src/test/scala/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/spark/mllib/clustering/KMeansSuite.scala
new file mode 100644
index 0000000000..cb096f39a9
--- /dev/null
+++ b/mllib/src/test/scala/spark/mllib/clustering/KMeansSuite.scala
@@ -0,0 +1,153 @@
+package spark.mllib.clustering
+
+import scala.util.Random
+
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.FunSuite
+
+import spark.SparkContext
+import spark.SparkContext._
+
+import org.jblas._
+
+
+class KMeansSuite extends FunSuite with BeforeAndAfterAll {
+ val sc = new SparkContext("local", "test")
+
+ override def afterAll() {
+ sc.stop()
+ System.clearProperty("spark.driver.port")
+ }
+
+ val EPSILON = 1e-4
+
+ import KMeans.{RANDOM, K_MEANS_PARALLEL}
+
+ def prettyPrint(point: Array[Double]): String = point.mkString("(", ", ", ")")
+
+ def prettyPrint(points: Array[Array[Double]]): String = {
+ points.map(prettyPrint).mkString("(", "; ", ")")
+ }
+
+ // L1 distance between two points
+ def distance1(v1: Array[Double], v2: Array[Double]): Double = {
+ v1.zip(v2).map{ case (a, b) => math.abs(a-b) }.max
+ }
+
+ // Assert that two vectors are equal within tolerance EPSILON
+ def assertEqual(v1: Array[Double], v2: Array[Double]) {
+ def errorMessage = prettyPrint(v1) + " did not equal " + prettyPrint(v2)
+ assert(v1.length == v2.length, errorMessage)
+ assert(distance1(v1, v2) <= EPSILON, errorMessage)
+ }
+
+ // Assert that two sets of points are equal, within EPSILON tolerance
+ def assertSetsEqual(set1: Array[Array[Double]], set2: Array[Array[Double]]) {
+ def errorMessage = prettyPrint(set1) + " did not equal " + prettyPrint(set2)
+ assert(set1.length == set2.length, errorMessage)
+ for (v <- set1) {
+ val closestDistance = set2.map(w => distance1(v, w)).min
+ if (closestDistance > EPSILON) {
+ fail(errorMessage)
+ }
+ }
+ for (v <- set2) {
+ val closestDistance = set1.map(w => distance1(v, w)).min
+ if (closestDistance > EPSILON) {
+ fail(errorMessage)
+ }
+ }
+ }
+
+ test("single cluster") {
+ val data = sc.parallelize(Array(
+ Array(1.0, 2.0, 6.0),
+ Array(1.0, 3.0, 0.0),
+ Array(1.0, 4.0, 6.0)
+ ))
+
+ // No matter how many runs or iterations we use, we should get one cluster,
+ // centered at the mean of the points
+
+ var model = KMeans.train(data, k=1, maxIterations=1)
+ assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0)))
+
+ model = KMeans.train(data, k=1, maxIterations=2)
+ assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0)))
+
+ model = KMeans.train(data, k=1, maxIterations=5)
+ assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0)))
+
+ model = KMeans.train(data, k=1, maxIterations=1, runs=5)
+ assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0)))
+
+ model = KMeans.train(data, k=1, maxIterations=1, runs=5)
+ assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0)))
+
+ model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=RANDOM)
+ assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0)))
+
+ model = KMeans.train(
+ data, k=1, maxIterations=1, runs=1, initializationMode=K_MEANS_PARALLEL)
+ assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0)))
+ }
+
+ test("single cluster with big dataset") {
+ val smallData = Array(
+ Array(1.0, 2.0, 6.0),
+ Array(1.0, 3.0, 0.0),
+ Array(1.0, 4.0, 6.0)
+ )
+ val data = sc.parallelize((1 to 100).flatMap(_ => smallData), 4)
+
+ // No matter how many runs or iterations we use, we should get one cluster,
+ // centered at the mean of the points
+
+ var model = KMeans.train(data, k=1, maxIterations=1)
+ assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0)))
+
+ model = KMeans.train(data, k=1, maxIterations=2)
+ assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0)))
+
+ model = KMeans.train(data, k=1, maxIterations=5)
+ assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0)))
+
+ model = KMeans.train(data, k=1, maxIterations=1, runs=5)
+ assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0)))
+
+ model = KMeans.train(data, k=1, maxIterations=1, runs=5)
+ assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0)))
+
+ model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=RANDOM)
+ assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0)))
+
+ model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=K_MEANS_PARALLEL)
+ assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0)))
+ }
+
+ test("k-means|| initialization") {
+ val points = Array(
+ Array(1.0, 2.0, 6.0),
+ Array(1.0, 3.0, 0.0),
+ Array(1.0, 4.0, 6.0),
+ Array(1.0, 0.0, 1.0),
+ Array(1.0, 1.0, 1.0)
+ )
+ val rdd = sc.parallelize(points)
+
+ // K-means|| initialization should place all clusters into distinct centers because
+ // it will make at least five passes, and it will give non-zero probability to each
+ // unselected point as long as it hasn't yet selected all of them
+
+ var model = KMeans.train(rdd, k=5, maxIterations=1)
+ assertSetsEqual(model.clusterCenters, points)
+
+ // Iterations of Lloyd's should not change the answer either
+ model = KMeans.train(rdd, k=5, maxIterations=10)
+ assertSetsEqual(model.clusterCenters, points)
+
+ // Neither should more runs
+ model = KMeans.train(rdd, k=5, maxIterations=10, runs=5)
+ assertSetsEqual(model.clusterCenters, points)
+ }
+}
diff --git a/mllib/src/test/scala/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/spark/mllib/recommendation/ALSSuite.scala
new file mode 100644
index 0000000000..2ada9ae76b
--- /dev/null
+++ b/mllib/src/test/scala/spark/mllib/recommendation/ALSSuite.scala
@@ -0,0 +1,80 @@
+package spark.mllib.recommendation
+
+import scala.util.Random
+
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.FunSuite
+
+import spark.SparkContext
+import spark.SparkContext._
+
+import org.jblas._
+
+
+class ALSSuite extends FunSuite with BeforeAndAfterAll {
+ val sc = new SparkContext("local", "test")
+
+ override def afterAll() {
+ sc.stop()
+ System.clearProperty("spark.driver.port")
+ }
+
+ test("rank-1 matrices") {
+ testALS(10, 20, 1, 15, 0.7, 0.3)
+ }
+
+ test("rank-2 matrices") {
+ testALS(20, 30, 2, 15, 0.7, 0.3)
+ }
+
+ /**
+ * Test if we can correctly factorize R = U * P where U and P are of known rank.
+ *
+ * @param users number of users
+ * @param products number of products
+ * @param features number of features (rank of problem)
+ * @param iterations number of iterations to run
+ * @param samplingRate what fraction of the user-product pairs are known
+ * @param matchThreshold max difference allowed to consider a predicted rating correct
+ */
+ def testALS(users: Int, products: Int, features: Int, iterations: Int,
+ samplingRate: Double, matchThreshold: Double)
+ {
+ val rand = new Random(42)
+
+ // Create a random matrix with uniform values from -1 to 1
+ def randomMatrix(m: Int, n: Int) =
+ new DoubleMatrix(m, n, Array.fill(m * n)(rand.nextDouble() * 2 - 1): _*)
+
+ val userMatrix = randomMatrix(users, features)
+ val productMatrix = randomMatrix(features, products)
+ val trueRatings = userMatrix.mmul(productMatrix)
+
+ val sampledRatings = {
+ for (u <- 0 until users; p <- 0 until products if rand.nextDouble() < samplingRate)
+ yield (u, p, trueRatings.get(u, p))
+ }
+
+ val model = ALS.train(sc.parallelize(sampledRatings), features, iterations)
+
+ val predictedU = new DoubleMatrix(users, features)
+ for ((u, vec) <- model.userFeatures.collect(); i <- 0 until features) {
+ predictedU.put(u, i, vec(i))
+ }
+ val predictedP = new DoubleMatrix(products, features)
+ for ((p, vec) <- model.productFeatures.collect(); i <- 0 until features) {
+ predictedP.put(p, i, vec(i))
+ }
+ val predictedRatings = predictedU.mmul(predictedP.transpose)
+
+ for (u <- 0 until users; p <- 0 until products) {
+ val prediction = predictedRatings.get(u, p)
+ val correct = trueRatings.get(u, p)
+ if (math.abs(prediction - correct) > matchThreshold) {
+ fail("Model failed to predict (%d, %d): %f vs %f\ncorr: %s\npred: %s\nU: %s\n P: %s".format(
+ u, p, correct, prediction, trueRatings, predictedRatings, predictedU, predictedP))
+ }
+ }
+ }
+}
+
diff --git a/mllib/src/test/scala/spark/mllib/regression/LogisticRegressionSuite.scala b/mllib/src/test/scala/spark/mllib/regression/LogisticRegressionSuite.scala
new file mode 100644
index 0000000000..04d3400cb4
--- /dev/null
+++ b/mllib/src/test/scala/spark/mllib/regression/LogisticRegressionSuite.scala
@@ -0,0 +1,57 @@
+package spark.mllib.regression
+
+import scala.util.Random
+
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.FunSuite
+
+import spark.SparkContext
+import spark.SparkContext._
+
+
+class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll {
+ val sc = new SparkContext("local", "test")
+
+ override def afterAll() {
+ sc.stop()
+ System.clearProperty("spark.driver.port")
+ }
+
+ // Test if we can correctly learn A, B where Y = logistic(A + B*X)
+ test("logistic regression") {
+ val nPoints = 10000
+ val rnd = new Random(42)
+
+ val x1 = Array.fill[Double](nPoints)(rnd.nextGaussian())
+
+ val A = 2.0
+ val B = -1.5
+
+ // NOTE: if U is uniform[0, 1] then ln(u) - ln(1-u) is Logistic(0,1)
+ val unifRand = new scala.util.Random(45)
+ val rLogis = (0 until nPoints).map { i =>
+ val u = unifRand.nextDouble()
+ math.log(u) - math.log(1.0-u)
+ }
+
+ // y <- A + B*x + rlogis(100)
+ // y <- as.numeric(y > 0)
+ val y = (0 until nPoints).map { i =>
+ val yVal = A + B * x1(i) + rLogis(i)
+ if (yVal > 0) 1.0 else 0.0
+ }
+
+ val testData = (0 until nPoints).map(i => (y(i).toDouble, Array(x1(i)))).toArray
+
+ val testRDD = sc.parallelize(testData, 2)
+ testRDD.cache()
+ val lr = new LogisticRegression().setStepSize(10.0)
+ .setNumIterations(20)
+
+ val model = lr.train(testRDD)
+
+ val weight0 = model.weights.get(0)
+ assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]")
+ assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]")
+ }
+}
diff --git a/mllib/src/test/scala/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/spark/mllib/regression/RidgeRegressionSuite.scala
new file mode 100644
index 0000000000..df41dbbdff
--- /dev/null
+++ b/mllib/src/test/scala/spark/mllib/regression/RidgeRegressionSuite.scala
@@ -0,0 +1,47 @@
+package spark.mllib.regression
+
+import scala.util.Random
+
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.FunSuite
+
+import spark.SparkContext
+import spark.SparkContext._
+
+
+class RidgeRegressionSuite extends FunSuite with BeforeAndAfterAll {
+ val sc = new SparkContext("local", "test")
+
+ override def afterAll() {
+ sc.stop()
+ System.clearProperty("spark.driver.port")
+ }
+
+ // Test if we can correctly learn Y = 3 + X1 + X2 when
+ // X1 and X2 are collinear.
+ test("multi-collinear variables") {
+ val rnd = new Random(43)
+ val x1 = Array.fill[Double](20)(rnd.nextGaussian())
+
+ // Pick a mean close to mean of x1
+ val rnd1 = new Random(42) //new NormalDistribution(0.1, 0.01)
+ val x2 = Array.fill[Double](20)(0.1 + rnd1.nextGaussian() * 0.01)
+
+ val xMat = (0 until 20).map(i => Array(x1(i), x2(i))).toArray
+
+ val y = xMat.map(i => 3 + i(0) + i(1))
+ val testData = (0 until 20).map(i => (y(i), xMat(i))).toArray
+
+ val testRDD = sc.parallelize(testData, 2)
+ testRDD.cache()
+ val ridgeReg = new RidgeRegression().setLowLambda(0)
+ .setHighLambda(10)
+
+ val model = ridgeReg.train(testRDD)
+
+ assert(model.intercept >= 2.9 && model.intercept <= 3.1)
+ assert(model.weights.length === 2)
+ assert(model.weights.get(0) >= 0.9 && model.weights.get(0) <= 1.1)
+ assert(model.weights.get(1) >= 0.9 && model.weights.get(1) <= 1.1)
+ }
+}
diff --git a/pom.xml b/pom.xml
index 3bcb2a3f34..48e623fa1c 100644
--- a/pom.xml
+++ b/pom.xml
@@ -56,7 +56,7 @@
<akka.version>2.0.3</akka.version>
<spray.version>1.0-M2.1</spray.version>
<spray.json.version>1.1.1</spray.json.version>
- <slf4j.version>1.6.1</slf4j.version>
+ <slf4j.version>1.7.2</slf4j.version>
<cdh.version>4.1.2</cdh.version>
<log4j.version>1.2.17</log4j.version>
@@ -109,17 +109,6 @@
<enabled>false</enabled>
</snapshots>
</repository>
- <repository>
- <id>twitter4j-repo</id>
- <name>Twitter4J Repository</name>
- <url>http://twitter4j.org/maven2/</url>
- <releases>
- <enabled>true</enabled>
- </releases>
- <snapshots>
- <enabled>false</enabled>
- </snapshots>
- </repository>
</repositories>
<pluginRepositories>
<pluginRepository>
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index b1f3f9a2ea..c487f34d4a 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -25,7 +25,7 @@ object SparkBuild extends Build {
//val HADOOP_MAJOR_VERSION = "2"
//val HADOOP_YARN = true
- lazy val root = Project("root", file("."), settings = rootSettings) aggregate(core, repl, examples, bagel, streaming)
+ lazy val root = Project("root", file("."), settings = rootSettings) aggregate(core, repl, examples, bagel, streaming, mllib)
lazy val core = Project("core", file("core"), settings = coreSettings)
@@ -37,6 +37,8 @@ object SparkBuild extends Build {
lazy val streaming = Project("streaming", file("streaming"), settings = streamingSettings) dependsOn (core)
+ lazy val mllib = Project("mllib", file("mllib"), settings = mllibSettings) dependsOn (core)
+
// A configuration to set an alternative publishLocalConfiguration
lazy val MavenCompile = config("m2r") extend(Compile)
lazy val publishLocalBoth = TaskKey[Unit]("publish-local", "publish local for m2 and ivy")
@@ -54,7 +56,7 @@ object SparkBuild extends Build {
// Fork new JVMs for tests and set Java options for those
fork := true,
- javaOptions += "-Xmx2g",
+ javaOptions += "-Xmx2500m",
// Only allow one test at a time, even across projects, since they run in the same JVM
concurrentRestrictions in Global += Tags.limit(Tags.Test, 1),
@@ -125,20 +127,20 @@ object SparkBuild extends Build {
publishMavenStyle in MavenCompile := true,
publishLocal in MavenCompile <<= publishTask(publishLocalConfiguration in MavenCompile, deliverLocal),
publishLocalBoth <<= Seq(publishLocal in MavenCompile, publishLocal).dependOn
- )
+ ) ++ net.virtualvoid.sbt.graph.Plugin.graphSettings
- val slf4jVersion = "1.6.1"
+ val slf4jVersion = "1.7.2"
val excludeJackson = ExclusionRule(organization = "org.codehaus.jackson")
val excludeNetty = ExclusionRule(organization = "org.jboss.netty")
+ val excludeAsm = ExclusionRule(organization = "asm")
def coreSettings = sharedSettings ++ Seq(
name := "spark-core",
resolvers ++= Seq(
"JBoss Repository" at "http://repository.jboss.org/nexus/content/repositories/releases/",
"Spray Repository" at "http://repo.spray.cc/",
- "Cloudera Repository" at "https://repository.cloudera.com/artifactory/cloudera-repos/",
- "Twitter4J Repository" at "http://twitter4j.org/maven2/"
+ "Cloudera Repository" at "https://repository.cloudera.com/artifactory/cloudera-repos/"
),
libraryDependencies ++= Seq(
@@ -201,11 +203,10 @@ object SparkBuild extends Build {
def examplesSettings = sharedSettings ++ Seq(
name := "spark-examples",
- resolvers ++= Seq("Apache HBase" at "https://repository.apache.org/content/repositories/releases"),
libraryDependencies ++= Seq(
"com.twitter" % "algebird-core_2.9.2" % "0.1.11",
- "org.apache.hbase" % "hbase" % "0.94.6" excludeAll(excludeNetty),
+ "org.apache.hbase" % "hbase" % "0.94.6" excludeAll(excludeNetty, excludeAsm),
"org.apache.cassandra" % "cassandra-all" % "1.2.5"
exclude("com.google.guava", "guava")
@@ -220,11 +221,21 @@ object SparkBuild extends Build {
def bagelSettings = sharedSettings ++ Seq(name := "spark-bagel")
+ def mllibSettings = sharedSettings ++ Seq(
+ name := "spark-mllib",
+ libraryDependencies ++= Seq(
+ "org.jblas" % "jblas" % "1.2.3"
+ )
+ )
+
def streamingSettings = sharedSettings ++ Seq(
name := "spark-streaming",
+ resolvers ++= Seq(
+ "Akka Repository" at "http://repo.akka.io/releases/"
+ ),
libraryDependencies ++= Seq(
"org.apache.flume" % "flume-ng-sdk" % "1.2.0" % "compile" excludeAll(excludeNetty),
- "com.github.sgroschupf" % "zkclient" % "0.1",
+ "com.github.sgroschupf" % "zkclient" % "0.1" excludeAll(excludeNetty),
"org.twitter4j" % "twitter4j-stream" % "3.0.3" excludeAll(excludeNetty),
"com.typesafe.akka" % "akka-zeromq" % "2.0.3" excludeAll(excludeNetty)
)
@@ -233,7 +244,7 @@ object SparkBuild extends Build {
def extraAssemblySettings() = Seq(test in assembly := {}) ++ Seq(
mergeStrategy in assembly := {
case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard
- case m if m.toLowerCase.matches("meta-inf/.*\\.sf$") => MergeStrategy.discard
+ case m if m.toLowerCase.matches("meta-inf.*\\.sf$") => MergeStrategy.discard
case "reference.conf" => MergeStrategy.concat
case _ => MergeStrategy.first
}
diff --git a/project/plugins.sbt b/project/plugins.sbt
index d4f2442872..f806e66481 100644
--- a/project/plugins.sbt
+++ b/project/plugins.sbt
@@ -16,3 +16,5 @@ addSbtPlugin("io.spray" %% "sbt-twirl" % "0.6.1")
//resolvers += Resolver.url("sbt-plugin-releases", new URL("http://scalasbt.artifactoryonline.com/scalasbt/sbt-plugin-releases/"))(Resolver.ivyStylePatterns)
//addSbtPlugin("com.jsuereth" % "xsbt-gpg-plugin" % "0.6")
+
+addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.7.3")
diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py
new file mode 100644
index 0000000000..78c9457b84
--- /dev/null
+++ b/python/pyspark/daemon.py
@@ -0,0 +1,164 @@
+import os
+import signal
+import socket
+import sys
+import traceback
+import multiprocessing
+from ctypes import c_bool
+from errno import EINTR, ECHILD
+from socket import AF_INET, SOCK_STREAM, SOMAXCONN
+from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN
+from pyspark.worker import main as worker_main
+from pyspark.serializers import write_int
+
+try:
+ POOLSIZE = multiprocessing.cpu_count()
+except NotImplementedError:
+ POOLSIZE = 4
+
+exit_flag = multiprocessing.Value(c_bool, False)
+
+
+def should_exit():
+ global exit_flag
+ return exit_flag.value
+
+
+def compute_real_exit_code(exit_code):
+ # SystemExit's code can be integer or string, but os._exit only accepts integers
+ import numbers
+ if isinstance(exit_code, numbers.Integral):
+ return exit_code
+ else:
+ return 1
+
+
+def worker(listen_sock):
+ # Redirect stdout to stderr
+ os.dup2(2, 1)
+ sys.stdout = sys.stderr # The sys.stdout object is different from file descriptor 1
+
+ # Manager sends SIGHUP to request termination of workers in the pool
+ def handle_sighup(*args):
+ assert should_exit()
+ signal.signal(SIGHUP, handle_sighup)
+
+ # Cleanup zombie children
+ def handle_sigchld(*args):
+ pid = status = None
+ try:
+ while (pid, status) != (0, 0):
+ pid, status = os.waitpid(0, os.WNOHANG)
+ except EnvironmentError as err:
+ if err.errno == EINTR:
+ # retry
+ handle_sigchld()
+ elif err.errno != ECHILD:
+ raise
+ signal.signal(SIGCHLD, handle_sigchld)
+
+ # Handle clients
+ while not should_exit():
+ # Wait until a client arrives or we have to exit
+ sock = None
+ while not should_exit() and sock is None:
+ try:
+ sock, addr = listen_sock.accept()
+ except EnvironmentError as err:
+ if err.errno != EINTR:
+ raise
+
+ if sock is not None:
+ # Fork a child to handle the client.
+ # The client is handled in the child so that the manager
+ # never receives SIGCHLD unless a worker crashes.
+ if os.fork() == 0:
+ # Leave the worker pool
+ signal.signal(SIGHUP, SIG_DFL)
+ listen_sock.close()
+ # Read the socket using fdopen instead of socket.makefile() because the latter
+ # seems to be very slow; note that we need to dup() the file descriptor because
+ # otherwise writes also cause a seek that makes us miss data on the read side.
+ infile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
+ outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
+ exit_code = 0
+ try:
+ worker_main(infile, outfile)
+ except SystemExit as exc:
+ exit_code = exc.code
+ finally:
+ outfile.flush()
+ sock.close()
+ os._exit(compute_real_exit_code(exit_code))
+ else:
+ sock.close()
+
+
+def launch_worker(listen_sock):
+ if os.fork() == 0:
+ try:
+ worker(listen_sock)
+ except Exception as err:
+ traceback.print_exc()
+ os._exit(1)
+ else:
+ assert should_exit()
+ os._exit(0)
+
+
+def manager():
+ # Create a new process group to corral our children
+ os.setpgid(0, 0)
+
+ # Create a listening socket on the AF_INET loopback interface
+ listen_sock = socket.socket(AF_INET, SOCK_STREAM)
+ listen_sock.bind(('127.0.0.1', 0))
+ listen_sock.listen(max(1024, 2 * POOLSIZE, SOMAXCONN))
+ listen_host, listen_port = listen_sock.getsockname()
+ write_int(listen_port, sys.stdout)
+
+ # Launch initial worker pool
+ for idx in range(POOLSIZE):
+ launch_worker(listen_sock)
+ listen_sock.close()
+
+ def shutdown():
+ global exit_flag
+ exit_flag.value = True
+
+ # Gracefully exit on SIGTERM, don't die on SIGHUP
+ signal.signal(SIGTERM, lambda signum, frame: shutdown())
+ signal.signal(SIGHUP, SIG_IGN)
+
+ # Cleanup zombie children
+ def handle_sigchld(*args):
+ try:
+ pid, status = os.waitpid(0, os.WNOHANG)
+ if status != 0 and not should_exit():
+ raise RuntimeError("worker crashed: %s, %s" % (pid, status))
+ except EnvironmentError as err:
+ if err.errno not in (ECHILD, EINTR):
+ raise
+ signal.signal(SIGCHLD, handle_sigchld)
+
+ # Initialization complete
+ sys.stdout.close()
+ try:
+ while not should_exit():
+ try:
+ # Spark tells us to exit by closing stdin
+ if os.read(0, 512) == '':
+ shutdown()
+ except EnvironmentError as err:
+ if err.errno != EINTR:
+ shutdown()
+ raise
+ finally:
+ signal.signal(SIGTERM, SIG_DFL)
+ exit_flag.value = True
+ # Send SIGHUP to notify workers of shutdown
+ os.kill(0, SIGHUP)
+
+
+if __name__ == '__main__':
+ manager()
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 115cf28cc2..5a95144983 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -46,6 +46,10 @@ def read_long(stream):
return struct.unpack("!q", length)[0]
+def write_long(value, stream):
+ stream.write(struct.pack("!q", value))
+
+
def read_int(stream):
length = stream.read(4)
if length == "":
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 6a1962d267..1e34d47365 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -12,6 +12,7 @@ import unittest
from pyspark.context import SparkContext
from pyspark.files import SparkFiles
from pyspark.java_gateway import SPARK_HOME
+from pyspark.serializers import read_int
class PySparkTestCase(unittest.TestCase):
@@ -117,5 +118,47 @@ class TestIO(PySparkTestCase):
self.sc.parallelize([1]).foreach(func)
+class TestDaemon(unittest.TestCase):
+ def connect(self, port):
+ from socket import socket, AF_INET, SOCK_STREAM
+ sock = socket(AF_INET, SOCK_STREAM)
+ sock.connect(('127.0.0.1', port))
+ # send a split index of -1 to shutdown the worker
+ sock.send("\xFF\xFF\xFF\xFF")
+ sock.close()
+ return True
+
+ def do_termination_test(self, terminator):
+ from subprocess import Popen, PIPE
+ from errno import ECONNREFUSED
+
+ # start daemon
+ daemon_path = os.path.join(os.path.dirname(__file__), "daemon.py")
+ daemon = Popen([sys.executable, daemon_path], stdin=PIPE, stdout=PIPE)
+
+ # read the port number
+ port = read_int(daemon.stdout)
+
+ # daemon should accept connections
+ self.assertTrue(self.connect(port))
+
+ # request shutdown
+ terminator(daemon)
+ time.sleep(1)
+
+ # daemon should no longer accept connections
+ with self.assertRaises(EnvironmentError) as trap:
+ self.connect(port)
+ self.assertEqual(trap.exception.errno, ECONNREFUSED)
+
+ def test_termination_stdin(self):
+ """Ensure that daemon and workers terminate when stdin is closed."""
+ self.do_termination_test(lambda daemon: daemon.stdin.close())
+
+ def test_termination_sigterm(self):
+ """Ensure that daemon and workers terminate on SIGTERM."""
+ from signal import SIGTERM
+ self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM))
+
if __name__ == "__main__":
unittest.main()
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 812e7a9da5..379bbfd4c2 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -3,6 +3,7 @@ Worker that receives input from Piped RDD.
"""
import os
import sys
+import time
import traceback
from base64 import standard_b64decode
# CloudPickler needs to be imported so that depicklers are registered using the
@@ -12,48 +13,60 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.cloudpickle import CloudPickler
from pyspark.files import SparkFiles
from pyspark.serializers import write_with_length, read_with_length, write_int, \
- read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file
+ read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file
-# Redirect stdout to stderr so that users must return values from functions.
-old_stdout = os.fdopen(os.dup(1), 'w')
-os.dup2(2, 1)
+def load_obj(infile):
+ return load_pickle(standard_b64decode(infile.readline().strip()))
-def load_obj():
- return load_pickle(standard_b64decode(sys.stdin.readline().strip()))
+def report_times(outfile, boot, init, finish):
+ write_int(-3, outfile)
+ write_long(1000 * boot, outfile)
+ write_long(1000 * init, outfile)
+ write_long(1000 * finish, outfile)
-def main():
- split_index = read_int(sys.stdin)
- spark_files_dir = load_pickle(read_with_length(sys.stdin))
+def main(infile, outfile):
+ boot_time = time.time()
+ split_index = read_int(infile)
+ if split_index == -1: # for unit tests
+ return
+ spark_files_dir = load_pickle(read_with_length(infile))
SparkFiles._root_directory = spark_files_dir
SparkFiles._is_running_on_worker = True
sys.path.append(spark_files_dir)
- num_broadcast_variables = read_int(sys.stdin)
+ num_broadcast_variables = read_int(infile)
for _ in range(num_broadcast_variables):
- bid = read_long(sys.stdin)
- value = read_with_length(sys.stdin)
+ bid = read_long(infile)
+ value = read_with_length(infile)
_broadcastRegistry[bid] = Broadcast(bid, load_pickle(value))
- func = load_obj()
- bypassSerializer = load_obj()
+ func = load_obj(infile)
+ bypassSerializer = load_obj(infile)
if bypassSerializer:
dumps = lambda x: x
else:
dumps = dump_pickle
- iterator = read_from_pickle_file(sys.stdin)
+ init_time = time.time()
+ iterator = read_from_pickle_file(infile)
try:
for obj in func(split_index, iterator):
- write_with_length(dumps(obj), old_stdout)
+ write_with_length(dumps(obj), outfile)
except Exception as e:
- write_int(-2, old_stdout)
- write_with_length(traceback.format_exc(), old_stdout)
+ write_int(-2, outfile)
+ write_with_length(traceback.format_exc(), outfile)
sys.exit(-1)
+ finish_time = time.time()
+ report_times(outfile, boot_time, init_time, finish_time)
# Mark the beginning of the accumulators section of the output
- write_int(-1, old_stdout)
+ write_int(-1, outfile)
for aid, accum in _accumulatorRegistry.items():
- write_with_length(dump_pickle((aid, accum._value)), old_stdout)
+ write_with_length(dump_pickle((aid, accum._value)), outfile)
+ write_int(-1, outfile)
if __name__ == '__main__':
- main()
+ # Redirect stdout to stderr so that users must return values from functions.
+ old_stdout = os.fdopen(os.dup(1), 'w')
+ os.dup2(2, 1)
+ main(sys.stdin, old_stdout)
diff --git a/repl/src/main/scala/spark/repl/SparkILoop.scala b/repl/src/main/scala/spark/repl/SparkILoop.scala
index 23556dbc8f..59f9d05683 100644
--- a/repl/src/main/scala/spark/repl/SparkILoop.scala
+++ b/repl/src/main/scala/spark/repl/SparkILoop.scala
@@ -822,7 +822,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
spark.repl.Main.interp.out.println("Spark context available as sc.");
spark.repl.Main.interp.out.flush();
""")
- command("import spark.SparkContext._");
+ command("import spark.SparkContext._")
}
echo("Type in expressions to have them evaluated.")
echo("Type :help for more information.")
@@ -838,7 +838,10 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
if (prop != null) prop else "local"
}
}
- sparkContext = new SparkContext(master, "Spark shell")
+ val jars = Option(System.getenv("ADD_JARS")).map(_.split(','))
+ .getOrElse(new Array[String](0))
+ .map(new java.io.File(_).getAbsolutePath)
+ sparkContext = new SparkContext(master, "Spark shell", System.getenv("SPARK_HOME"), jars)
sparkContext
}
@@ -850,6 +853,10 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
printWelcome()
echo("Initializing interpreter...")
+ // Add JARS specified in Spark's ADD_JARS variable to classpath
+ val jars = Option(System.getenv("ADD_JARS")).map(_.split(',')).getOrElse(new Array[String](0))
+ jars.foreach(settings.classpath.append(_))
+
this.settings = settings
createInterpreter()
diff --git a/repl/src/test/scala/spark/repl/ReplSuite.scala b/repl/src/test/scala/spark/repl/ReplSuite.scala
index 1c64f9b98d..f46e6d8be4 100644
--- a/repl/src/test/scala/spark/repl/ReplSuite.scala
+++ b/repl/src/test/scala/spark/repl/ReplSuite.scala
@@ -28,24 +28,25 @@ class ReplSuite extends FunSuite {
val separator = System.getProperty("path.separator")
interp.process(Array("-classpath", paths.mkString(separator)))
spark.repl.Main.interp = null
- if (interp.sparkContext != null)
+ if (interp.sparkContext != null) {
interp.sparkContext.stop()
+ }
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
System.clearProperty("spark.driver.port")
System.clearProperty("spark.hostPort")
return out.toString
}
-
+
def assertContains(message: String, output: String) {
- assert(output contains message,
+ assert(output.contains(message),
"Interpreter output did not contain '" + message + "':\n" + output)
}
-
+
def assertDoesNotContain(message: String, output: String) {
- assert(!(output contains message),
+ assert(!output.contains(message),
"Interpreter output contained '" + message + "':\n" + output)
}
-
+
test ("simple foreach with accumulator") {
val output = runInterpreter("local", """
val accum = sc.accumulator(0)
@@ -56,7 +57,7 @@ class ReplSuite extends FunSuite {
assertDoesNotContain("Exception", output)
assertContains("res1: Int = 55", output)
}
-
+
test ("external vars") {
val output = runInterpreter("local", """
var v = 7
@@ -105,7 +106,7 @@ class ReplSuite extends FunSuite {
assertContains("res0: Int = 70", output)
assertContains("res1: Int = 100", output)
}
-
+
test ("broadcast vars") {
// Test that the value that a broadcast var had when it was created is used,
// even if that variable is then modified in the driver program
@@ -143,6 +144,27 @@ class ReplSuite extends FunSuite {
assertContains("res2: Long = 3", output)
}
+ test ("local-cluster mode") {
+ val output = runInterpreter("local-cluster[1,1,512]", """
+ var v = 7
+ def getV() = v
+ sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_)
+ v = 10
+ sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_)
+ var array = new Array[Int](5)
+ val broadcastArray = sc.broadcast(array)
+ sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect
+ array(0) = 5
+ sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect
+ """)
+ assertDoesNotContain("error:", output)
+ assertDoesNotContain("Exception", output)
+ assertContains("res0: Int = 70", output)
+ assertContains("res1: Int = 100", output)
+ assertContains("res2: Array[Int] = Array(0, 0, 0, 0, 0)", output)
+ assertContains("res4: Array[Int] = Array(0, 0, 0, 0, 0)", output)
+ }
+
if (System.getenv("MESOS_NATIVE_LIBRARY") != null) {
test ("running on Mesos") {
val output = runInterpreter("localquiet", """
diff --git a/run b/run
index c0065c53f1..805466ea2c 100755
--- a/run
+++ b/run
@@ -23,29 +23,38 @@ fi
if [ "$1" = "spark.deploy.master.Master" -o "$1" = "spark.deploy.worker.Worker" ]; then
SPARK_MEM=${SPARK_DAEMON_MEMORY:-512m}
SPARK_DAEMON_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS -Dspark.akka.logLifecycleEvents=true"
- SPARK_JAVA_OPTS=$SPARK_DAEMON_JAVA_OPTS # Empty by default
+ # Do not overwrite SPARK_JAVA_OPTS environment variable in this script
+ OUR_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS" # Empty by default
+else
+ OUR_JAVA_OPTS="$SPARK_JAVA_OPTS"
fi
# Add java opts for master, worker, executor. The opts maybe null
case "$1" in
'spark.deploy.master.Master')
- SPARK_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_MASTER_OPTS"
+ OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_MASTER_OPTS"
;;
'spark.deploy.worker.Worker')
- SPARK_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_WORKER_OPTS"
+ OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_WORKER_OPTS"
;;
'spark.executor.StandaloneExecutorBackend')
- SPARK_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_EXECUTOR_OPTS"
+ OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_EXECUTOR_OPTS"
;;
'spark.executor.MesosExecutorBackend')
- SPARK_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_EXECUTOR_OPTS"
+ OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_EXECUTOR_OPTS"
;;
'spark.repl.Main')
- SPARK_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_REPL_OPTS"
+ OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_REPL_OPTS"
;;
esac
+# Figure out whether to run our class with java or with the scala launcher.
+# In most cases, we'd prefer to execute our process with java because scala
+# creates a shell script as the parent of its Java process, which makes it
+# hard to kill the child with stuff like Process.destroy(). However, for
+# the Spark shell, the wrapper is necessary to properly reset the terminal
+# when we exit, so we allow it to set a variable to launch with scala.
if [ "$SPARK_LAUNCH_WITH_SCALA" == "1" ]; then
if [ "$SCALA_HOME" ]; then
RUNNER="${SCALA_HOME}/bin/scala"
@@ -58,14 +67,15 @@ if [ "$SPARK_LAUNCH_WITH_SCALA" == "1" ]; then
fi
fi
else
- if [ `command -v java` ]; then
- RUNNER="java"
+ if [ -n "${JAVA_HOME}" ]; then
+ RUNNER="${JAVA_HOME}/bin/java"
else
- if [ -z "$JAVA_HOME" ]; then
+ if [ `command -v java` ]; then
+ RUNNER="java"
+ else
echo "JAVA_HOME is not set" >&2
exit 1
fi
- RUNNER="${JAVA_HOME}/bin/java"
fi
if [ -z "$SCALA_LIBRARY_PATH" ]; then
if [ -z "$SCALA_HOME" ]; then
@@ -84,7 +94,7 @@ fi
export SPARK_MEM
# Set JAVA_OPTS to be able to load native libraries and to set heap size
-JAVA_OPTS="$SPARK_JAVA_OPTS"
+JAVA_OPTS="$OUR_JAVA_OPTS"
JAVA_OPTS="$JAVA_OPTS -Djava.library.path=$SPARK_LIBRARY_PATH"
JAVA_OPTS="$JAVA_OPTS -Xms$SPARK_MEM -Xmx$SPARK_MEM"
# Load extra JAVA_OPTS from conf/java-opts, if it exists
@@ -92,14 +102,11 @@ if [ -e $FWDIR/conf/java-opts ] ; then
JAVA_OPTS="$JAVA_OPTS `cat $FWDIR/conf/java-opts`"
fi
export JAVA_OPTS
+# Attention: when changing the way the JAVA_OPTS are assembled, the change must be reflected in ExecutorRunner.scala!
CORE_DIR="$FWDIR/core"
-REPL_DIR="$FWDIR/repl"
-REPL_BIN_DIR="$FWDIR/repl-bin"
EXAMPLES_DIR="$FWDIR/examples"
-BAGEL_DIR="$FWDIR/bagel"
-STREAMING_DIR="$FWDIR/streaming"
-PYSPARK_DIR="$FWDIR/python"
+REPL_DIR="$FWDIR/repl"
# Exit if the user hasn't compiled Spark
if [ ! -e "$CORE_DIR/target" ]; then
@@ -114,33 +121,9 @@ if [[ "$@" = *repl* && ! -e "$REPL_DIR/target" ]]; then
exit 1
fi
-# Build up classpath
-CLASSPATH="$SPARK_CLASSPATH"
-CLASSPATH="$CLASSPATH:$FWDIR/conf"
-CLASSPATH="$CLASSPATH:$CORE_DIR/target/scala-$SCALA_VERSION/classes"
-if [ -n "$SPARK_TESTING" ] ; then
- CLASSPATH="$CLASSPATH:$CORE_DIR/target/scala-$SCALA_VERSION/test-classes"
- CLASSPATH="$CLASSPATH:$STREAMING_DIR/target/scala-$SCALA_VERSION/test-classes"
-fi
-CLASSPATH="$CLASSPATH:$CORE_DIR/src/main/resources"
-CLASSPATH="$CLASSPATH:$REPL_DIR/target/scala-$SCALA_VERSION/classes"
-CLASSPATH="$CLASSPATH:$EXAMPLES_DIR/target/scala-$SCALA_VERSION/classes"
-CLASSPATH="$CLASSPATH:$STREAMING_DIR/target/scala-$SCALA_VERSION/classes"
-CLASSPATH="$CLASSPATH:$STREAMING_DIR/lib/org/apache/kafka/kafka/0.7.2-spark/*" # <-- our in-project Kafka Jar
-if [ -e "$FWDIR/lib_managed" ]; then
- CLASSPATH="$CLASSPATH:$FWDIR/lib_managed/jars/*"
- CLASSPATH="$CLASSPATH:$FWDIR/lib_managed/bundles/*"
-fi
-CLASSPATH="$CLASSPATH:$REPL_DIR/lib/*"
-if [ -e $REPL_BIN_DIR/target ]; then
- for jar in `find "$REPL_BIN_DIR/target" -name 'spark-repl-*-shaded-hadoop*.jar'`; do
- CLASSPATH="$CLASSPATH:$jar"
- done
-fi
-CLASSPATH="$CLASSPATH:$BAGEL_DIR/target/scala-$SCALA_VERSION/classes"
-for jar in `find $PYSPARK_DIR/lib -name '*jar'`; do
- CLASSPATH="$CLASSPATH:$jar"
-done
+# Compute classpath using external script
+CLASSPATH=`$FWDIR/bin/compute-classpath.sh`
+export CLASSPATH
# Figure out the JAR file that our examples were packaged into. This includes a bit of a hack
# to avoid the -sources and -doc packages that are built by publish-local.
@@ -148,37 +131,16 @@ if [ -e "$EXAMPLES_DIR/target/scala-$SCALA_VERSION/spark-examples"*[0-9T].jar ];
# Use the JAR from the SBT build
export SPARK_EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/scala-$SCALA_VERSION/spark-examples"*[0-9T].jar`
fi
-if [ -e "$EXAMPLES_DIR/target/spark-examples-"*hadoop[12].jar ]; then
+if [ -e "$EXAMPLES_DIR/target/spark-examples"*[0-9T].jar ]; then
# Use the JAR from the Maven build
- export SPARK_EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/spark-examples-"*hadoop[12].jar`
-fi
-
-# Add hadoop conf dir - else FileSystem.*, etc fail !
-# Note, this assumes that there is either a HADOOP_CONF_DIR or YARN_CONF_DIR which hosts
-# the configurtion files.
-if [ "x" != "x$HADOOP_CONF_DIR" ]; then
- CLASSPATH="$CLASSPATH:$HADOOP_CONF_DIR"
-fi
-if [ "x" != "x$YARN_CONF_DIR" ]; then
- CLASSPATH="$CLASSPATH:$YARN_CONF_DIR"
+ export SPARK_EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/spark-examples"*[0-9T].jar`
fi
-
-# Figure out whether to run our class with java or with the scala launcher.
-# In most cases, we'd prefer to execute our process with java because scala
-# creates a shell script as the parent of its Java process, which makes it
-# hard to kill the child with stuff like Process.destroy(). However, for
-# the Spark shell, the wrapper is necessary to properly reset the terminal
-# when we exit, so we allow it to set a variable to launch with scala.
if [ "$SPARK_LAUNCH_WITH_SCALA" == "1" ]; then
EXTRA_ARGS="" # Java options will be passed to scala as JAVA_OPTS
else
- CLASSPATH="$CLASSPATH:$SCALA_LIBRARY_PATH/scala-library.jar"
- CLASSPATH="$CLASSPATH:$SCALA_LIBRARY_PATH/scala-compiler.jar"
- CLASSPATH="$CLASSPATH:$SCALA_LIBRARY_PATH/jline.jar"
# The JVM doesn't read JAVA_OPTS by default so we need to pass it in
EXTRA_ARGS="$JAVA_OPTS"
fi
-export CLASSPATH # Needed for spark-shell
exec "$RUNNER" -cp "$CLASSPATH" $EXTRA_ARGS "$@"
diff --git a/run2.cmd b/run2.cmd
index c6f43dde5b..a9c4df180f 100644
--- a/run2.cmd
+++ b/run2.cmd
@@ -23,7 +23,9 @@ if "%1"=="spark.deploy.worker.Worker" set RUNNING_DAEMON=1
if "x%SPARK_DAEMON_MEMORY%" == "x" set SPARK_DAEMON_MEMORY=512m
set SPARK_DAEMON_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS% -Dspark.akka.logLifecycleEvents=true
if "%RUNNING_DAEMON%"=="1" set SPARK_MEM=%SPARK_DAEMON_MEMORY%
-if "%RUNNING_DAEMON%"=="1" set SPARK_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS%
+rem Do not overwrite SPARK_JAVA_OPTS environment variable in this script
+if "%RUNNING_DAEMON%"=="0" set OUR_JAVA_OPTS=%SPARK_JAVA_OPTS%
+if "%RUNNING_DAEMON%"=="1" set OUR_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS%
rem Check that SCALA_HOME has been specified
if not "x%SCALA_HOME%"=="x" goto scala_exists
@@ -31,50 +33,22 @@ if not "x%SCALA_HOME%"=="x" goto scala_exists
goto exit
:scala_exists
-rem If the user specifies a Mesos JAR, put it before our included one on the classpath
-set MESOS_CLASSPATH=
-if not "x%MESOS_JAR%"=="x" set MESOS_CLASSPATH=%MESOS_JAR%
-
rem Figure out how much memory to use per executor and set it as an environment
rem variable so that our process sees it and can report it to Mesos
if "x%SPARK_MEM%"=="x" set SPARK_MEM=512m
rem Set JAVA_OPTS to be able to load native libraries and to set heap size
-set JAVA_OPTS=%SPARK_JAVA_OPTS% -Djava.library.path=%SPARK_LIBRARY_PATH% -Xms%SPARK_MEM% -Xmx%SPARK_MEM%
-rem Load extra JAVA_OPTS from conf/java-opts, if it exists
-if exist "%FWDIR%conf\java-opts.cmd" call "%FWDIR%conf\java-opts.cmd"
+set JAVA_OPTS=%OUR_JAVA_OPTS% -Djava.library.path=%SPARK_LIBRARY_PATH% -Xms%SPARK_MEM% -Xmx%SPARK_MEM%
+rem Attention: when changing the way the JAVA_OPTS are assembled, the change must be reflected in ExecutorRunner.scala!
set CORE_DIR=%FWDIR%core
-set REPL_DIR=%FWDIR%repl
set EXAMPLES_DIR=%FWDIR%examples
-set BAGEL_DIR=%FWDIR%bagel
-set STREAMING_DIR=%FWDIR%streaming
-set PYSPARK_DIR=%FWDIR%python
-
-rem Build up classpath
-set CLASSPATH=%SPARK_CLASSPATH%;%MESOS_CLASSPATH%;%FWDIR%conf;%CORE_DIR%\target\scala-%SCALA_VERSION%\classes
-set CLASSPATH=%CLASSPATH%;%CORE_DIR%\target\scala-%SCALA_VERSION%\test-classes;%CORE_DIR%\src\main\resources
-set CLASSPATH=%CLASSPATH%;%STREAMING_DIR%\target\scala-%SCALA_VERSION%\classes;%STREAMING_DIR%\target\scala-%SCALA_VERSION%\test-classes
-set CLASSPATH=%CLASSPATH%;%STREAMING_DIR%\lib\org\apache\kafka\kafka\0.7.2-spark\*
-set CLASSPATH=%CLASSPATH%;%REPL_DIR%\target\scala-%SCALA_VERSION%\classes;%EXAMPLES_DIR%\target\scala-%SCALA_VERSION%\classes
-set CLASSPATH=%CLASSPATH%;%FWDIR%lib_managed\jars\*
-set CLASSPATH=%CLASSPATH%;%FWDIR%lib_managed\bundles\*
-set CLASSPATH=%CLASSPATH%;%FWDIR%repl\lib\*
-set CLASSPATH=%CLASSPATH%;%FWDIR%python\lib\*
-set CLASSPATH=%CLASSPATH%;%BAGEL_DIR%\target\scala-%SCALA_VERSION%\classes
-
-rem Add hadoop conf dir - else FileSystem.*, etc fail
-rem Note, this assumes that there is either a HADOOP_CONF_DIR or YARN_CONF_DIR which hosts
-rem the configurtion files.
-if "x%HADOOP_CONF_DIR%"=="x" goto no_hadoop_conf_dir
- set CLASSPATH=%CLASSPATH%;%HADOOP_CONF_DIR%
-:no_hadoop_conf_dir
-
-if "x%YARN_CONF_DIR%"=="x" goto no_yarn_conf_dir
- set CLASSPATH=%CLASSPATH%;%YARN_CONF_DIR%
-:no_yarn_conf_dir
-
+set REPL_DIR=%FWDIR%repl
+rem Compute classpath using external script
+set DONT_PRINT_CLASSPATH=1
+call "%FWDIR%bin\compute-classpath.cmd"
+set DONT_PRINT_CLASSPATH=0
rem Figure out the JAR file that our examples were packaged into.
rem First search in the build path from SBT:
diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala
index e1be5ef51c..9be7926a4a 100644
--- a/streaming/src/main/scala/spark/streaming/DStream.scala
+++ b/streaming/src/main/scala/spark/streaming/DStream.scala
@@ -441,7 +441,12 @@ abstract class DStream[T: ClassManifest] (
* Return a new DStream in which each RDD has a single element generated by counting each RDD
* of this DStream.
*/
- def count(): DStream[Long] = this.map(_ => 1L).reduce(_ + _)
+ def count(): DStream[Long] = {
+ this.map(_ => (null, 1L))
+ .transform(_.union(context.sparkContext.makeRDD(Seq((null, 0L)), 1)))
+ .reduceByKey(_ + _)
+ .map(_._2)
+ }
/**
* Return a new DStream in which each RDD contains the counts of each distinct value in
@@ -457,7 +462,7 @@ abstract class DStream[T: ClassManifest] (
* this DStream will be registered as an output stream and therefore materialized.
*/
def foreach(foreachFunc: RDD[T] => Unit) {
- foreach((r: RDD[T], t: Time) => foreachFunc(r))
+ this.foreach((r: RDD[T], t: Time) => foreachFunc(r))
}
/**
diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala
index b8b60aab43..36b841af8f 100644
--- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala
@@ -27,6 +27,8 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat
import org.apache.hadoop.fs.Path
import twitter4j.Status
+import twitter4j.auth.Authorization
+
/**
* A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic
@@ -186,10 +188,11 @@ class StreamingContext private (
* should be same.
*/
def actorStream[T: ClassManifest](
- props: Props,
- name: String,
- storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2,
- supervisorStrategy: SupervisorStrategy = ReceiverSupervisorStrategy.defaultStrategy): DStream[T] = {
+ props: Props,
+ name: String,
+ storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2,
+ supervisorStrategy: SupervisorStrategy = ReceiverSupervisorStrategy.defaultStrategy
+ ): DStream[T] = {
networkStream(new ActorReceiver[T](props, name, storageLevel, supervisorStrategy))
}
@@ -197,9 +200,10 @@ class StreamingContext private (
* Create an input stream that receives messages pushed by a zeromq publisher.
* @param publisherUrl Url of remote zeromq publisher
* @param subscribe topic to subscribe to
- * @param bytesToObjects A zeroMQ stream publishes sequence of frames for each topic and each frame has sequence
- * of byte thus it needs the converter(which might be deserializer of bytes)
- * to translate from sequence of sequence of bytes, where sequence refer to a frame
+ * @param bytesToObjects A zeroMQ stream publishes sequence of frames for each topic
+ * and each frame has sequence of byte thus it needs the converter
+ * (which might be deserializer of bytes) to translate from sequence
+ * of sequence of bytes, where sequence refer to a frame
* and sub sequence refer to its payload.
* @param storageLevel RDD storage level. Defaults to memory-only.
*/
@@ -215,24 +219,39 @@ class StreamingContext private (
}
/**
- * Create an input stream that pulls messages form a Kafka Broker.
+ * Create an input stream that pulls messages from a Kafka Broker.
* @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..).
* @param groupId The group id for this consumer.
* @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
- * in its own thread.
- * @param initialOffsets Optional initial offsets for each of the partitions to consume.
- * By default the value is pulled from zookeper.
+ * in its own thread.
* @param storageLevel Storage level to use for storing the received objects
* (default: StorageLevel.MEMORY_AND_DISK_SER_2)
*/
- def kafkaStream[T: ClassManifest](
+ def kafkaStream(
zkQuorum: String,
groupId: String,
topics: Map[String, Int],
- initialOffsets: Map[KafkaPartitionKey, Long] = Map[KafkaPartitionKey, Long](),
storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2
+ ): DStream[String] = {
+ val kafkaParams = Map[String, String](
+ "zk.connect" -> zkQuorum, "groupid" -> groupId, "zk.connectiontimeout.ms" -> "10000")
+ kafkaStream[String, kafka.serializer.StringDecoder](kafkaParams, topics, storageLevel)
+ }
+
+ /**
+ * Create an input stream that pulls messages from a Kafka Broker.
+ * @param kafkaParams Map of kafka configuration paramaters.
+ * See: http://kafka.apache.org/configuration.html
+ * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
+ * in its own thread.
+ * @param storageLevel Storage level to use for storing the received objects
+ */
+ def kafkaStream[T: ClassManifest, D <: kafka.serializer.Decoder[_]: Manifest](
+ kafkaParams: Map[String, String],
+ topics: Map[String, Int],
+ storageLevel: StorageLevel
): DStream[T] = {
- val inputStream = new KafkaInputDStream[T](this, zkQuorum, groupId, topics, initialOffsets, storageLevel)
+ val inputStream = new KafkaInputDStream[T, D](this, kafkaParams, topics, storageLevel)
registerInputStream(inputStream)
inputStream
}
@@ -362,18 +381,18 @@ class StreamingContext private (
/**
* Create a input stream that returns tweets received from Twitter.
- * @param username Twitter username
- * @param password Twitter password
+ * @param twitterAuth Twitter4J authentication, or None to use Twitter4J's default OAuth
+ * authorization; this uses the system properties twitter4j.oauth.consumerKey,
+ * .consumerSecret, .accessToken and .accessTokenSecret.
* @param filters Set of filter strings to get only those tweets that match them
* @param storageLevel Storage level to use for storing the received objects
*/
def twitterStream(
- username: String,
- password: String,
+ twitterAuth: Option[Authorization] = None,
filters: Seq[String] = Nil,
storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2
): DStream[Status] = {
- val inputStream = new TwitterInputDStream(this, username, password, filters, storageLevel)
+ val inputStream = new TwitterInputDStream(this, twitterAuth, filters, storageLevel)
registerInputStream(inputStream)
inputStream
}
@@ -397,7 +416,8 @@ class StreamingContext private (
* it will process either one or all of the RDDs returned by the queue.
* @param queue Queue of RDDs
* @param oneAtATime Whether only one RDD should be consumed from the queue in every interval
- * @param defaultRDD Default RDD is returned by the DStream when the queue is empty. Set as null if no RDD should be returned when empty
+ * @param defaultRDD Default RDD is returned by the DStream when the queue is empty.
+ * Set as null if no RDD should be returned when empty
* @tparam T Type of objects in the RDD
*/
def queueStream[T: ClassManifest](
diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala
index 3d149a742c..ed7b789d98 100644
--- a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala
+++ b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala
@@ -4,23 +4,18 @@ import spark.streaming._
import receivers.{ActorReceiver, ReceiverSupervisorStrategy}
import spark.streaming.dstream._
import spark.storage.StorageLevel
-
import spark.api.java.function.{Function => JFunction, Function2 => JFunction2}
import spark.api.java.{JavaSparkContext, JavaRDD}
-
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
-
import twitter4j.Status
-
import akka.actor.Props
import akka.actor.SupervisorStrategy
import akka.zeromq.Subscribe
-
import scala.collection.JavaConversions._
-
import java.lang.{Long => JLong, Integer => JInt}
import java.io.InputStream
import java.util.{Map => JMap}
+import twitter4j.auth.Authorization
/**
* A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic
@@ -121,14 +116,15 @@ class JavaStreamingContext(val ssc: StreamingContext) {
* @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
* in its own thread.
*/
- def kafkaStream[T](
+ def kafkaStream(
zkQuorum: String,
groupId: String,
topics: JMap[String, JInt])
- : JavaDStream[T] = {
- implicit val cmt: ClassManifest[T] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
- ssc.kafkaStream[T](zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*))
+ : JavaDStream[String] = {
+ implicit val cmt: ClassManifest[String] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[String]]
+ ssc.kafkaStream(zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*),
+ StorageLevel.MEMORY_ONLY_SER_2)
}
/**
@@ -136,49 +132,45 @@ class JavaStreamingContext(val ssc: StreamingContext) {
* @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..).
* @param groupId The group id for this consumer.
* @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
- * in its own thread.
- * @param initialOffsets Optional initial offsets for each of the partitions to consume.
- * By default the value is pulled from zookeper.
+ * in its own thread.
+ * @param storageLevel RDD storage level. Defaults to memory-only
+ *
*/
- def kafkaStream[T](
+ def kafkaStream(
zkQuorum: String,
groupId: String,
topics: JMap[String, JInt],
- initialOffsets: JMap[KafkaPartitionKey, JLong])
- : JavaDStream[T] = {
- implicit val cmt: ClassManifest[T] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
- ssc.kafkaStream[T](
- zkQuorum,
- groupId,
- Map(topics.mapValues(_.intValue()).toSeq: _*),
- Map(initialOffsets.mapValues(_.longValue()).toSeq: _*))
+ storageLevel: StorageLevel)
+ : JavaDStream[String] = {
+ implicit val cmt: ClassManifest[String] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[String]]
+ ssc.kafkaStream(zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*),
+ storageLevel)
}
/**
* Create an input stream that pulls messages form a Kafka Broker.
- * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..).
- * @param groupId The group id for this consumer.
+ * @param typeClass Type of RDD
+ * @param decoderClass Type of kafka decoder
+ * @param kafkaParams Map of kafka configuration paramaters.
+ * See: http://kafka.apache.org/configuration.html
* @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
* in its own thread.
- * @param initialOffsets Optional initial offsets for each of the partitions to consume.
- * By default the value is pulled from zookeper.
* @param storageLevel RDD storage level. Defaults to memory-only
*/
- def kafkaStream[T](
- zkQuorum: String,
- groupId: String,
+ def kafkaStream[T, D <: kafka.serializer.Decoder[_]](
+ typeClass: Class[T],
+ decoderClass: Class[D],
+ kafkaParams: JMap[String, String],
topics: JMap[String, JInt],
- initialOffsets: JMap[KafkaPartitionKey, JLong],
storageLevel: StorageLevel)
: JavaDStream[T] = {
implicit val cmt: ClassManifest[T] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
- ssc.kafkaStream[T](
- zkQuorum,
- groupId,
+ implicit val cmd: Manifest[D] = implicitly[Manifest[AnyRef]].asInstanceOf[Manifest[D]]
+ ssc.kafkaStream[T, D](
+ kafkaParams.toMap,
Map(topics.mapValues(_.intValue()).toSeq: _*),
- Map(initialOffsets.mapValues(_.longValue()).toSeq: _*),
storageLevel)
}
@@ -315,47 +307,76 @@ class JavaStreamingContext(val ssc: StreamingContext) {
/**
* Create a input stream that returns tweets received from Twitter.
- * @param username Twitter username
- * @param password Twitter password
+ * @param twitterAuth Twitter4J Authorization object
+ * @param filters Set of filter strings to get only those tweets that match them
+ * @param storageLevel Storage level to use for storing the received objects
+ */
+ def twitterStream(
+ twitterAuth: Authorization,
+ filters: Array[String],
+ storageLevel: StorageLevel
+ ): JavaDStream[Status] = {
+ ssc.twitterStream(Some(twitterAuth), filters, storageLevel)
+ }
+
+ /**
+ * Create a input stream that returns tweets received from Twitter using Twitter4J's default
+ * OAuth authentication; this requires the system properties twitter4j.oauth.consumerKey,
+ * .consumerSecret, .accessToken and .accessTokenSecret to be set.
* @param filters Set of filter strings to get only those tweets that match them
* @param storageLevel Storage level to use for storing the received objects
*/
def twitterStream(
- username: String,
- password: String,
filters: Array[String],
storageLevel: StorageLevel
): JavaDStream[Status] = {
- ssc.twitterStream(username, password, filters, storageLevel)
+ ssc.twitterStream(None, filters, storageLevel)
}
/**
* Create a input stream that returns tweets received from Twitter.
- * @param username Twitter username
- * @param password Twitter password
+ * @param twitterAuth Twitter4J Authorization
* @param filters Set of filter strings to get only those tweets that match them
*/
def twitterStream(
- username: String,
- password: String,
+ twitterAuth: Authorization,
filters: Array[String]
): JavaDStream[Status] = {
- ssc.twitterStream(username, password, filters)
+ ssc.twitterStream(Some(twitterAuth), filters)
+ }
+
+ /**
+ * Create a input stream that returns tweets received from Twitter using Twitter4J's default
+ * OAuth authentication; this requires the system properties twitter4j.oauth.consumerKey,
+ * .consumerSecret, .accessToken and .accessTokenSecret to be set.
+ * @param filters Set of filter strings to get only those tweets that match them
+ */
+ def twitterStream(
+ filters: Array[String]
+ ): JavaDStream[Status] = {
+ ssc.twitterStream(None, filters)
}
/**
* Create a input stream that returns tweets received from Twitter.
- * @param username Twitter username
- * @param password Twitter password
+ * @param twitterAuth Twitter4J Authorization
*/
def twitterStream(
- username: String,
- password: String
+ twitterAuth: Authorization
): JavaDStream[Status] = {
- ssc.twitterStream(username, password)
+ ssc.twitterStream(Some(twitterAuth))
}
/**
+ * Create a input stream that returns tweets received from Twitter using Twitter4J's default
+ * OAuth authentication; this requires the system properties twitter4j.oauth.consumerKey,
+ * .consumerSecret, .accessToken and .accessTokenSecret to be set.
+ */
+ def twitterStream(): JavaDStream[Status] = {
+ ssc.twitterStream()
+ }
+
+ /**
* Create an input stream with any arbitrary user implemented actor receiver.
* @param props Props object defining creation of the actor
* @param name Name of the actor
diff --git a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala
index ddd9becf32..55d2957be4 100644
--- a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala
@@ -9,58 +9,51 @@ import java.util.concurrent.Executors
import kafka.consumer._
import kafka.message.{Message, MessageSet, MessageAndMetadata}
-import kafka.serializer.StringDecoder
+import kafka.serializer.Decoder
import kafka.utils.{Utils, ZKGroupTopicDirs}
import kafka.utils.ZkUtils._
+import kafka.utils.ZKStringSerializer
+import org.I0Itec.zkclient._
import scala.collection.Map
import scala.collection.mutable.HashMap
import scala.collection.JavaConversions._
-// Key for a specific Kafka Partition: (broker, topic, group, part)
-case class KafkaPartitionKey(brokerId: Int, topic: String, groupId: String, partId: Int)
-
/**
* Input stream that pulls messages from a Kafka Broker.
*
- * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..).
- * @param groupId The group id for this consumer.
+ * @param kafkaParams Map of kafka configuration paramaters. See: http://kafka.apache.org/configuration.html
* @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
* in its own thread.
- * @param initialOffsets Optional initial offsets for each of the partitions to consume.
- * By default the value is pulled from zookeper.
* @param storageLevel RDD storage level.
*/
private[streaming]
-class KafkaInputDStream[T: ClassManifest](
+class KafkaInputDStream[T: ClassManifest, D <: Decoder[_]: Manifest](
@transient ssc_ : StreamingContext,
- zkQuorum: String,
- groupId: String,
+ kafkaParams: Map[String, String],
topics: Map[String, Int],
- initialOffsets: Map[KafkaPartitionKey, Long],
storageLevel: StorageLevel
) extends NetworkInputDStream[T](ssc_ ) with Logging {
def getReceiver(): NetworkReceiver[T] = {
- new KafkaReceiver(zkQuorum, groupId, topics, initialOffsets, storageLevel)
+ new KafkaReceiver[T, D](kafkaParams, topics, storageLevel)
.asInstanceOf[NetworkReceiver[T]]
}
}
private[streaming]
-class KafkaReceiver(zkQuorum: String, groupId: String,
- topics: Map[String, Int], initialOffsets: Map[KafkaPartitionKey, Long],
- storageLevel: StorageLevel) extends NetworkReceiver[Any] {
-
- // Timeout for establishing a connection to Zookeper in ms.
- val ZK_TIMEOUT = 10000
+class KafkaReceiver[T: ClassManifest, D <: Decoder[_]: Manifest](
+ kafkaParams: Map[String, String],
+ topics: Map[String, Int],
+ storageLevel: StorageLevel
+ ) extends NetworkReceiver[Any] {
// Handles pushing data into the BlockManager
lazy protected val blockGenerator = new BlockGenerator(storageLevel)
// Connection to Kafka
- var consumerConnector : ZookeeperConsumerConnector = null
+ var consumerConnector : ConsumerConnector = null
def onStop() {
blockGenerator.stop()
@@ -73,54 +66,59 @@ class KafkaReceiver(zkQuorum: String, groupId: String,
// In case we are using multiple Threads to handle Kafka Messages
val executorPool = Executors.newFixedThreadPool(topics.values.reduce(_ + _))
- logInfo("Starting Kafka Consumer Stream with group: " + groupId)
- logInfo("Initial offsets: " + initialOffsets.toString)
+ logInfo("Starting Kafka Consumer Stream with group: " + kafkaParams("groupid"))
- // Zookeper connection properties
+ // Kafka connection properties
val props = new Properties()
- props.put("zk.connect", zkQuorum)
- props.put("zk.connectiontimeout.ms", ZK_TIMEOUT.toString)
- props.put("groupid", groupId)
+ kafkaParams.foreach(param => props.put(param._1, param._2))
// Create the connection to the cluster
- logInfo("Connecting to Zookeper: " + zkQuorum)
+ logInfo("Connecting to Zookeper: " + kafkaParams("zk.connect"))
val consumerConfig = new ConsumerConfig(props)
- consumerConnector = Consumer.create(consumerConfig).asInstanceOf[ZookeeperConsumerConnector]
- logInfo("Connected to " + zkQuorum)
+ consumerConnector = Consumer.create(consumerConfig)
+ logInfo("Connected to " + kafkaParams("zk.connect"))
- // If specified, set the topic offset
- setOffsets(initialOffsets)
+ // When autooffset.reset is defined, it is our responsibility to try and whack the
+ // consumer group zk node.
+ if (kafkaParams.contains("autooffset.reset")) {
+ tryZookeeperConsumerGroupCleanup(kafkaParams("zk.connect"), kafkaParams("groupid"))
+ }
// Create Threads for each Topic/Message Stream we are listening
- val topicMessageStreams = consumerConnector.createMessageStreams(topics, new StringDecoder())
+ val decoder = manifest[D].erasure.newInstance.asInstanceOf[Decoder[T]]
+ val topicMessageStreams = consumerConnector.createMessageStreams(topics, decoder)
// Start the messages handler for each partition
topicMessageStreams.values.foreach { streams =>
streams.foreach { stream => executorPool.submit(new MessageHandler(stream)) }
}
-
- }
-
- // Overwrites the offets in Zookeper.
- private def setOffsets(offsets: Map[KafkaPartitionKey, Long]) {
- offsets.foreach { case(key, offset) =>
- val topicDirs = new ZKGroupTopicDirs(key.groupId, key.topic)
- val partitionName = key.brokerId + "-" + key.partId
- updatePersistentPath(consumerConnector.zkClient,
- topicDirs.consumerOffsetDir + "/" + partitionName, offset.toString)
- }
}
// Handles Kafka Messages
- private class MessageHandler(stream: KafkaStream[String]) extends Runnable {
+ private class MessageHandler[T: ClassManifest](stream: KafkaStream[T]) extends Runnable {
def run() {
logInfo("Starting MessageHandler.")
- stream.takeWhile { msgAndMetadata =>
+ for (msgAndMetadata <- stream) {
blockGenerator += msgAndMetadata.message
- // Keep on handling messages
-
- true
}
}
}
+
+ // It is our responsibility to delete the consumer group when specifying autooffset.reset. This is because
+ // Kafka 0.7.2 only honors this param when the group is not in zookeeper.
+ //
+ // The kafka high level consumer doesn't expose setting offsets currently, this is a trick copied from Kafkas'
+ // ConsoleConsumer. See code related to 'autooffset.reset' when it is set to 'smallest'/'largest':
+ // https://github.com/apache/kafka/blob/0.7.2/core/src/main/scala/kafka/consumer/ConsoleConsumer.scala
+ private def tryZookeeperConsumerGroupCleanup(zkUrl: String, groupId: String) {
+ try {
+ val dir = "/consumers/" + groupId
+ logInfo("Cleaning up temporary zookeeper data under " + dir + ".")
+ val zk = new ZkClient(zkUrl, 30*1000, 30*1000, ZKStringSerializer)
+ zk.deleteRecursive(dir)
+ zk.close()
+ } catch {
+ case _ => // swallow
+ }
+ }
}
diff --git a/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala
index 7385474963..122a529bb7 100644
--- a/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala
@@ -140,12 +140,10 @@ abstract class NetworkReceiver[T: ClassManifest]() extends Serializable with Log
/**
- * Pushes a block (as iterator of values) into the block manager.
+ * Pushes a block (as an ArrayBuffer filled with data) into the block manager.
*/
- def pushBlock(blockId: String, iterator: Iterator[T], metadata: Any, level: StorageLevel) {
- val buffer = new ArrayBuffer[T] ++ iterator
- env.blockManager.put(blockId, buffer.asInstanceOf[ArrayBuffer[Any]], level)
-
+ def pushBlock(blockId: String, arrayBuffer: ArrayBuffer[T], metadata: Any, level: StorageLevel) {
+ env.blockManager.put(blockId, arrayBuffer.asInstanceOf[ArrayBuffer[Any]], level)
actor ! ReportBlock(blockId, metadata)
}
@@ -195,10 +193,10 @@ abstract class NetworkReceiver[T: ClassManifest]() extends Serializable with Log
class BlockGenerator(storageLevel: StorageLevel)
extends Serializable with Logging {
- case class Block(id: String, iterator: Iterator[T], metadata: Any = null)
+ case class Block(id: String, buffer: ArrayBuffer[T], metadata: Any = null)
val clock = new SystemClock()
- val blockInterval = 200L
+ val blockInterval = System.getProperty("spark.streaming.blockInterval", "200").toLong
val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer)
val blockStorageLevel = storageLevel
val blocksForPushing = new ArrayBlockingQueue[Block](1000)
@@ -222,17 +220,13 @@ abstract class NetworkReceiver[T: ClassManifest]() extends Serializable with Log
currentBuffer += obj
}
- private def createBlock(blockId: String, iterator: Iterator[T]) : Block = {
- new Block(blockId, iterator)
- }
-
private def updateCurrentBuffer(time: Long) {
try {
val newBlockBuffer = currentBuffer
currentBuffer = new ArrayBuffer[T]
if (newBlockBuffer.size > 0) {
val blockId = "input-" + NetworkReceiver.this.streamId + "-" + (time - blockInterval)
- val newBlock = createBlock(blockId, newBlockBuffer.toIterator)
+ val newBlock = new Block(blockId, newBlockBuffer)
blocksForPushing.add(newBlock)
}
} catch {
@@ -248,7 +242,7 @@ abstract class NetworkReceiver[T: ClassManifest]() extends Serializable with Log
try {
while(true) {
val block = blocksForPushing.take()
- NetworkReceiver.this.pushBlock(block.id, block.iterator, block.metadata, storageLevel)
+ NetworkReceiver.this.pushBlock(block.id, block.buffer, block.metadata, storageLevel)
}
} catch {
case ie: InterruptedException =>
diff --git a/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala
index c697498862..ff7a58be45 100644
--- a/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala
@@ -3,34 +3,45 @@ package spark.streaming.dstream
import spark._
import spark.streaming._
import storage.StorageLevel
-
import twitter4j._
-import twitter4j.auth.BasicAuthorization
+import twitter4j.auth.Authorization
+import java.util.prefs.Preferences
+import twitter4j.conf.ConfigurationBuilder
+import twitter4j.conf.PropertyConfiguration
+import twitter4j.auth.OAuthAuthorization
+import twitter4j.auth.AccessToken
/* A stream of Twitter statuses, potentially filtered by one or more keywords.
*
-* @constructor create a new Twitter stream using the supplied username and password to authenticate.
+* @constructor create a new Twitter stream using the supplied Twitter4J authentication credentials.
* An optional set of string filters can be used to restrict the set of tweets. The Twitter API is
* such that this may return a sampled subset of all tweets during each interval.
+*
+* If no Authorization object is provided, initializes OAuth authorization using the system
+* properties twitter4j.oauth.consumerKey, .consumerSecret, .accessToken and .accessTokenSecret.
*/
private[streaming]
class TwitterInputDStream(
@transient ssc_ : StreamingContext,
- username: String,
- password: String,
+ twitterAuth: Option[Authorization],
filters: Seq[String],
storageLevel: StorageLevel
) extends NetworkInputDStream[Status](ssc_) {
+
+ private def createOAuthAuthorization(): Authorization = {
+ new OAuthAuthorization(new ConfigurationBuilder().build())
+ }
+ private val authorization = twitterAuth.getOrElse(createOAuthAuthorization())
+
override def getReceiver(): NetworkReceiver[Status] = {
- new TwitterReceiver(username, password, filters, storageLevel)
+ new TwitterReceiver(authorization, filters, storageLevel)
}
}
private[streaming]
class TwitterReceiver(
- username: String,
- password: String,
+ twitterAuth: Authorization,
filters: Seq[String],
storageLevel: StorageLevel
) extends NetworkReceiver[Status] {
@@ -40,8 +51,7 @@ class TwitterReceiver(
protected override def onStart() {
blockGenerator.start()
- twitterStream = new TwitterStreamFactory()
- .getInstance(new BasicAuthorization(username, password))
+ twitterStream = new TwitterStreamFactory().getInstance(twitterAuth)
twitterStream.addListener(new StatusListener {
def onStatus(status: Status) = {
blockGenerator += status
diff --git a/streaming/src/main/scala/spark/streaming/receivers/ActorReceiver.scala b/streaming/src/main/scala/spark/streaming/receivers/ActorReceiver.scala
index b3201d0b28..036c95a860 100644
--- a/streaming/src/main/scala/spark/streaming/receivers/ActorReceiver.scala
+++ b/streaming/src/main/scala/spark/streaming/receivers/ActorReceiver.scala
@@ -9,6 +9,8 @@ import spark.streaming.dstream.NetworkReceiver
import java.util.concurrent.atomic.AtomicInteger
+import scala.collection.mutable.ArrayBuffer
+
/** A helper with set of defaults for supervisor strategy **/
object ReceiverSupervisorStrategy {
@@ -136,8 +138,9 @@ private[streaming] class ActorReceiver[T: ClassManifest](
}
protected def pushBlock(iter: Iterator[T]) {
- pushBlock("block-" + streamId + "-" + System.nanoTime(),
- iter, null, storageLevel)
+ val buffer = new ArrayBuffer[T]
+ buffer ++= iter
+ pushBlock("block-" + streamId + "-" + System.nanoTime(), buffer, null, storageLevel)
}
protected def onStart() = {
diff --git a/streaming/src/test/java/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/spark/streaming/JavaAPISuite.java
index 3bed500f73..4cf10582a9 100644
--- a/streaming/src/test/java/spark/streaming/JavaAPISuite.java
+++ b/streaming/src/test/java/spark/streaming/JavaAPISuite.java
@@ -4,6 +4,7 @@ import com.google.common.base.Optional;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.io.Files;
+import kafka.serializer.StringDecoder;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.junit.After;
import org.junit.Assert;
@@ -23,7 +24,6 @@ import spark.streaming.api.java.JavaPairDStream;
import spark.streaming.api.java.JavaStreamingContext;
import spark.streaming.JavaTestUtils;
import spark.streaming.JavaCheckpointTestUtils;
-import spark.streaming.dstream.KafkaPartitionKey;
import spark.streaming.InputStreamsSuite;
import java.io.*;
@@ -1203,10 +1203,14 @@ public class JavaAPISuite implements Serializable {
@Test
public void testKafkaStream() {
HashMap<String, Integer> topics = Maps.newHashMap();
- HashMap<KafkaPartitionKey, Long> offsets = Maps.newHashMap();
JavaDStream test1 = ssc.kafkaStream("localhost:12345", "group", topics);
- JavaDStream test2 = ssc.kafkaStream("localhost:12345", "group", topics, offsets);
- JavaDStream test3 = ssc.kafkaStream("localhost:12345", "group", topics, offsets,
+ JavaDStream test2 = ssc.kafkaStream("localhost:12345", "group", topics,
+ StorageLevel.MEMORY_AND_DISK());
+
+ HashMap<String, String> kafkaParams = Maps.newHashMap();
+ kafkaParams.put("zk.connect","localhost:12345");
+ kafkaParams.put("groupid","consumer-group");
+ JavaDStream test3 = ssc.kafkaStream(String.class, StringDecoder.class, kafkaParams, topics,
StorageLevel.MEMORY_AND_DISK());
}
@@ -1263,7 +1267,7 @@ public class JavaAPISuite implements Serializable {
@Test
public void testTwitterStream() {
String[] filters = new String[] { "good", "bad", "ugly" };
- JavaDStream test = ssc.twitterStream("username", "password", filters, StorageLevel.MEMORY_ONLY());
+ JavaDStream test = ssc.twitterStream(filters, StorageLevel.MEMORY_ONLY());
}
@Test
diff --git a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala
index e7352deb81..565089a853 100644
--- a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala
+++ b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala
@@ -93,9 +93,9 @@ class BasicOperationsSuite extends TestSuiteBase {
test("count") {
testOperation(
- Seq(1 to 1, 1 to 2, 1 to 3, 1 to 4),
+ Seq(Seq(), 1 to 1, 1 to 2, 1 to 3, 1 to 4),
(s: DStream[Int]) => s.count(),
- Seq(Seq(1L), Seq(2L), Seq(3L), Seq(4L))
+ Seq(Seq(0L), Seq(1L), Seq(2L), Seq(3L), Seq(4L))
)
}
diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala
index 0acb6db6f2..b024fc9dcc 100644
--- a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala
+++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala
@@ -243,6 +243,17 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
assert(output(i) === expectedOutput(i))
}
}
+
+ test("kafka input stream") {
+ val ssc = new StreamingContext(master, framework, batchDuration)
+ val topics = Map("my-topic" -> 1)
+ val test1 = ssc.kafkaStream("localhost:12345", "group", topics)
+ val test2 = ssc.kafkaStream("localhost:12345", "group", topics, StorageLevel.MEMORY_AND_DISK)
+
+ // Test specifying decoder
+ val kafkaParams = Map("zk.connect"->"localhost:12345","groupid"->"consumer-group")
+ val test3 = ssc.kafkaStream[String, kafka.serializer.StringDecoder](kafkaParams, topics, StorageLevel.MEMORY_AND_DISK)
+ }
}